In [2]:
import torch
import torch.nn as nn
from torch.nn import functional as F
%matplotlib inline

In [3]:
batchSize = 32
contextSize = 8
trainingSteps = 3000
evaluationSteps = 200
evalautionInterval = 500
learningRate = 1e-3
device = "cuda" if torch.cuda.is_available() else "cpu"
embeddingSize = 32
nHeads = 4
nBlocks = 6
dropoutRate = 0.3

In [4]:
with open("input.txt", "r", encoding="utf-8") as file:
    text = file.read()

vocabulary = sorted(list(set(text)))
vocabularySize = len(vocabulary)
stoi = {s: i for i, s in enumerate(vocabulary)}
itos = {i: s for i, s in enumerate(vocabulary)}
encode = lambda x: [stoi[s] for s in x]
decode = lambda x: ''.join([itos[i] for i in x])

data = torch.tensor(encode(text), dtype=torch.long)
split = int(len(data)*0.9)
trainData = data[:split]
valData = data[split:]

In [5]:
def getBatch(data):
    indexes = torch.randint(len(data) - contextSize - 1, (batchSize,))
    X = torch.stack([data[i: i+contextSize] for i in indexes])
    Y = torch.stack([data[i + 1: i + 1 + contextSize] for i in indexes])
    X, Y = X.to(device), Y.to(device)
    return X, Y

In [6]:
@torch.no_grad()
def estimateLoss(model, evaluationSteps):
    losses = []
    model.eval()
    for data in [trainData, valData]:
        tempLosses = torch.zeros(evaluationSteps)
        for i in range(evaluationSteps):
            X, Y = getBatch(data)
            _, loss = model(X, Y)
            tempLosses[i] = loss.item()
        losses.append(tempLosses.mean())
    model.train()
    return losses[0], losses[1]

In [7]:
class Head(nn.Module):
    def __init__(self, headSize):
        super().__init__()
        self.key = nn.Linear(embeddingSize, headSize, bias=False)
        self.query = nn.Linear(embeddingSize, headSize, bias=False)
        self.value = nn.Linear(embeddingSize, headSize, bias=False)
        self.register_buffer('mask', torch.tril(torch.ones(contextSize, contextSize)))
        self.dropout = nn.Dropout(dropoutRate)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x) # (B, T, hs)
        q = self.query(x) # (B, T, hs)
        affinities = q @ k.transpose(-2, -1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)
        affinities = affinities.masked_fill(self.mask[:T, :T] == 0, float('-inf'))
        affinities = F.softmax(affinities, dim=-1)
        affinities = self.dropout(affinities)
        v = self.value(x)
        out = affinities @ v
        return out
    
class MultiHeadAttention(nn.Module):
    def __init__(self, headSize):
        super().__init__()
        self.heads = nn.ModuleList([Head(headSize) for _ in range(nHeads)])
        self.projection = nn.Linear(headSize*nHeads, embeddingSize)
        self.dropout = nn.Dropout(dropoutRate)
        
    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim = -1)
        out = self.dropout(self.projection(out))
        return out
    
class FeedForward(nn.Module):
    def __init__(self, embeddingSize):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(embeddingSize, 4*embeddingSize),
            nn.ReLU(),
            nn.Linear(4*embeddingSize, embeddingSize),
            nn.Dropout(dropoutRate),
        )
    
    def forward(self, x):
        return self.net(x)
    
class Block(nn.Module):
    def __init__(self):
        super().__init__()
        headSize = embeddingSize // nHeads
        self.selfAttention = MultiHeadAttention(headSize)
        self.feedforward = FeedForward(embeddingSize)
        self.layernorm1 = nn.LayerNorm(embeddingSize)
        self.layernorm2 = nn.LayerNorm(embeddingSize)

    def forward(self, x):
        x = x + self.selfAttention(self.layernorm1(x))
        x = x + self.feedforward(self.layernorm2(x))
        return x
    
class SmallLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.tokenEmbeddingTable = nn.Embedding(vocabularySize, embeddingSize)
        self.positionEmbeddingTable = nn.Embedding(contextSize, embeddingSize)
        self.blocks = nn.Sequential(*[Block() for _ in range(nBlocks)])
        self.finalLayerNorm = nn.LayerNorm(embeddingSize)
        self.languageModellingHead = nn.Linear(embeddingSize, vocabularySize)

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, data, targets=None):
        B, T = data.shape

        tokenEmbeddings = self.tokenEmbeddingTable(data) # (B,T,C)
        positionEmbeddings = self.positionEmbeddingTable(torch.arange(T, device=device)) # (T,C)
        x = tokenEmbeddings + positionEmbeddings # (B,T,C)
        x = self.blocks(x) # (B,T,C)
        x = self.finalLayerNorm(x) # (B,T,C)
        logits = self.languageModellingHead(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss
    
    def generate(self, sequence, nNewTokens):
        for _ in range(nNewTokens):
            croppedContext = sequence[:, -contextSize:] # indices (B, T) array of indices in the current context
            logits, loss = self(croppedContext)
            logits = logits[:, -1, :] # becomes (B, C), focusing on the last timestep
            probs = F.softmax(logits, dim=-1) # (B, C)
            nextToken = torch.multinomial(probs, num_samples=1) # (B, 1)
            sequence = torch.cat((sequence, nextToken), dim=1) # (B, T+1)
            #print(f"Just generated: {decode(torch.flatten(sequence).tolist())} from context: {decode(torch.flatten(croppedContext).tolist())}")
        return sequence

In [8]:
model = SmallLanguageModel().to(device)
print(f"Number of parameters: {sum(p.numel() for p in model.parameters())/1e3}k")
optimizer = torch.optim.AdamW(model.parameters(), lr=learningRate)

Number of parameters: 80.193k


In [9]:
for i in range(trainingSteps):

    if i % evalautionInterval == 0 or i == trainingSteps - 1:
        trainLoss, valLoss = estimateLoss(model, evaluationSteps)
        print(f"step {i}: train loss {trainLoss:.4f}, val loss {valLoss:.4f}")

    Xb, Yb = getBatch(trainData)

    logits, loss = model(Xb, Yb)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


step 0: train loss 4.1745, val loss 4.1740
step 500: train loss 2.4456, val loss 2.4523
step 1000: train loss 2.3466, val loss 2.3519
step 1500: train loss 2.2758, val loss 2.2815
step 2000: train loss 2.2318, val loss 2.2459
step 2500: train loss 2.1997, val loss 2.2103
step 2999: train loss 2.1575, val loss 2.1836


In [10]:
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(model.generate(context, nNewTokens=500)[0].tolist()))


ARICHESDDARDIMIZESDOBY INGUNI
HASANd earde af des.
Whotce sod nothe,
And nave haud mod lave pored, a to dir the day deachtoes; will youliay, Mve,
A by utle ofy
Word, wplows mof por?
QOUNGA blucies your; neplartty?
Thenme you danst Hland and Sef ofor taurphanweng.
in no wit, and;

Youks duatoun is of my hat ank stay's for:
Ne to o mching ond sulce.

CRUMWe wardes, venses! temoer, thow, tis bemiisd In no god kny pleafonTyevaod'?

lOy nof your beOhthan wall you
Prastatioces
Caiusven, hiBle;
And, me
