In [9]:
import torch

# ok so we're implementing a custom cumsum function
# we want it to be super numerically stable

# set seed to 42
torch.manual_seed(42)

input_data = torch.randn(1000000, dtype=torch.float64, device="cuda")
input_data_fp32 = input_data.float()

# let's start by checking torch fp64 cumsum. that'll be our ground truth.
fp64_cumsum = torch.cumsum(input_data, dim=-1)
# let's see how much it diverges from doing a non-recursive for loop. my past experiments suggested this is kinda unstable.
# fp64_slow_loop_cumsum = torch.zeros_like(input_data)
# for i in range(input_data.shape[-1]):
#     fp64_slow_loop_cumsum[..., i] = torch.sum(input_data[..., :i+1], dim=-1)
# print(f"Max difference between fp64 and fp64 loop cumsum:",(fp64_cumsum - fp64_slow_loop_cumsum).max())
# then let's compare to a loop that uses an accumulator.
# fp64_fast_loop_cumsum = torch.zeros_like(input_data)
# fp64_fast_loop_cumsum[..., 0] = input_data[..., 0]
# for i in range(1, input_data.shape[-1]):
#     fp64_fast_loop_cumsum[..., i] = fp64_fast_loop_cumsum[..., i-1] + input_data[..., i]
print(f"Max difference between fp64 and fp64 fast loop cumsum:",(fp64_cumsum - fp64_fast_loop_cumsum).max())

# then we'll compare to torch fp32 cumsum. past experiments say this is super stable and close to fp64 torch.cumsum.
fp32_cumsum = torch.cumsum(input_data_fp32, dim=-1)
print(f"Max difference between fp64 and fp32 cumsum:",(fp64_cumsum - fp32_cumsum).max())

# then we'll test our own stuff

# first a triton kernel that uses fp64. this should be as good as torch fp64 with loop.

import torch
import triton
import triton.language as tl

@triton.jit
def cumsum_triton(x_ptr, y_ptr, n_elements):
    """
    A simple sequential cumulative sum kernel in fp64.
    
    This kernel reads the input array from x_ptr, initializes the accumulator
    with the first element (to ensure the accumulator is fp64) and writes the
    cumulative sum into y_ptr. It then loops over the remaining elements,
    adds each value to the accumulator, and stores the cumulative sum.
    
    Parameters:
      x_ptr: pointer to the beginning of the input array.
      y_ptr: pointer to the beginning of the output array.
      n_elements: total number of elements to process (expected to be a Python int).
    """
    # Initialize the accumulator with the first element, ensuring the type is fp64.
    acc = tl.load(x_ptr)
    tl.store(y_ptr, acc)
    
    # Process the rest of the elements.
    i = 1
    while i < n_elements:
        val = tl.load(x_ptr + i)
        acc = acc + val
        tl.store(y_ptr + i, acc)
        i = i + 1

def fp64_triton_cumsum(x):
    """
    Computes the cumulative sum of the input tensor x (in fp64) using a Triton kernel.
    
    The function allocates an output tensor of the same shape as x and then launches
    the Triton kernel `cumsum_triton` with a grid size of 1 (i.e. a single program instance)
    to perform the sequential accumulation.
    
    Parameters:
      x: a 1-dimensional torch tensor of dtype torch.float64 on the CUDA device.
    
    Returns:
      A torch tensor with the cumulative sum computed elementwise.
    """
    n_elements = x.numel()       # total number of elements in x
    y = torch.empty_like(x)      # allocate output tensor
    # Launch a single kernel instance; note: this kernel assumes n_elements >= 1.
    grid = (1,)
    cumsum_triton[grid](x, y, n_elements)
    return y

fp64_triton_cumsum_result = fp64_triton_cumsum(input_data)
print(f"Max difference between fp64 and fp64 triton cumsum:",(fp64_cumsum - fp64_triton_cumsum_result).max())

fp64_triton_cumsum_result_fp32 = fp64_triton_cumsum_result.float()
print(f"Max difference between fp64 and fp64 triton cumsum (converted to fp32):",(fp64_cumsum - fp64_triton_cumsum_result_fp32).max())


# then a triton kernel to do parallel scan, I think. we'll do it in fp64 at first. it should hopefully match torch fp64 cumsum super closely.

# then we'll do the same thing in fp32.

fp32_triton_cumsum_result = fp64_triton_cumsum(input_data_fp32)
print("Max diff between fp64 and fp32 triton cumsum:",(fp64_cumsum - fp32_triton_cumsum_result).max())

# then we'll implement our custom protect_and_attack function in fp64.

# it should be super close to torch fp64 cumsum too - about as close as our fp64 parallel scan triton kernel.

# and as we finalize each of the functions / etc, we'll move it into the stability.py file (for brevity).

Max difference between fp64 and fp64 fast loop cumsum: tensor(267.5462, device='cuda:0', dtype=torch.float64)
Max difference between fp64 and fp32 cumsum: tensor(0.0007, device='cuda:0', dtype=torch.float64)
Max difference between fp64 and fp64 triton cumsum: tensor(2.3874e-11, device='cuda:0', dtype=torch.float64)
Max difference between fp64 and fp64 triton cumsum (converted to fp32): tensor(6.1035e-05, device='cuda:0', dtype=torch.float64)
Max diff between fp64 and fp32 triton cumsum: tensor(0.0216, device='cuda:0', dtype=torch.float64)
