In [40]:
import torch
import torch.nn as nn
from beartype import beartype as typechecker
import numpy as np
from jaxtyping import Float, jaxtyped
from torch import Tensor

In [41]:
# Positional embedding as defined by Vaswani et al., 2017.
class PosEmbed(nn.Module):

    def __init__(self, d_model=4, n_tokens=8, debug=False):
        super().__init__()
        self.debug = debug
        self.d_model = d_model
        self.n_tokens = n_tokens

        posEmbedLst = []
        for pos in range(n_tokens):
            row = []
            for dim in range(d_model):
                if dim % 2 == 0:
                    row += [self.sin(pos, dim)]
                else:
                    row += [self.cos(pos, dim)]
            
            posEmbedLst.append(row.copy())

        self.posEmbed : Float[Tensor, "n_tokens d_model"] = (
            torch.tensor(posEmbedLst)
        )
        if debug:
            print(f"{self.posEmbed=}")

    def sin(self, pos, evenDim):
        return torch.sin(
            torch.tensor(pos/self.n_tokens**(evenDim/self.d_model))
        )
    
    def cos(self, pos, oddDim):
        return torch.cos(
            torch.tensor(pos/self.n_tokens**(oddDim/self.d_model))
        )
    
    @jaxtyped(typechecker=typechecker)
    def forward(
        self, input: Float[Tensor, "n_tokens d_model"]
    ) -> Float[Tensor, "n_tokens d_model"]:
        # I actually dont need and input!
        r = input + self.posEmbed
        if self.debug:
            print(f"{r=}")
        return r

In [42]:
a = torch.ones((8, 4))
PosEmbed().forward(a)

tensor([[1.0000, 2.0000, 1.0000, 2.0000],
        [1.8415, 1.8284, 1.3462, 1.9780],
        [1.9093, 1.3724, 1.6496, 1.9129],
        [1.1411, 0.7886, 1.8727, 1.8076],
        [0.2432, 0.2774, 1.9878, 1.6668],
        [0.0411, 0.0142, 1.9807, 1.4966],
        [0.7206, 0.0894, 1.8523, 1.3045],
        [1.6570, 0.4772, 1.6184, 1.0991]])

In [43]:
# Input shape: n_tokens x d_model
class Head(nn.Module):

    def __init__(self, d_model=4, n_tokens=8, n_head=2, debug=False):
        super().__init__()
        self.d_model = d_model
        self.n_tokens = n_tokens
        self.n_head = n_head
        self.debug = debug

        self.d_head = d_head = d_model // n_head
        self.wQ: Float[Tensor, "d_model d_head"] = (
            torch.rand((d_model, d_head))
        )
        self.wK: Float[Tensor, "d_model d_head"] = (
            torch.rand((d_model, d_head))
        )
        self.wV: Float[Tensor, "d_model d_head"] = (
            torch.rand((d_model, d_head))
        )

    @jaxtyped(typechecker=typechecker)
    def forward(
        self, input: Float[Tensor, "n_tokens d_model"]
    ) -> Float[Tensor, "n_tokens d_head"]:
        Q: Float[Tensor, "n_tokens d_head"] = input @ self.wQ
        K: Float[Tensor, "n_tokens d_head"] = input @ self.wK
        A: Float[Tensor, "n_tokens n_token"] = Q @ K.transpose(dim0=-2, dim1=-1)
        scores: Float[Tensor, "n_tokens n_token"] = (
            torch.softmax(A / self.d_model**(1/2), dim=-1)
        )
        V: Float[Tensor, "n_tokens d_head"] = input @ self.wV
        output: Float[Tensor, "n_tokens d_head"] = scores @ V
        return output

In [44]:
a = torch.ones((8, 4))
Head().forward(a)

tensor([[2.4602, 1.4841],
        [2.4602, 1.4841],
        [2.4602, 1.4841],
        [2.4602, 1.4841],
        [2.4602, 1.4841],
        [2.4602, 1.4841],
        [2.4602, 1.4841],
        [2.4602, 1.4841]])

In [45]:
class FeedForward(nn.Module):

    def __init__(self, d_model=4, n_tokens=8, mlp_factor=4, debug=False):
        super().__init__()
        self.d_model = d_model
        self.n_tokens = n_tokens
        self.mlp_factor = mlp_factor
        self.debug = debug

        self.encode: Float[Tensor, "d_model d_model*mlp_factor"] = (
            torch.rand((d_model, d_model*mlp_factor))
        )
        self.decode: Float[Tensor, "d_model*mlp_factor d_model"] = (
            torch.rand((d_model*mlp_factor, d_model))
        )

    def forward(
            self, input: Float[Tensor, "n_tokens d_model"]
    ) -> Float[Tensor, "n_tokens d_model"]:
        output: Float[Tensor, "n_tokens d_model"] = (
            input @ self.encode @ self.decode 
        )
        return output 

In [48]:
FeedForward().forward(torch.ones((8, 4)))

tensor([[18.2701, 15.6138, 15.3038, 14.6008],
        [18.2701, 15.6138, 15.3038, 14.6008],
        [18.2701, 15.6138, 15.3038, 14.6008],
        [18.2701, 15.6138, 15.3038, 14.6008],
        [18.2701, 15.6138, 15.3038, 14.6008],
        [18.2701, 15.6138, 15.3038, 14.6008],
        [18.2701, 15.6138, 15.3038, 14.6008],
        [18.2701, 15.6138, 15.3038, 14.6008]])

In [None]:
class LayerNorm(nn.Module):

    def __init__(self, d_model=4, n_tokens=8, debug=False):
        self.d_model = d_model
        self.n_tokens = n_tokens
        self.debug = debug

    def forward(
        self, input: Float[Tensor, "n_tokens d_model"]
    ) -> Float[Tensor, "n_tokens d_model"]:
        output = (input - input.mean(dim=-1))/input.std(dim=-1)
        return output

In [51]:
class HiddenLayer(nn.Module):

    def __init__(self, d_model=4, n_tokens=8, n_head=2, mlp_factor=4, debug=False):
        super().__init__()
        self.d_model = d_model
        self.n_tokens = n_tokens
        self.n_head = n_head
        self.mlp_factor = mlp_factor
        self.debug = debug

        self.heads = nn.ModuleList([
            Head(
                d_model=d_model, n_tokens=n_tokens, n_head=n_head, debug=debug
            ) for _ in range(n_head)
        ])
        self.wO: Float[Tensor, "d_model d_model"] = torch.rand((
            d_model, d_model                                                         
        ))
        self.ffw = FeedForward(
            d_model=d_model, n_tokens=n_tokens,
            mlp_factor=mlp_factor, debug=debug
        )

    def forward(
        self, input: Float[Tensor, "n_tokens d_model"]
    ) -> Float[Tensor, "n_tokens d_model"]:
        outsLst: list[Float[Tensor, "n_token d_model/n_head"]] = [
            h(input.clone()).tolist() for h in self.heads
        ]
        outs: Float[Tensor, "n_head n_tokens d_model/n_head"] = torch.tensor(
            outsLst
        )
        outsCat: Float[Tensor, "n_tokens d_model"] = outs.view(
            -1, self.d_model
        )
        outsProj: Float[Tensor, "n_tokens d_model"] = outsCat @ self.wO
        output: Float[Tensor, "n_tokens d_model"] = self.ffw(outsProj)
        return output


In [52]:
HiddenLayer().forward(torch.ones((8, 4)))

tensor([[ 93.8969,  97.5771, 102.0062,  83.3961],
        [ 93.8969,  97.5771, 102.0062,  83.3961],
        [ 93.8969,  97.5771, 102.0062,  83.3961],
        [ 93.8969,  97.5771, 102.0062,  83.3961],
        [ 78.6800,  81.8623,  85.5446,  69.9256],
        [ 78.6800,  81.8623,  85.5446,  69.9256],
        [ 78.6800,  81.8623,  85.5446,  69.9256],
        [ 78.6800,  81.8623,  85.5446,  69.9256]])

In [12]:
class Transformer(nn.Module):

    def __init__(
        self, num_layers=5, d_model=4, n_vocab=16,
        n_tokens=8, n_head=2, mlp_factor=4, debug=False
    ):
        super().__init__()
        self.num_layers = num_layers
        self.d_model = d_model
        self.n_vocab = n_vocab
        self.n_tokens = n_tokens
        self.n_head = n_head
        self.mlp_factor = mlp_factor
        self.debug = debug

        self.embed: Float[Tensor, "n_vocab d_model"] = (
            torch.rand((n_vocab, d_model))
        )
        self.posEmbed = PosEmbed(
            d_model=d_model, n_tokens=n_tokens, debug=debug
        )
        self.hiddenLayers = nn.ModuleList([
            HiddenLayer(
                d_model=d_model, n_tokens=n_tokens, n_head=n_head,
                mlp_factor=mlp_factor, debug=debug
            ) for _ in range(num_layers)
        ])
