In [1]:
from flash_attn.flash_attn_triton import flash_attn_func
from torch import nn
import torch

class Attention(nn.Module):
    def __init__(self):
        super(Attention, self).__init__()
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, q, k, v, mask=None):
        scores = (torch.bmm(q, k.transpose(-2, -1)))
        if mask is not None:
            scores = scores.masked_fill(mask, float("-inf"))
        # print(self.softmax(
        #     scores / torch.sqrt(torch.tensor(k.shape[-1]))
        # ))
        return torch.bmm(self.softmax(
            scores / torch.sqrt(torch.tensor(k.shape[-1]))
        ), v)

class ScaledDotProductAttention(nn.Module):
    def __init__(self, dim_in, dim_qk, dim_v, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.W_q = nn.Linear(dim_in, dim_qk, bias=False)
        self.W_k = nn.Linear(dim_in, dim_qk, bias=False)
        self.W_v = nn.Linear(dim_in, dim_v, bias=False)
        self.attention = Attention()

    def forward(self, x, mask=None):
        q = self.W_q(x)
        k = self.W_k(x)
        v = self.W_v(x)

        print(f"q: {q.shape}, k: {k.shape}, v: {v.shape}")

        return self.attention(q, k, v, mask=mask)

def causal_mask(T: int, device=None, dtype=torch.bool):
    # True above the diagonal ⇒ blocked
    return torch.triu(torch.ones(T, T, dtype=dtype, device=device), diagonal=1).bool()

att = ScaledDotProductAttention(
    dim_in=8,
    dim_qk=6,
    dim_v=4,
)

sample_x = torch.randn(4, 5, 8)
res = att(sample_x, mask=causal_mask(sample_x.shape[-2]))
print(f"res: {res}, shape: {res.shape}")

q: torch.Size([4, 5, 6]), k: torch.Size([4, 5, 6]), v: torch.Size([4, 5, 4])
res: tensor([[[ 0.4613,  0.2435, -1.5595,  0.8940],
         [ 0.3192,  0.1920, -0.6335,  0.6146],
         [-0.1549, -0.0481, -1.1615,  0.5916],
         [-0.1328,  0.0587, -0.6640,  0.1650],
         [-0.2915,  0.3614, -0.7433,  0.1308]],

        [[ 0.3559, -0.0821, -1.0970, -0.2304],
         [-0.0574,  0.0532, -0.8686, -0.2913],
         [-0.1693,  0.1506, -0.8161, -0.2438],
         [-0.3403,  0.0036, -0.4348, -0.2176],
         [-0.3482,  0.0253, -0.1736, -0.1784]],

        [[-0.6685, -0.0771, -0.5147, -0.3106],
         [-0.5942,  0.0873, -0.0999, -0.1095],
         [ 0.8770,  0.2911,  0.9244,  0.0208],
         [-0.4599,  0.2917,  0.1165, -0.2777],
         [-0.2725,  0.2296,  0.0399, -0.2364]],

        [[ 0.6283,  0.6459,  2.0024, -0.0666],
         [ 0.4360,  0.1287,  1.1921,  0.3967],
         [-0.7163,  0.2043,  0.3544,  0.2952],
         [-0.3025,  0.3566,  0.5062,  0.1068],
         [-0.1466, 

In [33]:
class MultiHeadAttention(nn.Module):
    def __init__(self, dim_in, dim_qk, dim_v, dim_out, num_heads, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.W_q = nn.ModuleList([nn.Linear(dim_in, dim_qk, bias=False) for _ in range(num_heads)])
        self.W_k = nn.ModuleList([nn.Linear(dim_in, dim_qk, bias=False) for _ in range(num_heads)])
        self.W_v = nn.ModuleList([nn.Linear(dim_in, dim_v, bias=False) for _ in range(num_heads)])

        self.W_out = nn.Linear(dim_v * num_heads, dim_out, bias=False)

        self.attention = Attention()

    def forward(self, x, mask=None):
        V = []
        for W_q, W_k, W_v in zip(self.W_q, self.W_k, self.W_v):
            v_head = self.attention(W_q(x), W_k(x), W_v(x), mask=mask)
            V.append(v_head)

        return self.W_out(torch.concat(V, dim=-1))

mha = MultiHeadAttention(
    dim_in=8,
    dim_qk=6,
    dim_v=4,
    dim_out=10,
    num_heads=4,
)

sample_x = torch.randn(4, 5, 8)
result = mha(sample_x)
print(f"shape: {result.shape}")

shape: torch.Size([4, 5, 10])


In [32]:
from torch.cuda.amp import autocast

class MultiHeadFlashAttention(nn.Module):
    def __init__(self, dim_in, num_heads, dim_head, causal=False, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.W_q = nn.Linear(dim_in, dim_head * num_heads, bias=False)
        self.W_k = nn.Linear(dim_in, dim_head * num_heads, bias=False)
        self.W_v = nn.Linear(dim_in, dim_head * num_heads, bias=False)

        self.W_out = nn.Linear(dim_head * num_heads, dim_in, bias=False)

        self.attention = Attention()

        self.num_heads = num_heads
        self.dim_head = dim_head
        self.dim_in = dim_in
        self.causal = causal

    def forward(self, x):
        with torch.autocast('cuda'):
            b, n = x.shape[:2]
            q, k, v = self.W_q(x), self.W_k(x), self.W_v(x)
            q = q.view(b, n, self.num_heads, self.dim_head)
            k = k.view(b, n, self.num_heads, self.dim_head)
            v = v.view(b, n, self.num_heads, self.dim_head)

            return self.W_out(flash_attn_func(q, k, v, causal=self.causal).view(b, n, -1))

mha = MultiHeadFlashAttention(
    dim_in=512,
    num_heads=8,
    dim_head=64
).to("cuda")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sample_x = torch.randn(4, 5, 512).to(device)
result = mha(sample_x)
print(f"shape: {result.shape}")

shape: torch.Size([4, 5, 512])


In [26]:
class TransformerEncoderLayer(nn.Module):
    def __init__(self, dim_in, dim_qk, dim_v, num_heads, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.mha = MultiHeadAttention(
            dim_in=dim_in,
            dim_qk=dim_qk,
            dim_v=dim_v,
            dim_out=dim_in,
            num_heads=num_heads,
        )

        self.mha_norm = nn.LayerNorm(dim_in)
        self.activation = nn.GELU()
        self.ffn = nn.Sequential(
            nn.Linear(dim_in, dim_in),
            self.activation,
            nn.Linear(dim_in, dim_in)
        )
        self.ff_norm = nn.LayerNorm(dim_in)

    def forward(self, x):
        mha_out = self.mha_norm(self.mha(x)) + x
        return self.ff_norm(self.ffn(mha_out)) + mha_out

class TransformerEncoder(nn.Module):
    def __init__(self, dim_in, dim_qk, dim_v, num_heads, num_layers, *args, **kwargs):
        super().__init__(*args, **kwargs)
        layers = [TransformerEncoderLayer(dim_in=dim_in, dim_qk=dim_qk, dim_v=dim_v, num_heads=num_heads) for _ in range(num_layers)]
        self.layers = nn.ModuleList(layers)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

model = TransformerEncoder(
    dim_in=8,
    dim_qk=6,
    dim_v=4,
    num_heads=4,
    num_layers=4,
)

sample_x = torch.randn(4, 5, 8)
result = model(sample_x)
print(f"shape: {result.shape}, result: {result}")

shape: torch.Size([4, 5, 8]), result: tensor([[[-0.3201,  0.0625,  1.6362,  3.5891, -0.1413, -1.2206, -3.2086,
           1.0505],
         [ 1.5041, -1.1012,  2.5851,  5.8168, -1.7879, -2.2085, -4.2433,
           0.0286],
         [ 1.3578, -0.3948,  2.2239,  4.9473, -2.4554, -0.6486, -3.6839,
          -2.6097],
         [-0.6036, -1.3961, -1.3277,  4.9567,  0.5778, -0.3719, -5.2545,
           3.0133],
         [-0.6930, -0.3979, -1.5184,  5.1135, -0.1629,  0.2616, -6.4386,
           1.7458]],

        [[-1.9933,  0.2171,  1.4529,  4.4601, -0.3574, -0.3835, -3.5845,
           0.9353],
         [ 0.6849, -0.4494,  2.7661,  3.8220, -3.2947,  0.2268, -3.2978,
          -2.3724],
         [ 0.3149, -2.2079, -0.2901,  5.0354, -0.7963, -0.1709, -6.6625,
           2.0712],
         [ 1.7548, -1.8341,  1.8345,  4.6398, -2.3439,  0.8804, -4.2612,
          -1.5792],
         [ 0.6798,  0.6331,  2.6075,  3.2651, -3.4048, -0.1269, -4.0909,
          -1.3368]],

        [[-1.9374, -0.7035, 

In [39]:
class TransformerDecoderLayer(nn.Module):
    def __init__(self, dim_in, dim_qk, dim_v, num_heads, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.masked_mha = MultiHeadAttention(
            dim_in=dim_in,
            dim_qk=dim_qk,
            dim_v=dim_v,
            dim_out=dim_in,
            num_heads=num_heads,
        )
        self.masked_mha_norm = nn.LayerNorm(dim_in)

        self.mha = MultiHeadAttention(
            dim_in=dim_in,
            dim_qk=dim_qk,
            dim_v=dim_v,
            dim_out=dim_in,
            num_heads=num_heads,
        )
        self.mha_norm = nn.LayerNorm(dim_in)

        self.activation = nn.GELU()
        self.ffn = nn.Sequential(
            nn.Linear(dim_in, dim_in),
            self.activation,
            nn.Linear(dim_in, dim_in)
        )
        self.ff_norm = nn.LayerNorm(dim_in)

    def forward(self, x, mask):
        masked_mha_out = self.masked_mha_norm(self.masked_mha(x, mask=mask)) + x
        mha_out = self.mha_norm(self.mha(masked_mha_out)) + masked_mha_out
        return self.ff_norm(self.ffn(mha_out)) + mha_out

class TransformerDecoder(nn.Module):
    def __init__(self, dim_in, dim_qk, dim_v, num_heads, num_layers, *args, **kwargs):
        super().__init__(*args, **kwargs)
        layers = [TransformerDecoderLayer(dim_in=dim_in, dim_qk=dim_qk, dim_v=dim_v, num_heads=num_heads) for _ in range(num_layers)]
        self.layers = nn.ModuleList(layers)

    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return x

model = TransformerDecoder(
    dim_in=8,
    dim_qk=6,
    dim_v=4,
    num_heads=4,
    num_layers=4,
)

sample_x = torch.randn(4, 5, 8)
result = model(sample_x, mask=causal_mask(sample_x.shape[-2]))
print(f"shape: {result.shape}, result: {result}")

shape: torch.Size([4, 5, 8]), result: tensor([[[-2.1156,  0.3588, -0.1507,  3.9490, -0.3605, -0.3159, -1.5909,
          -4.3717],
         [ 0.2610, -2.9844,  3.2863,  3.9364,  2.0252, -1.0371, -4.7748,
          -1.9584],
         [ 2.7874, -2.9923,  3.7135,  4.7617,  4.1527, -1.1137, -4.7065,
          -3.9936],
         [ 3.7710, -5.5868,  1.4442,  3.5567,  4.2420,  0.3629, -3.0979,
          -3.2620],
         [ 2.4560, -6.7225,  3.0729,  3.6182,  4.5079, -0.8883, -3.1035,
          -4.3668]],

        [[ 0.0565, -2.5177,  4.2840,  4.7264, -1.3165,  0.0393,  0.7770,
          -4.8648],
         [ 2.0736, -5.3310,  5.5753,  2.2674,  1.2505, -5.0134, -1.5506,
          -0.4883],
         [ 4.6582, -2.4041,  6.4920,  4.5495, -1.9759, -4.1674, -4.6975,
          -4.9908],
         [ 3.5562, -1.6611,  5.5223,  5.5378, -0.5072, -3.8585, -4.5194,
          -5.0245],
         [ 5.5365, -2.4497,  4.8449,  4.3338,  1.4866, -2.8502, -3.9409,
          -5.0749]],

        [[ 2.6910, -4.1993, 