# Test some fusion details for linear+xent (via online logsumexp and a kind of forward AD)

In [None]:
import torch
from torch.utils.checkpoint import checkpoint
import torch.nn.functional as F

import triton
import triton.language as tl

import logging


In [None]:
torch.cuda.is_available()

In [None]:
from torch.profiler import profile, record_function, ProfilerActivity
activities = [ProfilerActivity.CPU]

# torch._logging.set_logs(dynamo=logging.DEBUG)
# torch._logging.set_logs(graph=True)
# torch._logging.set_logs(fusion=True)
# torch._logging.set_logs(output_code=True)

In [None]:
def cosim(x,y):
    return ((x.view(-1).double() * y.view(-1).double()).sum() / x.view(-1).double().norm() / y.view(-1).double().norm()).float()

#

### Baseline sanity check

In [None]:
B, S , H, V = 4, 512, 32, 4096
N = B * S 

x = torch.randn(B, S, H, requires_grad=True) # B S H
y = torch.randint(0, V, (B, S)) # vocab ** B S 
A = torch.randn(V, H, requires_grad=True)

def baseline(x, y, A):
    return F.cross_entropy(F.linear(x.double(), A.double()).view(-1, V), y.view(-1))

loss = baseline(x, y, A)
loss.backward()

reference_A_grad = A.grad.float().clone()
reference_x_grad = x.grad.float().clone()
reference_loss = loss.detach().float().clone()

loss.item(), A.grad.mean(), x.grad.mean()

In [None]:
# torch compile version:
chunk_size = 64 

@torch.compile(mode="max-autotune",fullgraph=True)
def checkpointed_chunked(x, y, A):
    loss = 0.
    x_chunks = x.view(-1, H).split(chunk_size)
    y_chunks = y.view(-1).split(chunk_size)
    num_chunks = len(y_chunks)
    for x_chunk, y_chunk in zip(x_chunks, y_chunks):
        loss += checkpoint(lambda A, x, y: F.cross_entropy(F.linear(x, A), y) / num_chunks, A, x_chunk, y_chunk,use_reentrant=False)
    return loss

loss = checkpointed_chunked(x, y, A)
loss.item(), A.grad.mean(), x.grad.mean()

In [None]:
def manual_implementation_vectorized(x, y, A):
    loss = 0.
    x_cont = x.view(-1, H)
    y_cont = y.view(-1)
    # loss = (-(A[y_cont] * x_cont).sum(dim=-1) + z_nv.logsumexp(dim=-1)).mean()
    z_nv = x_cont @ A.T
    c_n = z_nv.max(dim=-1)[0]
    loss = (-(A[y_cont] * x_cont).sum(dim=-1) + c_n + (z_nv - c_n[:, None]).exp().sum(dim=-1).log()).mean()
    
    return loss

loss = manual_implementation_vectorized(x, y, A)
loss.item()

In [None]:
def manual_implementation_chunked(x, y, A):
    loss = 0.
    x_chunks = x.view(-1, H).split(chunk_size)
    y_chunks = y.view(-1).split(chunk_size)
    num_chunks = len(y_chunks)
    # loss = (-(A[y_cont] * x_cont).sum(dim=-1) + z_nv.logsumexp(dim=-1)).mean()
    for x_chunk, y_chunk in zip(x_chunks, y_chunks):
        z_nv = x_chunk @ A.T
        c_n = z_nv.max(dim=-1)[0]
        loss += (-(A[y_chunk] * x_chunk).sum(dim=-1) + c_n + (z_nv - c_n[:, None]).exp().sum(dim=-1).log()).mean()
    
    return loss / num_chunks

loss = manual_implementation_chunked(x, y, A)
loss.item()

In [None]:
@torch.no_grad
def manual_implementation_chunked_online_lse(x, y, A):
    loss = 0.
    x_chunks = x.view(-1, H).split(chunk_size)
    y_chunks = y.view(-1).split(chunk_size)
    num_chunks = len(y_chunks)
    m_global = torch.zeros(N)
    s_global = torch.zeros(N)

    for idx, (x_chunk, y_chunk) in enumerate(zip(x_chunks, y_chunks)):
        m = float("-inf") * torch.ones(chunk_size)
        s = torch.zeros(chunk_size)

        for v in range(V):
            m_prev = m.clone()
            z_j = (x_chunk * A[v]).sum(dim=-1)
            m = torch.maximum(m_prev, z_j)
            s = s * (m_prev - m).exp() + (z_j - m).exp()

        loss += (-(A[y_chunk] * x_chunk).sum(dim=-1) + m + s.log()).mean()
        m_global[idx * chunk_size: (idx+1)*chunk_size] = m
        s_global[idx * chunk_size: (idx+1)*chunk_size] = s
    
    return loss / num_chunks, m_global, s_global

loss, m, s = manual_implementation_chunked_online_lse(x, y, A)
loss.item()

In [None]:
@torch.no_grad
def manual_implementation_chunked_online_lse_bwd(x, y, A, m_global, s_global):
    x_chunks = x.view(-1, H).split(chunk_size)
    y_chunks = y.view(-1).split(chunk_size)

    Agrad = torch.zeros_like(A)
    global_x_grad = torch.zeros_like(x.view(-1, H))

    for idx, (x_chunk, y_chunk) in enumerate(zip(x_chunks, y_chunks)):
        s = s_global[idx * chunk_size: (idx+1)*chunk_size]
        m = m_global[idx * chunk_size: (idx+1)*chunk_size]

        xgrad = -A[y_chunk]
        for line, y in enumerate(y_chunk):
            Agrad[y] += -x_chunk[line]
        for v in range(V):
            z_j = (x_chunk * A[v]).sum(dim=-1)
            xgrad += ((z_j - m).exp() / s)[:, None]  * A[v][None, :]
            Agrad[v] += (((z_j - m).exp() / s)[:, None] * x_chunk).sum(dim=0)

        global_x_grad[idx * chunk_size: (idx+1)*chunk_size] = xgrad / N
    
    return Agrad / N, global_x_grad.view_as(x)

Agrad, xgrad = manual_implementation_chunked_online_lse_bwd(x, y, A, m, s)
torch.dist(reference_x_grad, xgrad), cosim(reference_x_grad, xgrad), torch.dist(reference_A_grad, Agrad), cosim(reference_A_grad, Agrad)

In [None]:
def manual_implementation_vectorized_bwd(x, y, A):
    Agrad = torch.zeros_like(A)
    x_cont = x.view(-1, H)
    y_cont = y.view(-1)
    # loss = (-(A[y_cont] * x_cont).sum(dim=-1) + z_nv.logsumexp(dim=-1)).mean()
    z_nv = x_cont @ A.T
    c_n = z_nv.max(dim=-1)[0]
    loss = (-(A[y_cont] * x_cont).sum(dim=-1) + c_n + (z_nv - c_n[:, None]).exp().sum(dim=-1).log()).mean() 
    # bwd 
    d = (z_nv - c_n[:, None]).exp().sum(dim=-1, keepdim=True)
    xgrad = (-A[y_cont] + ((z_nv - c_n[:, None]).exp() / d) @ A) / N
    # Agrad[y_cont] += -x_cont
    for idx, y in enumerate(y_cont):
        Agrad[y] += -x_cont[idx]
    Agrad +=((z_nv - c_n[:, None]).exp() / d).T @ x_cont
    
    return loss, Agrad / N, xgrad.view_as(x)

# x.grad.zero_()
# A.grad.zero_()
loss, Agrad, xgrad = manual_implementation_vectorized_bwd(x, y, A)
# loss.backward()
# # loss, torch.dist(x.grad, xgrad), 
loss.item(), torch.dist(reference_x_grad, xgrad), torch.dist(reference_A_grad, Agrad)

In [None]:
@torch.no_grad
def manual_implementation_chunked_online_lse_bwd_single_loop(x, y, A, m_global, s_global):
    x_chunks = x.view(-1, H).split(chunk_size)
    y_chunks = y.view(-1).split(chunk_size)

    Agrad = torch.zeros_like(A)
    global_x_grad = torch.zeros_like(x.view(-1, H))

    for idx, (x_chunk, y_chunk) in enumerate(zip(x_chunks, y_chunks)):
        s = s_global[idx * chunk_size: (idx+1)*chunk_size]
        m = m_global[idx * chunk_size: (idx+1)*chunk_size]
        xgrad = torch.zeros_like(x_chunk)

        for v in range(V):
            z_j = (x_chunk * A[v]).sum(dim=-1)
            xgrad += ((z_j - m).exp() / s)[:, None]  * A[v][None, :]
            Agrad[v] += (((z_j - m).exp() / s)[:, None] * x_chunk).sum(dim=0)

            mask = y_chunk == v
            if mask.sum() > 0:
                Agrad[v] -= x_chunk[mask].sum(dim=0)
                xgrad[mask] -= A[v]

        global_x_grad[idx * chunk_size: (idx+1)*chunk_size] = xgrad / N
    
    return Agrad / N, global_x_grad.view_as(x)

Agrad, xgrad = manual_implementation_chunked_online_lse_bwd_single_loop(x, y, A, m, s)
torch.dist(reference_x_grad, xgrad), cosim(reference_x_grad, xgrad), torch.dist(reference_A_grad, Agrad), cosim(reference_A_grad, Agrad)

In [None]:
@torch.no_grad
def manual_implementation_chunked_online_lse_single_loop(x, y, A):
    loss = 0.
    x_chunks = x.view(-1, H).split(chunk_size)
    y_chunks = y.view(-1).split(chunk_size)
    num_chunks = len(y_chunks)
    m_global = torch.zeros(N)
    s_global = torch.zeros(N)

    for idx, (x_chunk, y_chunk) in enumerate(zip(x_chunks, y_chunks)):
        m = float("-inf") * torch.ones(chunk_size)
        s = torch.zeros(chunk_size)

        for v in range(V):
            m_prev = m.clone()
            z_j = (x_chunk * A[v]).sum(dim=-1)
            m = torch.maximum(m_prev, z_j)
            s = s * (m_prev - m).exp() + (z_j - m).exp()
            mask = y_chunk == v
            if mask.sum() > 0:
                loss -= z_j[mask].sum() / N

        loss += (m + s.log()).mean() / num_chunks
        m_global[idx * chunk_size: (idx+1)*chunk_size] = m
        s_global[idx * chunk_size: (idx+1)*chunk_size] = s
    
    return loss, m_global, s_global

loss, m, s = manual_implementation_chunked_online_lse_single_loop(x, y, A)
loss.item(), torch.dist(loss, reference_loss).item()

In [None]:
@torch.no_grad
def manual_implementation_chunked_online_lse_single_loop_bf16(x, y, A):
    x = x.to(dtype=torch.float16)
    A = A.to(dtype=torch.float16)

    loss = 0.
    x_chunks = x.view(-1, H).split(chunk_size)
    y_chunks = y.view(-1).split(chunk_size)
    num_chunks = len(y_chunks)
    m_global = torch.zeros(N, dtype=torch.float32)
    s_global = torch.zeros(N, dtype=torch.float32)

    for idx, (x_chunk, y_chunk) in enumerate(zip(x_chunks, y_chunks)):
        m = -10e5 * torch.ones(chunk_size, dtype=torch.float32)
        s = torch.zeros(chunk_size, dtype=torch.float32)

        for v in range(V):
            m_prev = m.clone()
            z_j = (x_chunk * A[v]).to(dtype=torch.float32).sum(dim=-1)
            m = torch.maximum(m_prev, z_j)
            s = s * (m_prev - m).exp() + (z_j - m).exp()
            mask = y_chunk == v
            if mask.sum() > 0:
                loss -= z_j[mask].sum() / N

        loss += (m + s.log()).mean() / num_chunks
        m_global[idx * chunk_size: (idx+1)*chunk_size] = m
        s_global[idx * chunk_size: (idx+1)*chunk_size] = s
    
    return loss, m_global, s_global

loss, m, s = manual_implementation_chunked_online_lse_single_loop_bf16(x, y, A)
loss.item(), torch.dist(loss, reference_loss).item(), f"{torch.dist(loss, reference_loss).item():2.4e}"

In [None]:
@torch.no_grad
def manual_implementation_chunked_online_lse_bwd_single_loop_bf16(x, y, A, m_global, s_global):
    x = x.to(dtype=torch.float16) # float16 is much more accurate for "sane" values of x and A
    A = A.to(dtype=torch.float16)

    x_chunks = x.view(-1, H).split(chunk_size)
    y_chunks = y.view(-1).split(chunk_size)

    Agrad = torch.zeros_like(A, dtype=torch.float32)
    global_x_grad = torch.zeros_like(x.view(-1, H), dtype=torch.float32)

    for idx, (x_chunk, y_chunk) in enumerate(zip(x_chunks, y_chunks)):
        s = s_global[idx * chunk_size: (idx+1)*chunk_size]
        m = m_global[idx * chunk_size: (idx+1)*chunk_size]
        xgrad = torch.zeros_like(x_chunk)

        for v in range(V):
            z_j = (x_chunk * A[v]).to(dtype=torch.float32).sum(dim=-1)
            xgrad += ((z_j - m).exp() / s)[:, None]  * A[v][None, :].to(dtype=torch.float32)
            Agrad[v] += (((z_j - m).exp() / s)[:, None] * x_chunk).sum(dim=0)

            mask = y_chunk == v
            if mask.sum() > 0:
                Agrad[v] -= x_chunk[mask].to(dtype=torch.float32).sum(dim=0)
                xgrad[mask] -= A[v].to(dtype=torch.float32)

        global_x_grad[idx * chunk_size: (idx+1)*chunk_size] = xgrad / N
    
    return Agrad / N, global_x_grad.view_as(x)

Agrad, xgrad = manual_implementation_chunked_online_lse_bwd_single_loop_bf16(x, y, A, m, s)
torch.dist(reference_x_grad, xgrad), cosim(reference_x_grad, xgrad), torch.dist(reference_A_grad, Agrad), cosim(reference_A_grad, Agrad)

In [None]:
# @torch.no_grad
# def manual_implementation_chunked_online_lse_single_loop_fwd_bwd_bf16(x, y, A):
#     x = x.to(dtype=torch.float16)
#     A = A.to(dtype=torch.float16)

#     loss = 0.
#     x_chunks = x.view(-1, H).split(chunk_size)
#     y_chunks = y.view(-1).split(chunk_size)
#     num_chunks = len(y_chunks)

#     Agrad = torch.zeros_like(A, dtype=torch.float32)
#     global_x_grad = torch.zeros_like(x.view(-1, H), dtype=torch.float32)

#     for idx, (x_chunk, y_chunk) in enumerate(zip(x_chunks, y_chunks)):
#         m = -10e5 * torch.ones(chunk_size, dtype=torch.float32)
#         s = torch.zeros(chunk_size, dtype=torch.float32)
#         xgrad = torch.zeros_like(x_chunk)
#         xgrad2 = torch.zeros_like(x_chunk)

#         for v in range(V):
#             m_prev = m.clone()
#             z_j = (x_chunk * A[v]).to(dtype=torch.float32).sum(dim=-1)
#             m = torch.maximum(m_prev, z_j)
#             s = s * (m_prev - m).exp() + (z_j - m).exp()
#             mask = y_chunk == v
#             if mask.sum() > 0:
#                 loss -= z_j[mask].sum() / N
#                 # bwd 1
#                 Agrad[v] -= x_chunk[mask].to(dtype=torch.float32).sum(dim=0)
#                 xgrad[mask] -= A[v].to(dtype=torch.float32)

#             # bwd 2
#             xgrad2 += ((z_j - m).exp())[:, None]  * A[v][None, :].to(dtype=torch.float32) # I think the -m_j is the problem here
#             Agrad[v] += (((z_j - m).exp())[:, None] * x_chunk).sum(dim=0)  # m and s are online in this version and still changing

#         global_x_grad[idx * chunk_size: (idx+1)*chunk_size] = xgrad / N + xgrad2 / s / N
#         Agrad /= s
#         loss += (m + s.log()).mean() / num_chunks
    
#     return loss, Agrad, global_x_grad.view_as(x)

# loss, Agrad, xgrad = manual_implementation_chunked_online_lse_single_loop_fwd_bwd_bf16(x, y, A)
# loss.item(), f"{torch.dist(loss, reference_loss).item():2.4e}", torch.dist(reference_x_grad, xgrad), torch.dist(reference_A_grad, Agrad)

In [None]:
V_chunk_size = 16

In [None]:
V_chunk_size = 16

@torch.no_grad
def manual_implementation_chunked_online_lse_single_loop_bf16_blockV(x, y, A):
    x = x.to(dtype=torch.float16)
    A = A.to(dtype=torch.float16)

    loss = 0.
    x_chunks = x.view(-1, H).split(chunk_size)
    y_chunks = y.view(-1).split(chunk_size)
    num_chunks = len(y_chunks)
    m_global = torch.zeros(N, dtype=torch.float32)
    s_global = torch.zeros(N, dtype=torch.float32)

    A_chunks = A.split(V_chunk_size)
    v_offsets = torch.arange(V)

    for idx, (x_chunk, y_chunk) in enumerate(zip(x_chunks, y_chunks)):
        m = -10e5 * torch.ones(chunk_size, dtype=torch.float32)
        s = torch.zeros(chunk_size, dtype=torch.float32)

        for v_idx, A_v in enumerate(A_chunks):
            m_prev = m.clone()
            z_j_to_k = (x_chunk @ A_v.T).to(dtype=torch.float32) 
            m = torch.maximum(m_prev, z_j_to_k.max(dim=-1)[0])
            s = s * (m_prev - m).exp() + (z_j_to_k - m[:, None]).exp().sum(dim=-1)
            v_range = v_offsets[v_idx * V_chunk_size : (v_idx+1) * V_chunk_size]
            mask = y_chunk[:, None] == v_range[None, :]
            if mask.sum() > 0:
                loss -= z_j_to_k[mask].sum() / N

        loss += (m + s.log()).mean() / num_chunks
        m_global[idx * chunk_size: (idx+1)*chunk_size] = m
        s_global[idx * chunk_size: (idx+1)*chunk_size] = s
    
    return loss, m_global, s_global

loss, m, s = manual_implementation_chunked_online_lse_single_loop_bf16_blockV(x, y, A)
loss.item(), torch.dist(loss, reference_loss).item()# , f"{torch.dist(loss, reference_loss).item():2.4e}"

In [None]:
@torch.no_grad
def manual_implementation_chunked_online_lse_bwd_single_loop_bf16_blockV(x, y, A, m_global, s_global):
    x = x.to(dtype=torch.float16) # float16 is much more accurate for "sane" values of x and A
    A = A.to(dtype=torch.float16)

    x_chunks = x.view(-1, H).split(chunk_size)
    y_chunks = y.view(-1).split(chunk_size)

    A_chunks = A.split(V_chunk_size)
    v_offsets = torch.arange(V)

    Agrad = torch.zeros_like(A, dtype=torch.float32)
    global_x_grad = torch.zeros_like(x.view(-1, H), dtype=torch.float32)

    for idx, (x_chunk, y_chunk) in enumerate(zip(x_chunks, y_chunks)):
        s = s_global[idx * chunk_size: (idx+1)*chunk_size]
        m = m_global[idx * chunk_size: (idx+1)*chunk_size]
        xgrad = torch.zeros_like(x_chunk)

        for v_idx, A_v in enumerate(A_chunks):
            v_range = v_offsets[v_idx * V_chunk_size : (v_idx+1) * V_chunk_size]

            z_j_to_k = (x_chunk @ A_v.T).to(dtype=torch.float32) 
            xgrad += (((z_j_to_k - m[:, None]).exp() / s[:, None]).to(dtype=torch.float16) @  A_v).to(dtype=torch.float32)
            Agrad[v_range] += (((z_j_to_k - m[:, None]).exp() / s[:, None]).T.to(dtype=torch.float16) @ x_chunk).to(dtype=torch.float32)

            mask = y_chunk[:, None,] == v_range[None, :]
            if mask.sum() > 0:
                Agrad[v_range] -= torch.where(mask[:,:,None], x_chunk[:, None, :], 0.0).sum(dim=0) # reduction over N
                xgrad -= torch.where(mask[:, :, None], A_v, 0.0).sum(dim=1)

        global_x_grad[idx * chunk_size: (idx+1)*chunk_size] = xgrad / N
    
    return Agrad / N, global_x_grad.view_as(x)

Agrad, xgrad = manual_implementation_chunked_online_lse_bwd_single_loop_bf16_blockV(x, y, A, m, s)
torch.dist(reference_x_grad, xgrad), cosim(reference_x_grad, xgrad), torch.dist(reference_A_grad, Agrad), cosim(reference_A_grad, Agrad)

In [None]:
V_chunk_size = 16

@torch.no_grad
def manual_implementation_chunked_online_lse_single_loop_bf16_blockV(x, y, A):
    x = x.to(dtype=torch.float16)
    A = A.to(dtype=torch.float16)


    x_chunks = x.view(-1, H).split(chunk_size)
    y_chunks = y.view(-1).split(chunk_size)
    num_chunks = len(y_chunks)
    m_global = torch.zeros(N, dtype=torch.float32)
    s_global = torch.zeros(N, dtype=torch.float32)
    losses = torch.zeros(num_chunks, dtype=torch.float32)

    A_chunks = A.split(V_chunk_size)
    v_offsets = torch.arange(V)

    for idx, (x_chunk, y_chunk) in enumerate(zip(x_chunks, y_chunks)):
        m = -10e5 * torch.ones(chunk_size, dtype=torch.float32)
        s = torch.zeros(chunk_size, dtype=torch.float32)

        for v_idx, A_v in enumerate(A_chunks):
            m_prev = m.clone()
            z_j_to_k = (x_chunk @ A_v.T).to(dtype=torch.float32) 
            m = torch.maximum(m_prev, z_j_to_k.max(dim=-1)[0])
            s = s * (m_prev - m).exp() + (z_j_to_k - m[:, None]).exp().sum(dim=-1)
            v_range = v_offsets[v_idx * V_chunk_size : (v_idx+1) * V_chunk_size]
            mask = y_chunk[:, None] == v_range[None, :]
            if mask.sum() > 0:
                losses[idx] -= z_j_to_k[mask].sum() / N

        losses[idx] += (m + s.log()).mean() / num_chunks
        m_global[idx * chunk_size: (idx+1)*chunk_size] = m
        s_global[idx * chunk_size: (idx+1)*chunk_size] = s
    
    return losses.sum(), m_global, s_global

loss, m, s = manual_implementation_chunked_online_lse_single_loop_bf16_blockV(x, y, A)
loss.item(), torch.dist(loss, reference_loss).item() # , f"{torch.dist(loss, reference_loss).item():2.4e}"

In [None]:
@torch.no_grad
def manual_implementation_chunked_online_lse_single_loop_bf16_blockV_exp2(x, y, A):
    x = x.to(dtype=torch.float32)
    A = A.to(dtype=torch.float32)


    x_chunks = x.view(-1, H).split(chunk_size)
    y_chunks = y.view(-1).split(chunk_size)
    num_chunks = len(y_chunks)
    m_global = torch.zeros(N, dtype=torch.float32)
    s_global = torch.zeros(N, dtype=torch.float32)
    losses = torch.zeros(num_chunks, dtype=torch.float32)

    A_chunks = A.split(V_chunk_size)
    log2_const = 1.44269504
    

    for idx, (x_chunk, y_chunk) in enumerate(zip(x_chunks, y_chunks)):
        m = -10e5 * torch.ones(chunk_size, dtype=torch.float32)
        s = torch.zeros(chunk_size, dtype=torch.float32)
        v_offsets = torch.arange(0, V_chunk_size)

        for v_idx, A_v in enumerate(A_chunks):
            m_prev = m.clone()
            z_j_to_k = (x_chunk @ A_v.T).to(dtype=torch.float32)
            m = torch.maximum(m_prev, z_j_to_k.max(dim=-1)[0].mul(log2_const))
            s = s * (m_prev - m).exp2() + (z_j_to_k.mul(log2_const) -  m[:, None]).exp2().sum(dim=-1)

            mask = y_chunk[:, None] == v_offsets[None, :]
            losses[idx] -= torch.where(mask, z_j_to_k, 0.0).sum() / N

            v_offsets += V_chunk_size


        losses[idx] += (m + s.log2() ).mean() / num_chunks / log2_const
        m_global[idx * chunk_size: (idx+1)*chunk_size] = m
        s_global[idx * chunk_size: (idx+1)*chunk_size] = s
    
    return losses.sum(), m_global, s_global

loss, m_ref, s_ref = manual_implementation_chunked_online_lse_single_loop_bf16_blockV_exp2(x, y, A)
m = m_ref
s = s_ref
loss.item(), torch.dist(loss, reference_loss).item() # , f"{torch.dist(loss, reference_loss).item():2.4e}"

In [None]:
@torch.no_grad
def manual_implementation_chunked_online_lse_single_loop_bf16_blockV_2(x, y, A):
    x = x.to(dtype=torch.float32)
    A = A.to(dtype=torch.float32)


    x_chunks = x.view(-1, H).split(chunk_size)
    y_chunks = y.view(-1).split(chunk_size)
    num_chunks = len(y_chunks)
    m_global = torch.zeros(N, dtype=torch.float32)
    s_global = torch.zeros(N, dtype=torch.float32)
    losses = torch.zeros(num_chunks, dtype=torch.float32)

    A_chunks = A.split(V_chunk_size)
    
    for idx, (x_chunk, y_chunk) in enumerate(zip(x_chunks, y_chunks)):
        m = -10e5 * torch.ones(chunk_size, dtype=torch.float32)
        s = torch.zeros(chunk_size, dtype=torch.float32)
        v_offsets = torch.arange(0, V_chunk_size)

        for v_idx, A_v in enumerate(A_chunks):
            m_prev = m.clone()
            z_j_to_k = (x_chunk @ A_v.T).to(dtype=torch.float32)
            m = torch.maximum(m_prev, z_j_to_k.max(dim=-1)[0])
            s = s * (m_prev - m).exp() + (z_j_to_k-  m[:, None]).exp().sum(dim=-1)

            mask = y_chunk[:, None] == v_offsets[None, :]
            losses[idx] -= torch.where(mask, z_j_to_k, 0.0).sum() / N

            v_offsets += V_chunk_size


        losses[idx] += (m + s.log()).mean() / num_chunks
        m_global[idx * chunk_size: (idx+1)*chunk_size] = m
        s_global[idx * chunk_size: (idx+1)*chunk_size] = s
    
    return losses.sum(), m_global, s_global

loss, m_ref, s_ref = manual_implementation_chunked_online_lse_single_loop_bf16_blockV_2(x, y, A)
m = m_ref
s = s_ref
loss.item(), torch.dist(loss, reference_loss).item() # , f"{torch.dist(loss, reference_loss).item():2.4e}"

In [None]:
z = F.linear(x.double(), A.double()).view(-1, V) # N x V 
torch.dist(z.logsumexp(dim=1), m + s.log()), torch.dist(s, (z - m[:, None]).exp().sum(dim=1))

In [None]:
# H_chunk_size = 16

# @torch.no_grad
# def manual_implementation_chunked_online_lse_single_loop_bf16_blockV_2_tileH(x, y, A):
#     x = x.to(dtype=torch.float16)
#     A = A.to(dtype=torch.float16)


#     x_chunks = x.view(-1, H).split(chunk_size)
#     y_chunks = y.view(-1).split(chunk_size)
#     num_chunks = len(y_chunks)
#     m_global = torch.zeros(N,  H // H_chunk_size, dtype=torch.float32)
#     s_global = torch.zeros(N,  H // H_chunk_size, dtype=torch.float32)
#     losses = torch.zeros(num_chunks, H // H_chunk_size, dtype=torch.float32)

    
    

#     for idx, (x_chunk, y_chunk) in enumerate(zip(x_chunks, y_chunks)):
#         for h_idx in range(H // H_chunk_size):
#             x_chunk_h = x_chunk[:, h_idx*H_chunk_size : (h_idx+1) * H_chunk_size]
#             A_chunks = A[:, h_idx*H_chunk_size : (h_idx+1) * H_chunk_size].split(V_chunk_size)

#             m = -10e5 * torch.ones(chunk_size, dtype=torch.float32)
#             s = torch.zeros(chunk_size, dtype=torch.float32)
#             v_offsets = torch.arange(0, V_chunk_size)
            
#             for v_idx, A_v in enumerate(A_chunks):
#                 m_prev = m.clone()
#                 z_j_to_k = (x_chunk_h.matmul(A_v.T)).to(dtype=torch.float32)  
#                 m = torch.maximum(m_prev, z_j_to_k.max(dim=-1)[0])
#                 s = s * (m_prev - m).exp() + (z_j_to_k - m[:, None]).exp().sum(dim=-1)

#                 mask = y_chunk[:, None] == v_offsets[None, :]
#                 losses[idx] -= torch.where(mask, z_j_to_k, 0.0).sum() / N

#                 v_offsets += V_chunk_size


#             losses[idx, h_idx] += (m + s.log()).mean() / num_chunks
#             m_global[idx * chunk_size: (idx+1)*chunk_size, h_idx] = m
#             s_global[idx * chunk_size: (idx+1)*chunk_size, h_idx] = s
    
#     return losses.sum(), m_global.max(dim=1)[0], s_global.sum(dim=1).log()

# loss, m, s = manual_implementation_chunked_online_lse_single_loop_bf16_blockV_2_tileH(x, y, A)
# loss.item(), torch.dist(loss, reference_loss).item(), torch.dist(m, m_ref), torch.dist(s, s_ref)

In [None]:
# V_tile_size = 16

# @torch.no_grad
# def manual_implementation_chunked_online_lse_single_loop_bf16_blockV_2_tileV(x, y, A):
#     x = x.to(dtype=torch.float16)
#     A = A.to(dtype=torch.float16)


#     x_chunks = x.view(-1, H).split(chunk_size)
#     y_chunks = y.view(-1).split(chunk_size)
#     num_chunks = len(y_chunks)

#     num_V_tiles = V // V_chunk_size // V_tile_size
#     m_global = torch.zeros(N,  num_V_tiles, dtype=torch.float32)
#     s_global = torch.zeros(N,  num_V_tiles, dtype=torch.float32)
#     losses = torch.zeros(num_chunks, num_V_tiles, dtype=torch.float32)

#     A_chunks = A.split(V_chunk_size)
    

#     for idx, (x_chunk, y_chunk) in enumerate(zip(x_chunks, y_chunks)):
#         for h_idx in range(H // H_chunk_size):
#             for v_tile_idx in range(num_V_tiles):

#                 m = -10e5 * torch.ones(chunk_size, dtype=torch.float32)
#                 s = torch.zeros(chunk_size, dtype=torch.float32)
#                 v_offsets = torch.arange(0, V_chunk_size)

#                 A_chunks_in_tile = A_chunks[v_tile_idx * V_tile_size: (v_tile_idx+1) * V_tile_size]

#                 for v_idx, A_v in enumerate(A_chunks_in_tile):
#                     m_prev = m.clone()
#                     z_j_to_k = (x_chunk.matmul(A_v.T)).to(dtype=torch.float32)  
#                     m = torch.maximum(m_prev, z_j_to_k.max(dim=-1)[0])
#                     s = s * (m_prev - m).exp() + (z_j_to_k - m[:, None]).exp().sum(dim=-1)

#                     mask = y_chunk[:, None] == v_offsets[None, :]
#                     losses[idx, v_tile_idx] -= torch.where(mask, z_j_to_k, 0.0).sum() / N

#                     v_offsets += V_chunk_size


#                 # losses[idx, v_tile_idx] += (m + s.log()).mean() / num_chunks
#                 m_global[idx * chunk_size: (idx+1)*chunk_size, v_tile_idx] = m
#                 s_global[idx * chunk_size: (idx+1)*chunk_size, v_tile_idx] = s
    
#     # Collect 
#     m_global_reduced = m_global.max(dim=1)[0]
#     # s_global_reduced = s_global * (m_global - m_global_reduced).exp().sum(dim=-1) + (s_global.log() - m_global_reduced).exp().sum(dim=-1)
#     # s_global_reduced = (m_global + s_global).sum(dim=1).log()
#     s_global_reduced = s_global.log().logsumexp(dim=1)

#     # final_loss = (losses.sum() + (m_global + s_global).sum(dim=1).log().mean() / num_chunks)
#     final_loss = (losses.sum() + (m_global_reduced + s_global_reduced.log())).mean() / num_chunks

    
#     return final_loss, m_global_reduced, s_global_reduced

# loss, m, s = manual_implementation_chunked_online_lse_single_loop_bf16_blockV_2_tileV(x, y, A)
# loss.item(), torch.dist(loss, reference_loss).item(), torch.dist(m, m_ref), torch.dist(s, s_ref)

In [None]:
cosim(s, s_ref), s, s_ref, s.mean(), s_ref.mean()

In [None]:
A.shape

In [None]:
H_chunk_size = 16
V_chunk_size = 64

@torch.no_grad
def manual_implementation_chunked_online_lse_bwd_single_loop_bf16_blockV_blockH(x, y, A, m_global, s_global):
    compute_dtype = torch.float32
    x = x.to(dtype=compute_dtype) # float16 is much more accurate for "sane" values of x and A
    A = A.to(dtype=compute_dtype)

    x_chunks = x.view(-1, H).split(chunk_size)
    y_chunks = y.view(-1).split(chunk_size)

    A_chunks = A.split(V_chunk_size)
    v_offsets = torch.arange(V)

    Agrad = torch.zeros_like(A, dtype=torch.float32)
    global_x_grad = torch.zeros_like(x.view(-1, H), dtype=torch.float32)

    for idx, (x_chunk, y_chunk) in enumerate(zip(x_chunks, y_chunks)):
        s = s_global[idx * chunk_size: (idx+1)*chunk_size]
        m = m_global[idx * chunk_size: (idx+1)*chunk_size]
        xgrad = torch.zeros_like(x_chunk)
        Nc = x_chunk.shape[0]
        Vc = A_chunks[0].shape[0]

        for v_idx, A_v in enumerate(A_chunks): # can parallelize
            v_range = v_offsets[v_idx * V_chunk_size : (v_idx+1) * V_chunk_size]

            z_j_to_k = torch.zeros(Nc, Vc)
            for h_idx in range(H // H_chunk_size):
                x_chunk_h = x_chunk[:, h_idx*H_chunk_size : (h_idx+1) * H_chunk_size]
                A_chunk_h = A_v[:, h_idx*H_chunk_size : (h_idx+1) * H_chunk_size]

                z_j_to_k += (x_chunk_h @ A_chunk_h.T).to(dtype=torch.float32)
            
            softmax_z = ((z_j_to_k - m[:, None]).exp() / s[:, None]).to(dtype=compute_dtype)
            mask = y_chunk[:, None,] == v_range[None, :]

            for h_idx in range(H // H_chunk_size):
                x_chunk_h = x_chunk[:, h_idx*H_chunk_size : (h_idx+1) * H_chunk_size]
                A_chunk_h = A_v[:, h_idx*H_chunk_size : (h_idx+1) * H_chunk_size]


                xgrad_temp = (softmax_z @  A_chunk_h).to(dtype=torch.float32)
                xgrad_temp -= torch.where(mask[:, :, None], A_chunk_h[None, :, :], 0.0).sum(dim=1)
                xgrad[:, h_idx*H_chunk_size : (h_idx+1) * H_chunk_size] += xgrad_temp


                Agrad_temp = (softmax_z.T @ x_chunk_h).to(dtype=torch.float32)
                Agrad_temp -= torch.where(mask[:,:,None], x_chunk_h[:, None, :], 0.0).sum(dim=0)
                Agrad[v_range, h_idx*H_chunk_size : (h_idx+1) * H_chunk_size] += Agrad_temp
    
        global_x_grad[idx * chunk_size: (idx+1)*chunk_size] = xgrad / N
    
    return Agrad / N, global_x_grad.view_as(x)

Agrad, xgrad = manual_implementation_chunked_online_lse_bwd_single_loop_bf16_blockV_blockH(x, y, A, m, s)
torch.dist(reference_x_grad, xgrad), cosim(reference_x_grad, xgrad), torch.dist(reference_A_grad, Agrad), cosim(reference_A_grad, Agrad)

In [None]:
@torch.no_grad
def manual_implementation_chunked_online_lse_bwd_single_loop_bf16_blockV(x, y, A, m_global, s_global):
    compute_dtype = torch.float32
    x = x.to(dtype=compute_dtype) # float16 is much more accurate for "sane" values of x and A
    A = A.to(dtype=compute_dtype)

    x_chunks = x.view(-1, H).split(chunk_size)
    y_chunks = y.view(-1).split(chunk_size)

    A_chunks = A.split(V_chunk_size)
    v_offsets = torch.arange(V)

    Agrad = torch.zeros_like(A, dtype=torch.float32)
    global_x_grad = torch.zeros_like(x.view(-1, H), dtype=torch.float32)

    for idx, (x_chunk, y_chunk) in enumerate(zip(x_chunks, y_chunks)):
        s = s_global[idx * chunk_size: (idx+1)*chunk_size]
        m = m_global[idx * chunk_size: (idx+1)*chunk_size]
        xgrad = torch.zeros_like(x_chunk)

        for v_idx, A_v in enumerate(A_chunks):
            v_range = v_offsets[v_idx * V_chunk_size : (v_idx+1) * V_chunk_size]

            z_j_to_k = (x_chunk @ A_v.T).to(dtype=torch.float32) 
            xgrad += (((z_j_to_k - m[:, None]).exp() / s[:, None]).to(dtype=compute_dtype) @  A_v).to(dtype=torch.float32)
            Agrad[v_range] += (((z_j_to_k - m[:, None]).exp() / s[:, None]).T.to(dtype=compute_dtype) @ x_chunk).to(dtype=torch.float32)

            mask = y_chunk[:, None,] == v_range[None, :]
            if mask.sum() > 0:
                Agrad[v_range] -= torch.where(mask[:,:,None], x_chunk[:, None, :], 0.0).sum(dim=0) # reduction over N
                xgrad -= torch.where(mask[:, :, None], A_v, 0.0).sum(dim=1)

        global_x_grad[idx * chunk_size: (idx+1)*chunk_size] = xgrad / N
    
    return Agrad / N, global_x_grad.view_as(x)

Agrad, xgrad = manual_implementation_chunked_online_lse_bwd_single_loop_bf16_blockV(x, y, A, m, s)
torch.dist(reference_x_grad, xgrad), cosim(reference_x_grad, xgrad), torch.dist(reference_A_grad, Agrad), cosim(reference_A_grad, Agrad)

In [None]:
H_chunk_size = 16
V_chunk_size = 32
print(chunk_size, V_chunk_size, H_chunk_size)
print(N, V, H)

@torch.no_grad
def manual_implementation_chunked_online_lse_bwd_single_loop_bf16_blockV_blockH_At(x, y, A, m_global, s_global):
    compute_dtype = torch.float32
    x = x.to(dtype=compute_dtype) # float16 is much more accurate for "sane" values of x and A
    At = A.T.to(dtype=compute_dtype)

    x_chunks = x.view(-1, H).split(chunk_size)
    y_chunks = y.view(-1).split(chunk_size)

    At_chunks = At.split(V_chunk_size, dim=1)
    v_offsets = torch.arange(V)

    Atgrad = torch.zeros_like(At, dtype=torch.float32)
    global_x_grad = torch.zeros_like(x.view(-1, H), dtype=torch.float32)

    for idx, (x_chunk, y_chunk) in enumerate(zip(x_chunks, y_chunks)):
        s = s_global[idx * chunk_size: (idx+1)*chunk_size]
        m = m_global[idx * chunk_size: (idx+1)*chunk_size]
        xgrad = torch.zeros_like(x_chunk)
        Nc = x_chunk.shape[0]
        Vc = At_chunks[0].shape[1]

        for v_idx, A_v in enumerate(At_chunks): # can parallelize
            v_range = v_offsets[v_idx * V_chunk_size : (v_idx+1) * V_chunk_size]

            z_j_to_k = torch.zeros(Nc, Vc)
            for h_idx in range(H // H_chunk_size):
                x_chunk_h = x_chunk[:, h_idx*H_chunk_size : (h_idx+1) * H_chunk_size]
                A_chunk_h = A_v[h_idx*H_chunk_size : (h_idx+1) * H_chunk_size, :]

                z_j_to_k += (x_chunk_h @ A_chunk_h).to(dtype=torch.float32)
            
            softmax_z = ((z_j_to_k - m[:, None]).exp() / s[:, None]).to(dtype=compute_dtype)
            mask = (y_chunk[:, None] == v_range[None, :])[:,:,None] # needs to be N_BLOCK_SIZE x V_BLOCK_SIZE x 1 ?

            for h_idx in range(H // H_chunk_size):
                x_chunk_h = x_chunk[:, h_idx*H_chunk_size : (h_idx+1) * H_chunk_size]
                A_chunk_h = A_v[h_idx*H_chunk_size : (h_idx+1) * H_chunk_size, :]


                xgrad_temp = (softmax_z @  A_chunk_h.T).to(dtype=torch.float32)
                xgrad_temp -= torch.where(mask, A_chunk_h.T[None, :, :], 0.0).sum(dim=1)
                xgrad[:, h_idx*H_chunk_size : (h_idx+1) * H_chunk_size] += xgrad_temp


                Agrad_temp = (softmax_z.T @ x_chunk_h).to(dtype=torch.float32)
                Agrad_temp -= torch.where(mask, x_chunk_h[:, None, :], 0.0).sum(dim=0)
                Atgrad[h_idx*H_chunk_size : (h_idx+1) * H_chunk_size, v_range] += Agrad_temp.T
    
        global_x_grad[idx * chunk_size: (idx+1)*chunk_size] = xgrad / N
    
    return (Atgrad / N).T, global_x_grad.view_as(x)

Agrad, xgrad = manual_implementation_chunked_online_lse_bwd_single_loop_bf16_blockV_blockH_At(x, y, A, m, s)
torch.dist(reference_x_grad, xgrad), cosim(reference_x_grad, xgrad), torch.dist(reference_A_grad, Agrad), cosim(reference_A_grad, Agrad)

In [None]:
H_chunk_size = 16
V_chunk_size = 32
print(chunk_size, V_chunk_size, H_chunk_size)
print(N, V, H)
lse_global = m + s.log()

@torch.no_grad
def manual_implementation_chunked_online_lse_bwd_single_loop_bf16_blockV_blockH_At_double_recomp(x, y, A, lse_global):
    compute_dtype = torch.float32
    x = x.to(dtype=compute_dtype) # float16 is much more accurate for "sane" values of x and A
    At = A.T.to(dtype=compute_dtype)

    x_chunks = x.view(-1, H).split(chunk_size)
    y_chunks = y.view(-1).split(chunk_size)

    At_chunks = At.split(V_chunk_size, dim=1)
    v_offsets = torch.arange(V)

    Atgrad = torch.zeros_like(At, dtype=torch.float32)
    global_x_grad = torch.zeros_like(x.view(-1, H), dtype=torch.float32)


    for v_idx, A_v in enumerate(At_chunks): # can parallelize
        v_range = v_offsets[v_idx * V_chunk_size : (v_idx+1) * V_chunk_size]

        for idx, (x_chunk, y_chunk) in enumerate(zip(x_chunks, y_chunks)):
            lse = lse_global[idx * chunk_size: (idx+1)*chunk_size]
            xgrad = torch.zeros_like(x_chunk)
            Nc = x_chunk.shape[0]
            Vc = At_chunks[0].shape[1]

            z_j_to_k = torch.zeros(Nc, Vc)
            for h_idx in range(H // H_chunk_size):
                x_chunk_h = x_chunk[:, h_idx*H_chunk_size : (h_idx+1) * H_chunk_size]
                A_chunk_h = A_v[h_idx*H_chunk_size : (h_idx+1) * H_chunk_size, :]

                z_j_to_k += (x_chunk_h @ A_chunk_h).to(dtype=torch.float32)
            
            lse = lse_global[idx * chunk_size: (idx+1)*chunk_size]
            softmax_z = ((z_j_to_k - lse[:, None]).exp()).to(dtype=compute_dtype)
            mask = (y_chunk[:, None] == v_range[None, :])[:,:,None] # needs to be N_BLOCK_SIZE x V_BLOCK_SIZE x 1 ?

            for h_idx in range(H // H_chunk_size):
                x_chunk_h = x_chunk[:, h_idx*H_chunk_size : (h_idx+1) * H_chunk_size]
                A_chunk_h = A_v[h_idx*H_chunk_size : (h_idx+1) * H_chunk_size, :]

                Agrad_temp = (softmax_z.T @ x_chunk_h).to(dtype=torch.float32)
                Agrad_temp -= torch.where(mask, x_chunk_h[:, None, :], 0.0).sum(dim=0)
                Atgrad[h_idx*H_chunk_size : (h_idx+1) * H_chunk_size, v_range] += Agrad_temp.T
    

    for idx, (x_chunk, y_chunk) in enumerate(zip(x_chunks, y_chunks)):
        lse = lse_global[idx * chunk_size: (idx+1)*chunk_size]
        xgrad = torch.zeros_like(x_chunk)
        Nc = x_chunk.shape[0]
        Vc = At_chunks[0].shape[1]

        for v_idx, A_v in enumerate(At_chunks): # can parallelize
            v_range = v_offsets[v_idx * V_chunk_size : (v_idx+1) * V_chunk_size]

            z_j_to_k = torch.zeros(Nc, Vc)
            for h_idx in range(H // H_chunk_size):
                x_chunk_h = x_chunk[:, h_idx*H_chunk_size : (h_idx+1) * H_chunk_size]
                A_chunk_h = A_v[h_idx*H_chunk_size : (h_idx+1) * H_chunk_size, :]

                z_j_to_k += (x_chunk_h @ A_chunk_h).to(dtype=torch.float32)
            
            lse = lse_global[idx * chunk_size: (idx+1)*chunk_size]
            softmax_z = ((z_j_to_k - lse[:, None]).exp()).to(dtype=compute_dtype)
            mask = (y_chunk[:, None] == v_range[None, :])[:,:,None] # needs to be N_BLOCK_SIZE x V_BLOCK_SIZE x 1 ?

            for h_idx in range(H // H_chunk_size):
                x_chunk_h = x_chunk[:, h_idx*H_chunk_size : (h_idx+1) * H_chunk_size]
                A_chunk_h = A_v[h_idx*H_chunk_size : (h_idx+1) * H_chunk_size, :]


                xgrad_temp = (softmax_z @  A_chunk_h.T).to(dtype=torch.float32)
                xgrad_temp -= torch.where(mask, A_chunk_h.T[None, :, :], 0.0).sum(dim=1)
                xgrad[:, h_idx*H_chunk_size : (h_idx+1) * H_chunk_size] += xgrad_temp

    
        global_x_grad[idx * chunk_size: (idx+1)*chunk_size] = xgrad / N
    
    return (Atgrad / N).T, global_x_grad.view_as(x)

Agrad, xgrad = manual_implementation_chunked_online_lse_bwd_single_loop_bf16_blockV_blockH_At_double_recomp(x, y, A, lse_global)
torch.dist(reference_x_grad, xgrad), cosim(reference_x_grad, xgrad), torch.dist(reference_A_grad, Agrad), cosim(reference_A_grad, Agrad)