In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from beartype import beartype as typechecker
import numpy as np
from jaxtyping import UInt, Float, jaxtyped
from torch import Tensor
import random
import numpy as np

In [3]:
def set_seed():
    torch.manual_seed(4)
    random.seed(42)
    np.random.seed(42)

set_seed()

In [4]:
# Positional embedding as defined by Vaswani et al., 2017.
class PosEmbed(nn.Module):

    def __init__(self, d_model=4, n_tokens=8, debug=False):
        super().__init__()
        self.debug = debug
        self.d_model = d_model
        self.n_tokens = n_tokens

        posEmbedLst = []
        for pos in range(n_tokens):
            row = []
            for dim in range(d_model):
                if dim % 2 == 0:
                    row += [self.sin(pos, dim)]
                else:
                    row += [self.cos(pos, dim)]
            
            posEmbedLst.append(row.copy())

        self.posEmbed : Float[Tensor, "n_tokens d_model"] = nn.Parameter((
            torch.tensor(posEmbedLst)
        ))
        if debug:
            print(f"{self.posEmbed=}")

    def sin(self, pos, evenDim):
        return torch.sin(
            torch.tensor(pos/self.n_tokens**(evenDim/self.d_model))
        )
    
    def cos(self, pos, oddDim):
        return torch.cos(
            torch.tensor(pos/self.n_tokens**(oddDim/self.d_model))
        )
    
    @jaxtyped(typechecker=typechecker)
    def forward(
        self, input: Float[Tensor, "*batch_size n_tokens d_model"]
    ) -> Float[Tensor, "*batch_size n_tokens d_model"]:
        # I actually dont need and input!
        r = input + self.posEmbed
        if self.debug:
            display(f"{r=}")
        return r

In [5]:
a = torch.ones((8, 4))
PosEmbed().forward(a)

tensor([[1.0000, 2.0000, 1.0000, 2.0000],
        [1.8415, 1.8284, 1.3462, 1.9780],
        [1.9093, 1.3724, 1.6496, 1.9129],
        [1.1411, 0.7886, 1.8727, 1.8076],
        [0.2432, 0.2774, 1.9878, 1.6668],
        [0.0411, 0.0142, 1.9807, 1.4966],
        [0.7206, 0.0894, 1.8523, 1.3045],
        [1.6570, 0.4772, 1.6184, 1.0991]], grad_fn=<AddBackward0>)

In [6]:
# Input shape: n_tokens x d_model
class Head(nn.Module):

    def __init__(self, d_model=4, n_tokens=8, n_head=2, debug=False):
        super().__init__()
        self.d_model = d_model
        self.n_tokens = n_tokens
        self.n_head = n_head
        self.debug = debug

        self.d_head = d_head = d_model // n_head
        self.wQ: Float[Tensor, "d_model d_head"] = nn.Parameter(
            torch.rand((d_model, d_head))
        )
        self.wK: Float[Tensor, "d_model d_head"] = nn.Parameter(
            torch.rand((d_model, d_head))
        )
        self.wV: Float[Tensor, "d_model d_head"] = nn.Parameter(
            torch.rand((d_model, d_head))
        )

    @jaxtyped(typechecker=typechecker)
    def forward(
        self, input: Float[Tensor, "*batch_size n_tokens d_model"]
    ) -> Float[Tensor, "*batch_size n_tokens d_head"]:
        Q: Float[Tensor, "*batch_size n_tokens d_head"] = input @ self.wQ
        K: Float[Tensor, "*batch_size n_tokens d_head"] = input @ self.wK
        A: Float[Tensor, "*batch_size n_tokens n_token"] = (
            Q @ K.transpose(dim0=-2, dim1=-1)
        )
        if self.debug:
            display(A)
        masked: Float[Tensor, "*batch_size n_tokens n_tokens"] = (
            A + (torch.ones_like(A) * float("-inf")).triu(diagonal=1)
        )
        if self.debug:
            display(masked)
        scores: Float[Tensor, "*batch_size n_tokens n_token"] = (
            torch.softmax(masked / self.d_model**(1/2), dim=-1)
        )
        if self.debug:
            display(scores)
        V: Float[Tensor, "*batch_size n_tokens d_head"] = input @ self.wV
        output: Float[Tensor, "*batch_size n_tokens d_head"] = scores @ V
        return output

In [7]:
set_seed()
a = torch.ones((1, 8, 4))

In [8]:
class FeedForward(nn.Module):

    def __init__(self, d_model=4, n_tokens=8, mlp_factor=4, debug=False):
        super().__init__()
        self.d_model = d_model
        self.n_tokens = n_tokens
        self.mlp_factor = mlp_factor
        self.debug = debug

        self.encode: Float[Tensor, "d_model d_model*mlp_factor"] = nn.Parameter(
            torch.rand((d_model, d_model*mlp_factor))
        )
        self.decode: Float[Tensor, "d_model*mlp_factor d_model"] = nn.Parameter(
            torch.rand((d_model*mlp_factor, d_model))
        )

    def forward(
            self, input: Float[Tensor, "*batch_size n_tokens d_model"]
    ) -> Float[Tensor, "*batch_size n_tokens d_model"]:
        output: Float[Tensor, "*batch_size n_tokens d_model"] = (
            input @ self.encode @ self.decode 
        )
        return output 

In [9]:
set_seed()
FeedForward().forward(torch.ones((8, 4)))

tensor([[15.1447, 19.8279, 22.9614, 16.4457],
        [15.1447, 19.8279, 22.9614, 16.4457],
        [15.1447, 19.8279, 22.9614, 16.4457],
        [15.1447, 19.8279, 22.9614, 16.4457],
        [15.1447, 19.8279, 22.9614, 16.4457],
        [15.1447, 19.8279, 22.9614, 16.4457],
        [15.1447, 19.8279, 22.9614, 16.4457],
        [15.1447, 19.8279, 22.9614, 16.4457]], grad_fn=<MmBackward0>)

In [10]:
class LayerNorm(nn.Module):

    def __init__(self, d_model=4, n_tokens=8, debug=False):
        super().__init__()
        self.d_model = d_model
        self.n_tokens = n_tokens
        self.debug = debug

    def forward(
        self, input: Float[Tensor, "*batch_size n_tokens d_model"]
    ) -> Float[Tensor, "*batch_size n_tokens d_model"]:
        output = (
            (input - input.mean(dim=-1, keepdim=True))
            /input.std(dim=-1, keepdim=True)
        )
        return output

In [11]:
set_seed()
ln_ed = LayerNorm().forward(torch.rand((8, 4)))
ln_ed.mean(-1), ln_ed.std(-1)

(tensor([ 1.4901e-08,  0.0000e+00,  4.4703e-08,  6.7055e-08, -1.4901e-08,
         -1.1176e-08,  1.0431e-07,  1.1176e-07]),
 tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]))

In [12]:
class HiddenLayer(nn.Module):

    def __init__(self, d_model=4, n_tokens=8, n_head=2, mlp_factor=4, debug=False):
        super().__init__()
        self.d_model = d_model
        self.n_tokens = n_tokens
        self.n_head = n_head
        self.mlp_factor = mlp_factor
        self.debug = debug

        self.ln = LayerNorm(
            d_model=self.d_model, n_tokens=self.n_tokens
        )
        self.heads = nn.ModuleList([
            Head(
                d_model=d_model, n_tokens=n_tokens, n_head=n_head, debug=debug
            ) for _ in range(n_head)
        ])
        self.wO: Float[Tensor, "d_model d_model"] = nn.Parameter(torch.rand((
            d_model, d_model                                                         
        )))
        self.ffw = FeedForward(
            d_model=d_model, n_tokens=n_tokens,
            mlp_factor=mlp_factor, debug=debug
        )

    def forward(
        self, input: Float[Tensor, "*batch_size n_tokens d_model"]
    ) -> Float[Tensor, "*batch_size n_tokens d_model"]:
        normInput = self.ln(input.clone())
        outsLst: list[Float[Tensor, "*batch_size n_token d_model/n_head"]] = [
            h(normInput.clone()).tolist() for h in self.heads
        ]
        outs: Float[Tensor, "n_head *batch_size n_tokens d_model/n_head"] = (
            torch.tensor(outsLst)
        )
        for i in range(0, len(outs.shape) - 3):
            outs = outs.transpose(i, i+1)
        outsTrans: Float[Tensor, "*batch_size n_tokens n_head d_model/n_head"] = (
            outs.transpose(-3, -2).contiguous()
        )
        outsCat: Float[Tensor, "*batch_size n_tokens d_model"] = outsTrans.view(
            tuple([*outsTrans.shape[:-2], self.d_model])
        )
        outsProj: Float[Tensor, "*batch_size n_tokens d_model"] = outsCat @ self.wO
        outsSum: Float[Tensor, "*batch_size n_tokens d_model"] = outsProj + input.clone()
        outsNorm: Float[Tensor, "*batch_size n_tokens d_model"] = self.ln(outsSum)
        output: Float[Tensor, "*batch_size n_tokens d_model"] = self.ffw(outsNorm)
        return output + input


In [13]:
set_seed()
[name for name, _  in list(HiddenLayer().named_parameters())]

['wO',
 'heads.0.wQ',
 'heads.0.wK',
 'heads.0.wV',
 'heads.1.wQ',
 'heads.1.wK',
 'heads.1.wV',
 'ffw.encode',
 'ffw.decode']

In [14]:
class Transformer(nn.Module):

    def __init__(
        self, num_layers=5, d_model=4, n_vocab=16,
        n_tokens=8, n_head=2, mlp_factor=4, debug=False
    ):
        super().__init__()
        self.num_layers = num_layers
        self.d_model = d_model
        self.n_vocab = n_vocab
        self.n_tokens = n_tokens
        self.n_head = n_head
        self.mlp_factor = mlp_factor
        self.debug = debug

        self.embed: Float[Tensor, "n_vocab d_model"] = nn.Parameter(
            torch.rand((n_vocab, d_model))
        )
        # self.embed[-1, :] = torch.tensor([float("-inf")])
        self.posEmbed = PosEmbed(
            d_model=d_model, n_tokens=n_tokens, debug=debug
        )
        self.hiddenLayers = nn.ModuleList([
            HiddenLayer(
                d_model=d_model, n_tokens=n_tokens, n_head=n_head,
                mlp_factor=mlp_factor, debug=debug
            ) for _ in range(num_layers)
        ])
        self.unembed: Float[Tensor, "d_model n_vocab"] = nn.Parameter(
            torch.rand((d_model, n_vocab))
        )
    
    def forward(
        self, input: Float[Tensor, "*batch_size n_tokens n_vocab"]
    ) -> Float[Tensor, "*batch_size n_tokens n_vocab"]:
        mask: Float[Tensor, "n_vocab d_model"] = torch.ones_like(self.embed)
        mask[-1, :] = float("0")
        maskedEmbed: Float[Tensor, "n_vocab d_model"] = self.embed *  mask
        emb: Float[Tensor, "*batch_size n_tokens d_model"] = self.posEmbed(
            input @ maskedEmbed
        )
        out: Float[Tensor, "*batch_size n_tokens d_model"] = emb
        for hl in self.hiddenLayers:
            out = hl(out)
        logits: Float[Tensor, "*batch_size n_tokens n_vocab"] = out @ self.unembed
        return logits
    
    def inference(
        self, lst: list[int]
    ):
        tsr = torch.tensor([lst[i] if i < len(lst) else 15 for i in range(8)])
        inp = F.one_hot(
            tsr, num_classes=self.n_vocab
        ).to(dtype=torch.float)
        return self.forward(inp).argmax(-1)

In [15]:
set_seed()
# Transformer().forward(torch.zeros((8, 16)))
# Transformer().forward(torch.rand((2, 8, 16)))
Transformer().inference([1, 3, 2, 2])

tensor([ 9,  0,  7,  9,  9,  9,  9, 11])

In [16]:
# [(name, t.shape) for name, t in Transformer().named_parameters()]

In [30]:
from torch.optim import Adam
import wandb

class TransformerTrainer():
    
    def __init__(
        self, model, optim_choice=Adam, epochs=1, lr=1e-7,
        batch_size=4, debug=False
    ):
        self.model = model
        self.optim_choice = optim_choice
        self.epochs = epochs
        self.lr=lr
        self.batch_size = batch_size
        self.debug = debug
    
    def getLoss(
        self, batch: Float[Tensor, "*batch_size n_tokens"]
    ):
        inp = F.one_hot(
            batch, num_classes=self.model.n_vocab,
        ).to(dtype=torch.float)
        logits: Float[Tensor, "*batch_size n_tokens n_vocab"] = self.model(inp)
        log_probs = logits.log_softmax(-1)
        correct_log_probs = log_probs[..., :-1,:].gather(
                dim=-1, index=batch[..., 1:].unsqueeze(-1)
        ).squeeze(-1)
        loss = log_probs.neg()
        
        return loss
        
    
    def train(
        self, batches: list[Float[Tensor, "n_tokens"]]
    ):
        # wandb.init()
        batches = [
            torch.stack(batches[i:i+self.batch_size], dim=0)
            for i in range(0, len(batches), self.batch_size)
        ]
        optimizer = self.optim_choice(self.model.parameters(), lr=self.lr) 
        for i, batch in enumerate(batches):
            optimizer.zero_grad()
            loss = self.getLoss(batch)
            loss.backward()
            optimizer.step()
            if i % 100 == 0: print(f"{loss=}")
            # wandb.log({
            #     "loss": loss
            # })
        

In [31]:
TransformerTrainer(Transformer()).getLoss(torch.randint(0, 16, (8, )))

16
inp.shape=torch.Size([8, 16])
tensor([15,  9,  8, 10,  6, 15,  6])
tensor([[ -3.0457,  -3.8715,  -3.3720,  -1.8847,  -5.2333,  -3.0917,  -2.9264,
          -3.0351,  -3.3404,  -2.1282,  -2.9633,  -2.8837,  -2.1623,  -2.3566,
          -2.2443,  -4.1902],
        [ -4.0541,  -4.2653,  -4.5126,  -1.3380,  -6.9582,  -4.0187,  -2.7762,
          -3.3060,  -3.5709,  -2.0323,  -3.8455,  -3.1923,  -2.4218,  -1.9716,
          -2.1190,  -4.8727],
        [ -3.9446,  -4.3051,  -2.1538,  -3.4117,  -4.8789,  -1.6333,  -3.1526,
          -4.5522,  -3.9608,  -3.4527,  -2.1557,  -3.0251,  -1.4247,  -3.1890,
          -2.9318,  -4.4430],
        [ -4.1827,  -6.4136,  -3.0013, -11.6580,  -0.1088,  -3.8150,  -9.4205,
          -7.2648,  -7.6408,  -9.2219,  -4.7358,  -7.2731,  -7.6362, -10.7488,
          -9.4079,  -5.7923],
        [ -3.9642,  -6.6282,  -3.3176, -11.6749,  -0.0875,  -4.3227,  -9.6471,
          -7.2775,  -7.6907,  -9.3055,  -4.8339,  -7.4193,  -7.9271, -10.8712,
          -9.5650,  

In [582]:
debugModel = Transformer(num_layers=5, d_model=8, n_vocab=16, n_head=4)
# debugModel = Transformer()
for _ in range(1):
    display(debugModel.inference([0,1,2,3,4]))
    pass
debugBatch = [torch.tensor([(x + i) % 15 for i in range(8)], dtype=torch.int64) for x in range(9990)]
# display(debugBatch[:10])
trainer = TransformerTrainer(debugModel, epochs=1, lr=1e-3)
trainer.train(debugBatch)
print("after\n")
for _ in range(1):
    display(debugModel.inference([0,1,2,3,4]))
    display(debugModel.inference([3,4,5,6,7]))
# debugBatch
# debugModel.parameters()

tensor([ 2,  0,  3, 12, 13,  3,  2,  2])

loss=tensor(6.8083, grad_fn=<MeanBackward0>)
loss=tensor(2.7811, grad_fn=<MeanBackward0>)
loss=tensor(2.7604, grad_fn=<MeanBackward0>)
loss=tensor(2.4596, grad_fn=<MeanBackward0>)
loss=tensor(2.4979, grad_fn=<MeanBackward0>)
loss=tensor(2.4074, grad_fn=<MeanBackward0>)
loss=tensor(2.1470, grad_fn=<MeanBackward0>)
loss=tensor(2.3901, grad_fn=<MeanBackward0>)
loss=tensor(2.2935, grad_fn=<MeanBackward0>)
loss=tensor(2.0583, grad_fn=<MeanBackward0>)
loss=tensor(2.1675, grad_fn=<MeanBackward0>)
loss=tensor(2.1607, grad_fn=<MeanBackward0>)
loss=tensor(1.9552, grad_fn=<MeanBackward0>)
loss=tensor(1.9228, grad_fn=<MeanBackward0>)
loss=tensor(1.9755, grad_fn=<MeanBackward0>)
loss=tensor(1.9383, grad_fn=<MeanBackward0>)
loss=tensor(1.9066, grad_fn=<MeanBackward0>)
loss=tensor(1.9366, grad_fn=<MeanBackward0>)
loss=tensor(1.9347, grad_fn=<MeanBackward0>)
loss=tensor(1.9064, grad_fn=<MeanBackward0>)
loss=tensor(1.9182, grad_fn=<MeanBackward0>)
loss=tensor(1.9307, grad_fn=<MeanBackward0>)
loss=tenso

tensor([ 1,  2,  3,  4,  5,  4,  4, 13])

tensor([ 4,  5,  6,  7,  8,  6,  6, 14])

In [None]:
# for name, t in debugModel.named_parameters():
#     display(f"{name}, {t}")