In [1]:
from dataclasses import dataclass
import torch
import torch.nn as nn
import numpy as np
import requests
import re

@dataclass
class Config:
    d_model:int
    d_vocab:int
    d_hidden:int
    max_seq_len:int
    numTrans:int

In [None]:
class MLP(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.fc1 = nn.Linear(config.d_model, config.d_hidden)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(config.d_hidden, config.d_model)

    def forward(self, x):
        x = self.fc2(self.act(self.fc1(x)))
        return x

class Attention(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.Wqk = nn.Parameter(torch.rand(config.d_model, config.d_model))
        self.Wov = nn.Parameter(torch.rand(config.d_model, config.d_model))

        mask = torch.triu(torch.ones(config.max_seq_len, config.max_seq_len),diagonal=1)
        mask = mask.masked_fill(mask==1, -float('inf'))
        self.register_buffer("M", mask)
        
    def forward(self, x):
        T = x.size(0)
        temp = x @ self.Wqk @ x.T + self.M[:T, :T]
        scores = torch.softmax(temp, dim=1)

        scores = scores @ x @ self.Wov

        return scores

class Transformer(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.attn = Attention(config)
        self.mlp = MLP(config)
        self.ln1 = nn.LayerNorm(config.d_model)
        self.ln2 = nn.LayerNorm(config.d_model)

    def forward(self, x):
        x_norm = self.ln1(x)
        attn_out = self.attn(x_norm)
        x = x+attn_out
        x_norm = self.ln2(x)
        mlp_out = self.mlp(x_norm)
        x = x+mlp_out
        return x
    
class LanguageModel(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.config = config
        self.embedding = nn.Embedding(config.d_vocab, config.d_model)
        self.tbs = nn.ModuleList([Transformer(config) for i in range(self.config.numTrans)])
        self.lm_head = nn.Linear(config.d_model, config.d_vocab)

    def forward(self, x_tokens):
        x = self.embedding(x_tokens)
        temp = x
        for i in range(self.config.numTrans):
            temp = self.tbs[i](temp)

        logits = self.lm_head(temp)
        return logits

In [5]:
url = "https://www.gutenberg.org/files/1342/1342-0.txt" # Just a demo dataset, we can think about more creative datasets.
r = requests.get(url)
text = r.text

In [9]:
text = text.lower().replace("\n", " ")
tokens = text.split()
tokens = re.findall(r"\b\w+\b", text.lower())

vocab = list(set(tokens))
vocab.sort()

token2id = {token: idx for idx, token in enumerate(vocab)}
id2token = {idx: tok for tok, idx in token2id.items()}

print(len(vocab))

7030


In [15]:
config = Config(d_model=512, d_vocab=len(vocab), d_hidden=256, max_seq_len=2048, numTrans=10)

token_ids = [token2id[tok] for tok in tokens]

print(len(token_ids))

128769


In [None]:
model = LanguageModel(config)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

for step in range(1000):
    start = np.random.randint(0, len(token_ids) - config.max_seq_len - 1)
    x_ids = torch.tensor(token_ids[start:start+config.max_seq_len])
    y_ids = torch.tensor(token_ids[start+1:start+config.max_seq_len+1])
    logits = model(x_ids)
    targets = y_ids
    loss = loss_fn(logits, targets)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if step % 50 == 0:
        print(f"Step {step}: Loss is {loss.item}")


torch.Size([2048, 512])


IndexError: Target 6287 is out of bounds.

In [None]:
max_num_tokens = 50
prompt_text