In [1]:
import torch
from torch import tensor, nn, optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR

In [2]:
with open("data/input.txt") as f:
    text = f.read()

In [3]:
from transformers import AutoTokenizer

model_name = "openai-community/gpt2"

tokenizer = AutoTokenizer.from_pretrained(model_name)

In [4]:
from datasets import load_dataset

ds = load_dataset("myothiha/jokes")

In [5]:
len(set(ds['train']['text']))

187641

In [6]:
text = '<|endoftext|>'.join(ds['train']['text'])

In [7]:
vocab_size = tokenizer.vocab_size
context_length = 32
n_embs = 128
n_heads = 16
n_blocks = 8
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [8]:
len(text)

20160500

In [9]:
class TextDataLoader():
    def __init__(self, text, context_length, tokenizer, batch_size=1, device='cpu'):
        v = tokenizer(text, return_tensors='pt')
        self.tokens = v.input_ids # The attention mask will be handled manually later
        self.batch_size = batch_size
        self.context_length = context_length
        self.device = device
        
        self.position = 0
        
    def __iter__(self):
        self.reset()
        return self
        
    def __next__(self):
        B, T = self.batch_size, self.context_length
        if self.position + B * T + 1 < len(self.tokens[0]):
            tokens = self.tokens[0][self.position: self.position + B * T + 1]
            self.position += B * T + 1
            x = tokens[:-1].view(B, T)
            y = tokens[1:].view(B, T)
            return x.to(self.device), y.to(self.device)
        else:
            raise StopIteration
            
    def __len__(self):
        return len(self.tokens[0]) // (self.context_length + 1) // self.batch_size
    
    def reset(self):
        self.position = 0
        

The `head_size` matches that of the embeddings if it is a single head.

If multi-headed attention is used, then the `head_size` would equal number of embeddings // number of heads.

In [10]:
class SelfAttentionHead(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.k = nn.Linear(n_embs, head_size, bias=False)
        self.q = nn.Linear(n_embs, head_size, bias=False)
        self.v = nn.Linear(n_embs, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(context_length, context_length)))
        
    def forward(self, x):
        B, T, C = x.shape
        k = self.k(x) # (B, T, head_size)
        q = self.q(x) # (B, T, head_size)
        v = self.v(x) # (B, T, head_size)
        
#         attn = q @ k.transpose(-2, -1) * C ** -0.5 # (B, T, T)
        
#         wei = attn.masked_fill(self.tril[:T, :T] == 0, float("-inf")) # Causal masking, blocks future tokens from being seen
#         wei = wei.softmax(dim=-1) # (B, T, T)
        
#         out = wei @ v # (B, T, head_size)
        out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        return out

In [11]:
class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads):
        super().__init__()
        self.heads = nn.ModuleList([SelfAttentionHead(n_embs // n_heads) for i in range(n_heads)])
        self.proj = nn.Linear(n_embs, n_embs)
        
    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.proj(out)
        return out

In [12]:
class FeedForward(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embs, n_embs * 4),
            nn.GELU(),
            nn.Linear(n_embs * 4, n_embs)
        )
        
    def forward(self, x):
        out = self.net(x)
        return out

In [13]:
class TransformerBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embs)
        self.mha = MultiHeadAttention(n_heads)
        self.ln2 = nn.LayerNorm(n_embs)
        self.ffwd = FeedForward()
        
    def forward(self, x):
        x = self.mha(self.ln1(x)) + x
        x = self.ffwd(self.ln2(x)) + x
        
        return x
        

In [14]:
class GPT(nn.Module):
    def __init__(self):
        super().__init__()
        self.tk_emb = nn.Embedding(vocab_size, n_embs)
        self.pos_emb = nn.Embedding(context_length, n_embs)
        self.blocks = nn.Sequential(*[TransformerBlock() for i in range(n_blocks)])
        self.ln_f = nn.LayerNorm(n_embs)
        self.fc = nn.Linear(n_embs, vocab_size)
        
        self.tk_emb.weight = self.fc.weight
        
    def forward(self, x, targets=None):
        B, T = x.shape
        tk_emb = self.tk_emb(x) # (B, T, C)
        pos_tns = torch.arange(T, device=device) # T
        pos_emb = self.pos_emb(pos_tns) # (T, C)
        x = pos_emb + tk_emb # (B, T, C) + (T, C)
        
        x = self.blocks(x)
        
        x = self.ln_f(x)
        logits = self.fc(x) # (B, T, vocab_size)
        if targets is None:
            return logits
        else:
            loss = F.cross_entropy(logits.view(B * T, -1), targets.view(B*T))
            return logits, loss
        
    def generate(self):
        pass

In [15]:
def initialize(mod):
    if isinstance(mod, nn.Linear):
        torch.nn.init.kaiming_normal_(mod.weight)
        if mod.bias is not None:
            torch.nn.init.zeros_(mod.bias)

In [16]:
torch.set_float32_matmul_precision('high')


In [27]:
dl = TextDataLoader(text, context_length, tokenizer, batch_size=32, device=device)

In [28]:
len(dl), len(dl.tokens[0])

(4425, 4673544)

In [29]:
model = GPT().to(device)
model.apply(initialize)
lr = 7e-4
opt = optim.AdamW(model.parameters(), lr)

sum(p.numel() for p in model.parameters())

8070609

In [24]:
xb, yb = next(iter(dl))
model(x=xb, targets=yb)

(tensor([[[ 2.8622, -0.5032, -0.6921,  ..., -0.0037,  2.5393, -0.5748],
          [ 1.6149, -0.0974,  0.0567,  ...,  0.3497,  2.0110, -1.3228],
          [ 1.3989,  0.4020, -0.7424,  ...,  0.2187,  2.0540, -0.8872],
          ...,
          [ 0.7197,  1.0320, -0.9566,  ...,  0.9805,  3.1071,  0.2059],
          [ 2.3216,  0.5853,  0.4516,  ...,  0.8154,  2.2397,  0.2956],
          [ 2.8537,  0.1806, -0.0131,  ...,  0.7891,  2.1111, -0.5451]]],
        device='cuda:0', grad_fn=<ViewBackward0>),
 tensor(11.9861, device='cuda:0', grad_fn=<NllLossBackward0>))

In [30]:
model

GPT(
  (tk_emb): Embedding(50257, 128)
  (pos_emb): Embedding(32, 128)
  (blocks): Sequential(
    (0): TransformerBlock(
      (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (mha): MultiHeadAttention(
        (heads): ModuleList(
          (0-15): 16 x SelfAttentionHead(
            (k): Linear(in_features=128, out_features=8, bias=False)
            (q): Linear(in_features=128, out_features=8, bias=False)
            (v): Linear(in_features=128, out_features=8, bias=False)
          )
        )
        (proj): Linear(in_features=128, out_features=128, bias=True)
      )
      (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (ffwd): FeedForward(
        (net): Sequential(
          (0): Linear(in_features=128, out_features=512, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=512, out_features=128, bias=True)
        )
      )
    )
    (1): TransformerBlock(
      (ln1): LayerNorm((128,), eps=1e-05, elementwise

In [None]:
# Training loop for pre-training
epochs = 1
sched = CosineAnnealingLR(opt, epochs * len(dl), lr * 0.5)
for i in range(epochs):
    for step, (xb, yb) in enumerate(dl):
        opt.zero_grad()
        with torch.autocast(device_type=device, dtype=torch.bfloat16):
            logits, loss = model(xb, yb)
        loss.backward()
        opt.step()
        sched.step()
        if step % (len(dl)//10) == 0 or step == len(dl):
            print(f"Epoch: {i}, Step: {step}, Loss: {loss}")

Epoch: 0, Step: 0, Loss: 11.863006591796875
Epoch: 0, Step: 442, Loss: 6.758306980133057
Epoch: 0, Step: 884, Loss: 5.988088607788086
Epoch: 0, Step: 1326, Loss: 5.640953540802002
Epoch: 0, Step: 1768, Loss: 5.642059326171875
Epoch: 0, Step: 2210, Loss: 5.361952304840088
Epoch: 0, Step: 2652, Loss: 5.21062707901001


In [None]:
def generate(idx, max_tokens):
    model.eval()
    tokens = idx
    for i in range(max_tokens):
        logits = model(tokens[:, -context_length:])
        topk_values, topk_indices = torch.topk(logits[:, -1, :], 50)
        probs = topk_values.softmax(dim=-1)
        sample = torch.multinomial(probs, 1)
        token = torch.gather(topk_indices, 1, sample)
        tokens = torch.cat((tokens, token), dim=-1)
    
    return tokens

In [None]:
start_tokens = tokenizer.encode("What do you call", return_tensors='pt').to(device)
print(tokenizer.decode(generate(start_tokens, 2000)[0].tolist()))

In [None]:
# torch.save(model.state_dict(), "models/jokes_transformer")