In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class multihead(nn.Module):
    def __init__(self, d_in, d_out, context_size, n_heads , dropout=0.0):
        super().__init__()

        assert d_out % n_heads == 0, "d_in must be divisible by n_heads"

        self.d_in = d_in
        self.d_out = d_out
        self.n_heads = n_heads
        self.context_size = context_size
        self.d_head = d_out // n_heads
        self.dropout = dropout

        self.q = nn.Linear(d_in, d_out, bias=False)
        self.k = nn.Linear(d_in, d_out, bias=False)
        self.v = nn.Linear(d_in, d_out, bias=False)

    def forward(self, x):
        b, num_tokens, d_in = x.size()

        q = self.q(x).view(b, num_tokens, self.n_heads, self.d_head)
        k = self.k(x).view(b, num_tokens, self.n_heads, self.d_head)
        v = self.v(x).view(b, num_tokens, self.n_heads, self.d_head)

        q = q.permute(0, 2, 1, 3)  # (b, n_heads, num_tokens, d_head)
        k = k.permute(0, 2, 1, 3)
        v = v.permute(0, 2, 1, 3)

        att = F.scaled_dot_product_attention(
            q, k, v,
            attn_mask=None,
            dropout_p=self.dropout,
            is_causal=False,
        )

        att = att.permute(0, 2, 1, 3).contiguous()
        att = att.view(b, num_tokens, self.d_out)
        return att


In [3]:
class FeedForward(nn.Module):
    def __init__(self, n_embed, dropout=0.0):
        super().__init__()
        self.net = nn.Sequential(
        nn.Linear(n_embed, 4*n_embed),
        nn.ReLU(),
        nn.Linear(4*n_embed, n_embed),
        nn.Dropout(dropout)
        )
        # self.layer_norm = nn.LayerNorm(d_out)

    def forward(self, x):
        return self.net(x)

In [None]:
class Block(nn.Module):
    def __init__(self, d_in, d_out, context_size, n_heads, dropout=0.0):
        super().__init__()
        self.att = multihead(d_in, d_out, context_size, n_heads, dropout)
        self.ff = FeedForward(d_out, dropout)
        self.layer_norm1 = nn.LayerNorm(d_out)
        self.layer_norm2 = nn.LayerNorm(d_out)

    def forward(self, x):
        x = x + self.att(self.layer_norm1(x))
        x = x + self.ff(self.layer_norm2(x))
        return x