In [252]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from beartype import beartype as typechecker
import numpy as np
from jaxtyping import UInt, Float, jaxtyped
from torch import Tensor
import random
import numpy as np

In [253]:
def set_seed():
    torch.manual_seed(4)
    random.seed(42)
    np.random.seed(42)

set_seed()

In [286]:
# Positional embedding as defined by Vaswani et al., 2017.
class PosEmbed(nn.Module):

    def __init__(self, d_model=4, n_tokens=8, debug=False):
        super().__init__()
        self.debug = debug
        self.d_model = d_model
        self.n_tokens = n_tokens

        posEmbedLst = []
        for pos in range(n_tokens):
            row = []
            for dim in range(d_model):
                if dim % 2 == 0:
                    row += [self.sin(pos, dim)]
                else:
                    row += [self.cos(pos, dim)]
            
            posEmbedLst.append(row.copy())

        self.posEmbed : Float[Tensor, "n_tokens d_model"] = nn.Parameter((
            torch.tensor(posEmbedLst)
        ))
        if debug:
            print(f"{self.posEmbed=}")

    def sin(self, pos, evenDim):
        return torch.sin(
            torch.tensor(pos/self.n_tokens**(evenDim/self.d_model))
        )
    
    def cos(self, pos, oddDim):
        return torch.cos(
            torch.tensor(pos/self.n_tokens**(oddDim/self.d_model))
        )
    
    @jaxtyped(typechecker=typechecker)
    def forward(
        self, input: Float[Tensor, "*batch_size n_tokens d_model"]
    ) -> Float[Tensor, "*batch_size n_tokens d_model"]:
        # I actually dont need and input!
        r = input + self.posEmbed
        if self.debug:
            display(f"{r=}")
        return r

In [287]:
a = torch.ones((8, 4))
PosEmbed().forward(a)

tensor([[1.0000, 2.0000, 1.0000, 2.0000],
        [1.8415, 1.8284, 1.3462, 1.9780],
        [1.9093, 1.3724, 1.6496, 1.9129],
        [1.1411, 0.7886, 1.8727, 1.8076],
        [0.2432, 0.2774, 1.9878, 1.6668],
        [0.0411, 0.0142, 1.9807, 1.4966],
        [0.7206, 0.0894, 1.8523, 1.3045],
        [1.6570, 0.4772, 1.6184, 1.0991]], grad_fn=<AddBackward0>)

In [288]:
# Input shape: n_tokens x d_model
class Head(nn.Module):

    def __init__(self, d_model=4, n_tokens=8, n_head=2, debug=False):
        super().__init__()
        self.d_model = d_model
        self.n_tokens = n_tokens
        self.n_head = n_head
        self.debug = debug

        self.d_head = d_head = d_model // n_head
        self.wQ: Float[Tensor, "d_model d_head"] = nn.Parameter(
            torch.rand((d_model, d_head))
        )
        self.wK: Float[Tensor, "d_model d_head"] = nn.Parameter(
            torch.rand((d_model, d_head))
        )
        self.wV: Float[Tensor, "d_model d_head"] = nn.Parameter(
            torch.rand((d_model, d_head))
        )

    @jaxtyped(typechecker=typechecker)
    def forward(
        self, input: Float[Tensor, "*batch_size n_tokens d_model"]
    ) -> Float[Tensor, "*batch_size n_tokens d_head"]:
        Q: Float[Tensor, "*batch_size n_tokens d_head"] = input @ self.wQ
        K: Float[Tensor, "*batch_size n_tokens d_head"] = input @ self.wK
        A: Float[Tensor, "*batch_size n_tokens n_token"] = (
            Q @ K.transpose(dim0=-2, dim1=-1)
        )
        if self.debug:
            display(A)
        masked: Float[Tensor, "*batch_size n_tokens n_tokens"] = (
            A + (torch.ones_like(A) * float("-inf")).triu(diagonal=1)
        )
        if self.debug:
            display(masked)
        scores: Float[Tensor, "*batch_size n_tokens n_token"] = (
            torch.softmax(masked / self.d_model**(1/2), dim=-1)
        )
        if self.debug:
            display(scores)
        V: Float[Tensor, "*batch_size n_tokens d_head"] = input @ self.wV
        output: Float[Tensor, "*batch_size n_tokens d_head"] = scores @ V
        return output

In [289]:
set_seed()
a = torch.ones((1, 8, 4))

In [290]:
class FeedForward(nn.Module):

    def __init__(self, d_model=4, n_tokens=8, mlp_factor=4, debug=False):
        super().__init__()
        self.d_model = d_model
        self.n_tokens = n_tokens
        self.mlp_factor = mlp_factor
        self.debug = debug

        self.encode: Float[Tensor, "d_model d_model*mlp_factor"] = nn.Parameter(
            torch.rand((d_model, d_model*mlp_factor))
        )
        self.decode: Float[Tensor, "d_model*mlp_factor d_model"] = nn.Parameter(
            torch.rand((d_model*mlp_factor, d_model))
        )

    def forward(
            self, input: Float[Tensor, "*batch_size n_tokens d_model"]
    ) -> Float[Tensor, "*batch_size n_tokens d_model"]:
        output: Float[Tensor, "*batch_size n_tokens d_model"] = (
            input @ self.encode @ self.decode 
        )
        return output 

In [291]:
set_seed()
FeedForward().forward(torch.ones((8, 4)))

tensor([[15.1447, 19.8279, 22.9614, 16.4457],
        [15.1447, 19.8279, 22.9614, 16.4457],
        [15.1447, 19.8279, 22.9614, 16.4457],
        [15.1447, 19.8279, 22.9614, 16.4457],
        [15.1447, 19.8279, 22.9614, 16.4457],
        [15.1447, 19.8279, 22.9614, 16.4457],
        [15.1447, 19.8279, 22.9614, 16.4457],
        [15.1447, 19.8279, 22.9614, 16.4457]], grad_fn=<MmBackward0>)

In [292]:
class LayerNorm(nn.Module):

    def __init__(self, d_model=4, n_tokens=8, debug=False):
        super().__init__()
        self.d_model = d_model
        self.n_tokens = n_tokens
        self.debug = debug

    def forward(
        self, input: Float[Tensor, "*batch_size n_tokens d_model"]
    ) -> Float[Tensor, "*batch_size n_tokens d_model"]:
        output = (
            (input - input.mean(dim=-1, keepdim=True))
            /input.std(dim=-1, keepdim=True)
        )
        return output

In [293]:
set_seed()
ln_ed = LayerNorm().forward(torch.rand((8, 4)))
ln_ed.mean(-1), ln_ed.std(-1)

(tensor([ 1.4901e-08,  0.0000e+00,  4.4703e-08,  6.7055e-08, -1.4901e-08,
         -1.1176e-08,  1.0431e-07,  1.1176e-07]),
 tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]))

In [294]:
class HiddenLayer(nn.Module):

    def __init__(self, d_model=4, n_tokens=8, n_head=2, mlp_factor=4, debug=False):
        super().__init__()
        self.d_model = d_model
        self.n_tokens = n_tokens
        self.n_head = n_head
        self.mlp_factor = mlp_factor
        self.debug = debug

        self.ln = LayerNorm(
            d_model=self.d_model, n_tokens=self.n_tokens
        )
        self.heads = nn.ModuleList([
            Head(
                d_model=d_model, n_tokens=n_tokens, n_head=n_head, debug=debug
            ) for _ in range(n_head)
        ])
        self.wO: Float[Tensor, "d_model d_model"] = nn.Parameter(torch.rand((
            d_model, d_model                                                         
        )))
        self.ffw = FeedForward(
            d_model=d_model, n_tokens=n_tokens,
            mlp_factor=mlp_factor, debug=debug
        )

    def forward(
        self, input: Float[Tensor, "*batch_size n_tokens d_model"]
    ) -> Float[Tensor, "*batch_size n_tokens d_model"]:
        normInput = self.ln(input.clone())
        outsLst: list[Float[Tensor, "*batch_size n_token d_model/n_head"]] = [
            h(normInput.clone()).tolist() for h in self.heads
        ]
        outs: Float[Tensor, "*batch_size n_head n_tokens d_model/n_head"] = (
            torch.tensor(outsLst)
        )
        outsTrans: Float[Tensor, "*batch_size n_tokens n_head d_model/n_head"] = (
            outs.transpose(-3, -2).contiguous()
        )
        outsCat: Float[Tensor, "*batch_size n_tokens d_model"] = outsTrans.view(
            tuple([*outsTrans.shape[:-2], self.d_model])
        )
        outsProj: Float[Tensor, "*batch_size n_tokens d_model"] = outsCat @ self.wO
        outsSum: Float[Tensor, "*batch_size n_tokens d_model"] = outsProj + input.clone()
        outsNorm: Float[Tensor, "*batch_size n_tokens d_model"] = self.ln(outsSum)
        output: Float[Tensor, "*batch_size n_tokens d_model"] = self.ffw(outsNorm)
        return output + input


In [296]:
set_seed()
[name for name, _  in list(HiddenLayer().named_parameters())]

['wO',
 'heads.0.wQ',
 'heads.0.wK',
 'heads.0.wV',
 'heads.1.wQ',
 'heads.1.wK',
 'heads.1.wV',
 'ffw.encode',
 'ffw.decode']

In [251]:
class Transformer(nn.Module):

    def __init__(
        self, num_layers=5, d_model=4, n_vocab=16,
        n_tokens=8, n_head=2, mlp_factor=4, debug=False
    ):
        super().__init__()
        self.num_layers = num_layers
        self.d_model = d_model
        self.n_vocab = n_vocab
        self.n_tokens = n_tokens
        self.n_head = n_head
        self.mlp_factor = mlp_factor
        self.debug = debug

        self.embed: Float[Tensor, "n_vocab d_model"] = torch.rand(
            (n_vocab, d_model)
        )
        self.posEmbed = PosEmbed(
            d_model=d_model, n_tokens=n_tokens, debug=debug
        )
        self.hiddenLayers = nn.ModuleList([
            HiddenLayer(
                d_model=d_model, n_tokens=n_tokens, n_head=n_head,
                mlp_factor=mlp_factor, debug=debug
            ) for _ in range(num_layers)
        ])
        self.unembed: Float[Tensor, "d_model n_vocab"] = torch.rand(
            (d_model, n_vocab)
        )
        self.parameters += [
            self.embed, self.hiddenLayers, self.unembed
        ]
    
    def forward(
        self, input: Float[Tensor, "*batch_size n_tokens n_vocab"]
    ) -> Float[Tensor, "*batch_size n_tokens n_vocab"]:
        emb: Float[Tensor, "*batch_size n_tokens d_model"] = self.posEmbed(
            input @ self.embed
        )
        out = emb
        for hl in self.hiddenLayers:
            out = hl(out)
        logits: Float[Tensor, "*batch_size n_tokens n_vocab"] = out @ self.unembed
        return logits

In [243]:
set_seed()
Transformer().forward(torch.rand((8, 16)))
Transformer().forward(torch.rand((2, 8, 16)))

tensor([[[-5.8537e-01, -2.9655e-01, -4.0775e-01, -1.1655e+00, -8.5163e-01,
          -1.8397e+00, -1.0331e+00, -1.0017e+00, -2.2193e+00, -4.9268e-01,
          -1.9748e+00,  2.5402e-02, -8.8963e-01, -1.7210e+00, -1.4942e+00,
          -4.2577e-01],
         [ 1.1494e+00,  1.9409e+00,  8.3373e-01,  9.7188e-01,  6.4855e-01,
           8.5670e-01,  6.6286e-01,  7.1370e-01,  1.2318e+00,  1.8213e+00,
           2.9599e-01,  3.1077e+00,  6.3227e-01, -3.4496e-02,  1.2049e+00,
           1.6290e+00],
         [ 4.3700e+00,  6.9234e+00,  3.8467e+00,  5.0054e+00,  3.9598e+00,
           7.0750e+00,  4.9282e+00,  5.1374e+00,  9.2090e+00,  7.6520e+00,
           6.0487e+00,  9.9352e+00,  4.4578e+00,  4.6415e+00,  7.6291e+00,
           6.4958e+00],
         [ 1.3201e+01,  1.9243e+01,  1.2162e+01,  1.6287e+01,  1.1958e+01,
           2.2672e+01,  1.5373e+01,  1.5418e+01,  2.8978e+01,  2.0808e+01,
           2.0165e+01,  2.6545e+01,  1.3626e+01,  1.5247e+01,  2.3094e+01,
           1.7414e+01],
    

In [247]:
from torch.optim import Adam
class TransformerTrainer():
    
    def __init__(self, model, optim_choice=Adam, epochs=1, lr=1e-7, debug=False):
        self.model = model
        self.optim_choice = optim_choice
        self.epochs=epochs
        self.lr=lr
    
    def getLoss(
        self, batch: Float[Tensor, "*batch_size n_tokens"]
    ):
        inp = F.one_hot(
            batch, num_classes=self.model.n_vocab
        ).to(dtype=torch.float)
        logits: Float[Tensor, "*batch_size n_tokens n_vocab"] = self.model(inp)
        probs = logits.softmax(-1)
        correctProbs : Float[Tensor, "*batch_size n_tokens-1"] = (
            probs[:-1][..., range(probs.shape[-2] - 1), labels[1:]]
        )
        loss: Float[Tensor, "*batch_size n_tokens-1"] = (
            correctProbs.log().neg()
        )
        return loss
    
    def train(
        self, batches: list[Float[Tensor, "*batch_size n_tokens"]]
    ):
        optimizer = self.optim_choice(self.model.parameters(), lr=self.lr) 
        for bath in batches:
            optimizer.zero_grad()
            loss = self.get_loss(batch)
            loss.backward()
            optimizer.step()
        

In [245]:
a = torch.rand((2))
a, a.log()

(tensor([0.6346, 0.0803]), tensor([-0.4548, -2.5219]))

In [246]:
set_seed()
debugModel = Transformer()
labels = torch.randint(0, 16, (8, ))
inp = F.one_hot(labels, num_classes=16).to(dtype=torch.float)
logits = debugModel(inp)
probs = logits.softmax(-1)
correctLogits = probs[:-1][..., range(probs.shape[-2] - 1), labels[1:]]
loss = correctLogits.log().neg()
labels[1:], inp.shape, probs.shape, correctLogits, loss
# torch.tensor([[1, 2], [3, 4]])[..., torch.tensor([1, 0])]

(tensor([ 8, 10,  9, 12, 12,  6, 14]),
 torch.Size([8, 16]),
 torch.Size([8, 16]),
 tensor([0.0620, 0.0263, 0.0490, 0.0066, 0.0124, 0.0199, 0.0300]),
 tensor([2.7811, 3.6381, 3.0168, 5.0167, 4.3897, 3.9162, 3.5056]))

In [250]:
debugModel = Transformer()
debugBatch = [torch.randint(0, 16, (8,)) for _ in range(5)]
# TransformerTrainer(debugModel).train(debugBatch)
debugModel.parameters()

[]