In [3]:
import torch
import torch.nn as nn
from einops import rearrange, einsum


def softmax(x: torch.Tensor, dim: int) -> torch.Tensor:
    x_max, _ = torch.max(x, dim=dim, keepdim=True)
    return torch.exp(x - x_max) / torch.sum(torch.exp(x - x_max), dim=dim, keepdim=True)

def Swish(x: torch.Tensor) -> torch.Tensor:
    # Swish(x) = x * sigmoid(x)
    return x * torch.sigmoid(x)

def scaled_dot_product_attention(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
    '''
    Args:
        Q: (batch, seq_len, d_k)
        K: (batch, seq_len, d_k)
        V: (batch, seq_len, d_v)
        mask: (batch, seq_len, seq_len) - optional mask, 1 for valid positions, 0 for masked ones
    Returns:
        Output tensor: (batch, seq_len, d_v)
    '''

    d_k = Q.size(-1)
    qk = einsum(Q, K, '... queries d_k, ... keys d_k -> ... queries keys')
    if mask is not None:
        qk = qk.masked_fill(mask == 0, float('-inf'))

    attn = softmax(qk / d_k ** 0.5, dim=-1)
    output = einsum(attn, V, '... queries keys, ... keys d_v -> ... queries d_v')

    return output

In [4]:
class FlashAttentionFunc(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, weight, grad_out):
        ctx.save_for_backward(x, weight, grad_out)
        return x @ weight.T

    @staticmethod
    def backward(ctx, grad_out):
        x, weight, grad_out = ctx.saved_tensors
        grad_x = grad_out @ weight
        grad_weight = grad_out.T @ x
        return grad_x, grad_weight, None



tensor(10.0001)
tensor(9.9531, dtype=torch.float16)
tensor(10.0021)
tensor(10.0021)
