In [43]:
import torch

# ok so we're implementing a custom cumsum function
# we want it to be super numerically stable

# set seed to 42
torch.manual_seed(42)

input_data = torch.randn(1000000, dtype=torch.float64, device="cuda")
input_data_fp32 = input_data.float()

# let's start by checking torch fp64 cumsum. that'll be our ground truth.
fp64_cumsum_result = torch.cumsum(input_data, dim=-1)
fp64_sum_result = torch.sum(input_data, dim=-1)

# let's see how much it diverges from doing a non-recursive for loop. my past experiments suggested this is kinda unstable.
# fp64_slow_loop_cumsum = torch.zeros_like(input_data)
# for i in range(input_data.shape[-1]):
#     fp64_slow_loop_cumsum[..., i] = torch.sum(input_data[..., :i+1], dim=-1)
# print(f"Max difference between fp64 and fp64 loop cumsum:",(fp64_cumsum - fp64_slow_loop_cumsum).max())
# then let's compare to a loop that uses an accumulator.
fp64_fast_loop_cumsum = torch.zeros_like(input_data)
fp64_fast_loop_cumsum[..., 0] = input_data[..., 0]
for i in range(1, input_data.shape[-1]):
    fp64_fast_loop_cumsum[..., i] = fp64_fast_loop_cumsum[..., i-1] + input_data[..., i]
print(f"Max difference between fp64 and fp64 fast loop cumsum:",(fp64_cumsum_result - fp64_fast_loop_cumsum).max())

# then we'll compare to torch fp32 cumsum. past experiments say this is super stable and close to fp64 torch.cumsum.
fp32_cumsum_result = torch.cumsum(input_data_fp32, dim=-1)
print(f"Max difference between fp64 and fp32 cumsum:",(fp64_cumsum_result - fp32_cumsum_result).max())

fp32_sum_result = torch.sum(input_data_fp32, dim=-1)
print(f"Max difference between fp64 and fp32 sum:",(fp64_sum_result - fp32_sum_result).max())


# then we'll test our own stuff

# first a triton kernel that uses fp64. this should be as good as torch fp64 with loop.
from stability import triton_cumsum

fp64_triton_cumsum_result = triton_cumsum(input_data)
print(f"Max difference between fp64 and fp64 triton cumsum:",(fp64_cumsum_result - fp64_triton_cumsum_result).max())

fp64_triton_cumsum_result_fp32 = fp64_triton_cumsum_result.float()
print(f"Max difference between fp64 and fp64 triton cumsum (converted to fp32):",(fp64_cumsum_result - fp64_triton_cumsum_result_fp32).max())


# then a triton kernel to do parallel scan, I think. we'll do it in fp64 at first. it should hopefully match torch fp64 cumsum super closely.

# then we'll do the same thing in fp32.

fp32_triton_cumsum_result = triton_cumsum(input_data_fp32)
print("Max diff between fp64 and fp32 triton cumsum:",(fp64_cumsum_result - fp32_triton_cumsum_result).max())

# then we'll implement our custom protect_and_attack function in fp64.

# it should be super close to torch fp64 cumsum too - about as close as our fp64 parallel scan triton kernel.

# and as we finalize each of the functions / etc, we'll move it into the stability.py file (for brevity).

Max difference between fp64 and fp64 fast loop cumsum: tensor(2.3874e-11, device='cuda:0', dtype=torch.float64)
Max difference between fp64 and fp32 cumsum: tensor(0.0007, device='cuda:0', dtype=torch.float64)
Max difference between fp64 and fp32 sum: tensor(6.2322e-05, device='cuda:0', dtype=torch.float64)
Max difference between fp64 and fp64 triton cumsum: tensor(2.3874e-11, device='cuda:0', dtype=torch.float64)
Max difference between fp64 and fp64 triton cumsum (converted to fp32): tensor(6.1035e-05, device='cuda:0', dtype=torch.float64)
Max diff between fp64 and fp32 triton cumsum: tensor(0.0216, device='cuda:0', dtype=torch.float64)


In [11]:
# ok let's make smth simple. let's just do a super simple naive algorithm.
# so we'll just try to make a rly good sum(inputs) function that takes in a list of inputs and returns a single sum.
# it will use a binary tree to accumulate them. it'll fill it out to the nearest power of 2, then treat it like some binary tree (hrrm, let's say w/ narrowing the set of active nodes by just removing the nonzero bits)
# ah what's that thing with indexing? should it be one-indexed then? b/c then you can go from the first element (i=1) to its first child (i=2) by multiplying by 2.
# yes yes that's it.
# hrrm but do we rly *want* to be doubling, or whatever? maybe that's just some random artefact of min-heaps.
# nono. I think we want zero-indexed. b/c we're gonna accumulate everything into i=0.

import torch

def bliasson_sum(x: torch.Tensor):
    closest_power_of_2 = 2**(x.shape[-1].bit_length())
    x_big = torch.zeros(closest_power_of_2, dtype=x.dtype, device=x.device)
    x_big[:x.shape[-1]] = x

    # ok so now we're going to do several rounds of summing.
    # each round will have half as many active nodes as the previous one. the final round will have 1 active node. after that round, we'll just return the sum.
    # ok so we can compute this by just using a diff set of indices each time.

    indices = torch.arange(0, closest_power_of_2)

    while len(indices) > 1:
        assert len(indices) % 2 == 0
        next_indices = indices[:len(indices)//2] * 2
        x_big[next_indices] = x_big[indices[::2]] + x_big[indices[1::2]]
        indices = next_indices
    
    return x_big[0]

assert sum(torch.tensor([1])) == 1
assert sum(torch.tensor([1, 2])) == 3
assert sum(torch.tensor([1, 2, 3])) == 6
assert sum(torch.tensor([1, 2, 3, 4])) == 10
assert sum(torch.tensor([1, 2, 3, 4, 5])) == 15
assert sum(torch.tensor([1, 2, 3, 4, 5, 6])) == 21
assert sum(torch.tensor([1, 2, 3, 4, 5, 6, 7])) == 28

def bliasson_cumsum(x: torch.Tensor):
    return torch.tensor([bliasson_sum(x[:i+1]) for i in range(x.shape[-1])])

assert (bliasson_cumsum(torch.tensor([1, 2, 3, 4, 5, 6, 7])) == torch.tensor([1, 3, 6, 10, 15, 21, 28])).all()



# now let's check the diff versus torch fp64 cumsum.

bliasson_sum_result_fp64 = bliasson_sum(input_data)
print(f"Max diff between fp64 and bliasson sum:",(fp64_sum_result - bliasson_sum_result_fp64).max())

# now let's try it in fp32.

bliasson_sum_result_fp32 = bliasson_sum(input_data_fp32)
print(f"Max diff between fp64 and bliasson sum (converted to fp32):",(fp64_sum_result - bliasson_sum_result_fp32).max())


Max diff between fp64 and bliasson sum: tensor(4.5475e-13, device='cuda:0', dtype=torch.float64)
Max diff between fp64 and bliasson sum (converted to fp32): tensor(-0.0002, device='cuda:0', dtype=torch.float64)


In [64]:
import torch

def bliasson_cumsum(x: torch.Tensor, dim: int=-1):
    x = x.transpose(dim, x.ndim-1) if dim != -1 else x
    *rest, n = x.shape
    bit_length = n.bit_length()
    closest_power_of_2 = 2**bit_length
    x_big = torch.zeros((*rest, closest_power_of_2), dtype=x.dtype, device=x.device)
    x_big[...,:n] = x

    # ok so now we're going to do several rounds of summing.
    # each round will have half as many active nodes as the previous one. the final round will have 1 active node. after that round, we'll just return the sum.
    # ok so we can compute this by just using a diff set of indices each time.

    indices = torch.arange(0, closest_power_of_2)

    while len(indices) > 1:
        assert len(indices) % 2 == 0
        next_indices = indices[1::2]
        x_big[...,next_indices] = x_big[...,indices[::2]] + x_big[...,indices[1::2]]
        indices = next_indices
    
    # ok now we're going to propagate the info back down the tree, from top-down.

    for i in range(bit_length,1,-1):
        end_of_first_chunk = torch.arange(2 ** (i-1),closest_power_of_2,2 ** (i-1)) - 1
        end_of_first_half_of_second_chunk = end_of_first_chunk + 2 ** (i - 2)

        x_big[...,end_of_first_half_of_second_chunk] += x_big[...,end_of_first_chunk]
    
    raw_out = x_big[...,:n]

    return raw_out.transpose(dim, x.ndim-1) if dim != -1 else raw_out

(bliasson_cumsum(torch.tensor([1, 2, 3, 4, 5, 6, 7])) == torch.tensor([1, 3, 6, 10, 15, 21, 28])).all()

bliasson_cumsum_result_fp64 = bliasson_cumsum(input_data)
print(f"Max diff between fp64 and bliasson cumsum:",(fp64_cumsum_result - bliasson_cumsum_result_fp64).max())

bliasson_cumsum_result_fp32 = bliasson_cumsum(input_data_fp32)
print(f"Max diff between fp64 and bliasson cumsum on fp32:",(fp64_cumsum_result - bliasson_cumsum_result_fp32).max())




Max diff between fp64 and bliasson cumsum: tensor(1.8190e-12, device='cuda:0', dtype=torch.float64)
Max diff between fp64 and bliasson cumsum on fp32: tensor(0.0005, device='cuda:0', dtype=torch.float64)


In [65]:
# now let's test our bliasson cumsum on our dataset of inputs that cause instability.

# inputs_causing_instability = torch.load("../../inputs_causing_instability.pt")

for S_64_np, max_diff in inputs_causing_instability:
    S_64 = torch.from_numpy(S_64_np)
    FF_64 = torch.cumsum(S_64, dim=-2)
    bliasson_cumsum_result_fp64 = bliasson_cumsum(S_64,dim=1)
    print(f"Max diff during training: {max_diff} | Max diff with bliasson cumsum: {(FF_64 - bliasson_cumsum_result_fp64).max()}")

Max diff during training: 2.288818359375e-05 | Max diff with bliasson cumsum: 0.0
Max diff during training: 1.1444091796875e-05 | Max diff with bliasson cumsum: 0.0
Max diff during training: 1.52587890625e-05 | Max diff with bliasson cumsum: 0.0
Max diff during training: 1.52587890625e-05 | Max diff with bliasson cumsum: 0.0
Max diff during training: 1.52587890625e-05 | Max diff with bliasson cumsum: 0.0
Max diff during training: 1.537799835205078e-05 | Max diff with bliasson cumsum: 0.0
Max diff during training: 0.0 | Max diff with bliasson cumsum: 0.0
Max diff during training: 0.0 | Max diff with bliasson cumsum: 0.0
Max diff during training: 1.52587890625e-05 | Max diff with bliasson cumsum: 0.0
Max diff during training: 5.364418029785156e-07 | Max diff with bliasson cumsum: 0.0
Max diff during training: 0.0 | Max diff with bliasson cumsum: 0.0
Max diff during training: 1.52587890625e-05 | Max diff with bliasson cumsum: 0.0
Max diff during training: 1.430511474609375e-05 | Max diff 

## Protect-and-attack algorithm