In [1]:
import torch
from torch import tensor, nn, optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR

In [3]:
from transformers import AutoTokenizer

model_name = "openai-community/gpt2"

tokenizer = AutoTokenizer.from_pretrained(model_name)

In [4]:
from datasets import load_dataset

ds = load_dataset("myothiha/jokes")

In [5]:
len(set(ds['train']['text']))

187641

In [6]:
text = '<|endoftext|>'.join(ds['train']['text'])

In [7]:
vocab_size = tokenizer.vocab_size
context_length = 32
n_embs = 128
n_heads = 16
n_blocks = 8
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [8]:
len(text)

20160500

In [9]:
class TextDataLoader():
    def __init__(self, text, context_length, tokenizer, batch_size=1, device='cpu'):
        v = tokenizer(text, return_tensors='pt')
        self.tokens = v.input_ids # The attention mask will be handled manually later
        self.batch_size = batch_size
        self.context_length = context_length
        self.device = device
        
        self.position = 0
        
    def __iter__(self):
        self.reset()
        return self
        
    def __next__(self):
        B, T = self.batch_size, self.context_length
        if self.position + B * T + 1 < len(self.tokens[0]):
            tokens = self.tokens[0][self.position: self.position + B * T + 1]
            self.position += B * T + 1
            x = tokens[:-1].view(B, T)
            y = tokens[1:].view(B, T)
            return x.to(self.device), y.to(self.device)
        else:
            raise StopIteration
            
    def __len__(self):
        return len(self.tokens[0]) // (self.context_length + 1) // self.batch_size
    
    def reset(self):
        self.position = 0
        

The `head_size` matches that of the embeddings if it is a single head.

If multi-headed attention is used, then the `head_size` would equal number of embeddings // number of heads.

In [10]:
class SelfAttentionHead(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.k = nn.Linear(n_embs, head_size, bias=False)
        self.q = nn.Linear(n_embs, head_size, bias=False)
        self.v = nn.Linear(n_embs, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(context_length, context_length)))
        
    def forward(self, x):
        B, T, C = x.shape
        k = self.k(x) # (B, T, head_size)
        q = self.q(x) # (B, T, head_size)
        v = self.v(x) # (B, T, head_size)
        
#         attn = q @ k.transpose(-2, -1) * C ** -0.5 # (B, T, T)
        
#         wei = attn.masked_fill(self.tril[:T, :T] == 0, float("-inf")) # Causal masking, blocks future tokens from being seen
#         wei = wei.softmax(dim=-1) # (B, T, T)
        
#         out = wei @ v # (B, T, head_size)
        out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        return out

In [11]:
class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads):
        super().__init__()
        self.heads = nn.ModuleList([SelfAttentionHead(n_embs // n_heads) for i in range(n_heads)])
        self.proj = nn.Linear(n_embs, n_embs)
        
    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.proj(out)
        return out

In [12]:
class FeedForward(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embs, n_embs * 4),
            nn.GELU(),
            nn.Linear(n_embs * 4, n_embs)
        )
        
    def forward(self, x):
        out = self.net(x)
        return out

In [13]:
class TransformerBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embs)
        self.mha = MultiHeadAttention(n_heads)
        self.ln2 = nn.LayerNorm(n_embs)
        self.ffwd = FeedForward()
        
    def forward(self, x):
        x = self.mha(self.ln1(x)) + x
        x = self.ffwd(self.ln2(x)) + x
        
        return x
        

In [14]:
class GPT(nn.Module):
    def __init__(self):
        super().__init__()
        self.tk_emb = nn.Embedding(vocab_size, n_embs)
        self.pos_emb = nn.Embedding(context_length, n_embs)
        self.blocks = nn.Sequential(*[TransformerBlock() for i in range(n_blocks)])
        self.ln_f = nn.LayerNorm(n_embs)
        self.fc = nn.Linear(n_embs, vocab_size)
        
        self.tk_emb.weight = self.fc.weight
        
    def forward(self, x, targets=None):
        B, T = x.shape
        tk_emb = self.tk_emb(x) # (B, T, C)
        pos_tns = torch.arange(T, device=device) # T
        pos_emb = self.pos_emb(pos_tns) # (T, C)
        x = pos_emb + tk_emb # (B, T, C) + (T, C)
        
        x = self.blocks(x)
        
        x = self.ln_f(x)
        logits = self.fc(x) # (B, T, vocab_size)
        if targets is None:
            return logits
        else:
            loss = F.cross_entropy(logits.view(B * T, -1), targets.view(B*T))
            return logits, loss
        
    def generate(self):
        pass

In [15]:
def initialize(mod):
    if isinstance(mod, nn.Linear):
        torch.nn.init.kaiming_normal_(mod.weight)
        if mod.bias is not None:
            torch.nn.init.zeros_(mod.bias)

In [16]:
torch.set_float32_matmul_precision('high')


In [27]:
dl = TextDataLoader(text, context_length, tokenizer, batch_size=32, device=device)

In [34]:
len(dl), len(dl.tokens[0])

(4425, 4673544)

In [35]:
model = GPT().to(device)
model.apply(initialize)
lr = 7e-4
opt = optim.AdamW(model.parameters(), lr)

sum(p.numel() for p in model.parameters())

8070609

In [36]:
xb, yb = next(iter(dl))
model(x=xb, targets=yb)

(tensor([[[-0.1014,  0.4470,  3.4385,  ..., -1.4988,  2.6929, -0.9296],
          [ 0.4989,  0.4790,  1.0835,  ..., -2.0187,  3.2798,  0.0661],
          [ 0.2977,  1.6409,  0.7137,  ..., -1.7743,  3.7793, -0.5917],
          ...,
          [ 1.2270,  1.0936,  0.1648,  ..., -1.2649,  2.5656, -1.3216],
          [ 0.6957,  0.4822, -1.5492,  ..., -1.6905,  3.2747, -0.7123],
          [ 1.9623,  0.7844,  0.5086,  ..., -1.9630,  1.8272,  0.3836]],
 
         [[-0.2554,  0.1965,  3.8129,  ..., -1.3376,  2.5233, -1.2478],
          [ 0.7708,  0.3430,  1.2166,  ..., -1.5564,  3.2044, -0.0356],
          [ 0.5501,  1.3278,  0.8737,  ..., -1.6529,  3.5425, -0.5093],
          ...,
          [ 1.9153,  0.8677,  0.9645,  ..., -0.9940,  2.0926, -1.9137],
          [ 0.2208,  0.4363, -0.2117,  ..., -1.3752,  2.6518, -1.7363],
          [ 1.9874,  0.9801,  1.0197,  ..., -1.5811,  1.9959, -0.2106]],
 
         [[-0.2723,  0.3089,  3.5559,  ..., -1.5968,  2.6581, -1.1554],
          [ 0.7796,  0.2458,

In [None]:
# Training loop for pre-training
epochs = 5
sched = CosineAnnealingLR(opt, epochs * len(dl), lr * 0.01)
for i in range(epochs):
    for step, (xb, yb) in enumerate(dl):
        opt.zero_grad()
        with torch.autocast(device_type=device, dtype=torch.bfloat16):
            logits, loss = model(xb, yb)
        loss.backward()
        opt.step()
        sched.step()
        if step % (len(dl)//10) == 0 or step == len(dl):
            print(f"Epoch: {i}, Step: {step}, Loss: {loss}")

Epoch: 0, Step: 0, Loss: 11.87548828125
Epoch: 0, Step: 442, Loss: 6.692626476287842
Epoch: 0, Step: 884, Loss: 5.980865001678467
Epoch: 0, Step: 1326, Loss: 5.669127464294434
Epoch: 0, Step: 1768, Loss: 5.592933177947998
Epoch: 0, Step: 2210, Loss: 5.366819381713867
Epoch: 0, Step: 2652, Loss: 5.201066970825195
Epoch: 0, Step: 3094, Loss: 5.23137903213501
Epoch: 0, Step: 3536, Loss: 4.969971179962158
Epoch: 0, Step: 3978, Loss: 5.185560703277588
Epoch: 0, Step: 4420, Loss: 4.611307621002197
Epoch: 0, Step: 4425, Loss: 5.228513717651367
Epoch: 1, Step: 0, Loss: 5.138219833374023
Epoch: 1, Step: 442, Loss: 5.167189121246338
Epoch: 1, Step: 884, Loss: 4.924767971038818
Epoch: 1, Step: 1326, Loss: 4.846548080444336


In [32]:
def generate(idx, max_tokens):
    model.eval()
    tokens = idx
    for i in range(max_tokens):
        logits = model(tokens[:, -context_length:])
        topk_values, topk_indices = torch.topk(logits[:, -1, :], 50)
        probs = topk_values.softmax(dim=-1)
        sample = torch.multinomial(probs, 1)
        token = torch.gather(topk_indices, 1, sample)
        tokens = torch.cat((tokens, token), dim=-1)
    
    return tokens

In [33]:
start_tokens = tokenizer.encode("What do you call", return_tensors='pt').to(device)
print(tokenizer.decode(generate(start_tokens, 2000)[0].tolist()))

What do you call a girl who was an common like a bad? A chicken out between a cow? The bartender doesn't a big.
<|endoftext|>If a man is no new job, she'll be a new lot.
<|endoftext|>What's your computer got on me? The difference between his mother's and that it's a woman
<|endoftext|>They don't take a job. The bartender asks 'pats a picture.
<|endoftext|>Why can't gay and black people get up? Because he lost two things on his house.
<|endoftext|>I went to the only I have a job. I'm in bed, you don't tell it.
<|endoftext|>My phone is a dyslexon, but I want to eat.
<|endoftext|>What's the difference between a pirate and a joke? The difference between a bar and 1, and both he says, there's a long time, it's a man in.
<|endoftext|>My wife's best friend who had to look a long car? My father asked me I'll just like a group.
<|endoftext|>Did you hear about the people about the guy who told you get it the toilet? They haven't believe he's only.
<|endoftext|>I bet I could pick some things abou

In [None]:
# torch.save(model.state_dict(), "models/jokes_transformer")