In [1]:
import torch
from torch import tensor, nn, optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR

In [2]:
with open("data/input.txt") as f:
    text = f.read()

In [3]:
# Grab the tokenizer only
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

In [4]:
tokens = tensor(tokenizer.encode(text))
tokens[:5]

Token indices sequence length is longer than the specified maximum sequence length for this model (338025 > 1024). Running this sequence through the model will result in indexing errors


tensor([ 5962, 22307,    25,   198,  8421])

In [5]:
set(['hi', 'hi'])

{'hi'}

In [6]:
from datasets import load_dataset

ds = load_dataset("myothiha/jokes")

README.md:   0%|          | 0.00/21.0 [00:00<?, ?B/s]

train.csv:   0%|          | 0.00/19.7M [00:00<?, ?B/s]

validation.csv:   0%|          | 0.00/2.18M [00:00<?, ?B/s]

test.csv:   0%|          | 0.00/2.43M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/187641 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/20850 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/23166 [00:00<?, ? examples/s]

In [7]:
len(set(ds['train']['text']))

187641

In [8]:
text = ' [SEP] '.join(ds['train']['text'])

In [9]:
vocab_size = tokenizer.vocab_size
context_length = 32
n_embs = 256
n_heads = 8
n_blocks = 8
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [10]:
len(text)

19034660

In [11]:
class TextDataLoader():
    def __init__(self, text, context_size, tokenizer, batch_size=1, device='cpu'):
        v = tokenizer(text, return_tensors='pt')
        self.tokens = v.input_ids # The attention mask will be handled manually later
        self.batch_size = batch_size
        self.context_length = context_length
        self.device = device
        
        self.position = 0
        
    def __iter__(self):
        self.reset()
        return self
        
    def __next__(self):
        B, T = self.batch_size, self.context_length
        if self.position + B * T + 1 < len(self.tokens[0]):
            tokens = self.tokens[0][self.position: self.position + B * T + 1]
            self.position += self.context_length + B * T + 1
            x = tokens[:-1].view(B, T)
            y = tokens[1:].view(B, T)
            return x.to(self.device), y.to(self.device)
        else:
            raise StopIteration
            
    def __len__(self):
        return len(self.tokens[0]) // self.context_length // self.batch_size
    
    def reset(self):
        self.position = 0
        

The `head_size` matches that of the embeddings if it is a single head.

If multi-headed attention is used, then the `head_size` would equal number of embeddings // number of heads.

In [12]:
class SelfAttentionHead(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.k = nn.Linear(n_embs, head_size, bias=False)
        self.q = nn.Linear(n_embs, head_size, bias=False)
        self.v = nn.Linear(n_embs, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(context_length, context_length)))
        
    def forward(self, x):
        B, T, C = x.shape
        k = self.k(x) # (B, T, head_size)
        q = self.q(x) # (B, T, head_size)
        v = self.v(x) # (B, T, head_size)
        
        attn = q @ k.transpose(-2, -1) * C ** -0.5 # (B, T, T)
        
        wei = attn.masked_fill(self.tril[:T, :T] == 0, float("-inf")) # Causal masking, blocks future tokens from being seen
        wei = wei.softmax(dim=-1) # (B, T, T)
        
        out = wei @ v # (B, T, head_size)
        return out

In [13]:
class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads):
        super().__init__()
        self.heads = nn.ModuleList([SelfAttentionHead(n_embs // n_heads) for i in range(n_heads)])
        self.proj = nn.Linear(n_embs, n_embs)
        
    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.proj(out)
        return out

In [14]:
class FeedForward(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embs, n_embs * 4),
            nn.ReLU(),
            nn.Linear(n_embs * 4, n_embs)
        )
        
    def forward(self, x):
        out = self.net(x)
        return out

In [15]:
class TransformerBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embs)
        self.mha = MultiHeadAttention(n_heads)
        self.ln2 = nn.LayerNorm(n_embs)
        self.ffwd = FeedForward()
        
    def forward(self, x):
        x = self.mha(self.ln1(x)) + x
        x = self.ffwd(self.ln2(x)) + x
        
        return x
        

In [16]:
class GPT(nn.Module):
    def __init__(self):
        super().__init__()
        self.tk_emb = nn.Embedding(vocab_size, n_embs)
        self.pos_emb = nn.Embedding(context_length, n_embs)
        self.blocks = nn.ModuleList([TransformerBlock() for i in range(n_blocks)])
        self.ln_f = nn.LayerNorm(n_embs)
        self.fc = nn.Linear(n_embs, vocab_size)
        
        self.tk_emb.weight = self.fc.weight
        
    def forward(self, x, targets=None):
        B, T = x.shape
        tk_emb = self.tk_emb(x) # (B, T, C)
        pos_tns = torch.arange(T, device=device) # T
        pos_emb = self.pos_emb(pos_tns) # (T, C)
        x = pos_emb + tk_emb # (B, T, C) + (T, C)
        
        for block in self.blocks:
            x = block(x)
        
        x = self.ln_f(x)
        logits = self.fc(x) # (B, T, vocab_size)
        if targets is None:
            return logits
        else:
            loss = F.cross_entropy(logits.view(B * T, -1), targets.view(B*T))
            return logits, loss
        
    def generate(self):
        pass

In [17]:
torch.set_float32_matmul_precision('high')

In [18]:
model = GPT().to(device)

In [19]:
dl = TextDataLoader(text, 32, tokenizer, batch_size=32, device=device)

In [20]:
xb, yb = next(iter(dl))
xb.shape, yb.shape

(torch.Size([32, 32]), torch.Size([32, 32]))

In [21]:
# model(xb).shape

In [22]:
lr = 1e-4
opt = optim.Adam(model.parameters(), lr)

In [23]:
epochs = 5
sched = CosineAnnealingLR(opt, epochs * len(dl), lr * 0.5)
for i in range(epochs):
    for step, (xb, yb) in enumerate(dl):
        opt.zero_grad()
        xb = xb.to(device)
        yb = yb.to(device)
        with torch.autocast(device_type=device, dtype=torch.bfloat16):
            logits, loss = model(xb, yb)
        loss.backward()
        opt.step()
        sched.step()
        if step % 200 == 0 or step == len(dl) - 1:
            print(f"Epoch: {i}, Step: {step}, Loss: {loss}")

Epoch: 0, Step: 0, Loss: 11.018798828125
Epoch: 0, Step: 200, Loss: 6.7135009765625
Epoch: 0, Step: 400, Loss: 6.4780731201171875
Epoch: 0, Step: 600, Loss: 6.050868034362793
Epoch: 0, Step: 800, Loss: 5.910508632659912
Epoch: 0, Step: 1000, Loss: 5.696419715881348
Epoch: 0, Step: 1200, Loss: 5.576383113861084
Epoch: 0, Step: 1400, Loss: 5.382604598999023
Epoch: 0, Step: 1600, Loss: 5.582192420959473
Epoch: 0, Step: 1800, Loss: 5.180304050445557
Epoch: 0, Step: 2000, Loss: 5.430119514465332
Epoch: 0, Step: 2200, Loss: 5.166388511657715
Epoch: 0, Step: 2400, Loss: 5.11775541305542
Epoch: 0, Step: 2600, Loss: 5.062450885772705
Epoch: 0, Step: 2800, Loss: 4.690340042114258
Epoch: 0, Step: 3000, Loss: 4.733290672302246
Epoch: 0, Step: 3200, Loss: 5.401444911956787
Epoch: 0, Step: 3400, Loss: 5.17732572555542
Epoch: 0, Step: 3600, Loss: 4.903584957122803
Epoch: 0, Step: 3800, Loss: 4.650193214416504
Epoch: 0, Step: 4000, Loss: 4.683797359466553
Epoch: 0, Step: 4200, Loss: 4.553931713104248


In [28]:
sum(p.numel() for p in model.parameters())

19236689

In [25]:
def generate(idx, max_tokens):
    model.eval()
    tokens = idx
    for i in range(max_tokens):
        logits = model(tokens[:, -context_length:])
        topk_values, topk_indices = torch.topk(logits[:, -1, :], 20)
        probs = topk_values.softmax(dim=-1)
        sample = torch.multinomial(probs, 1)
        token = torch.gather(topk_indices, 1, sample)
        tokens = torch.cat((tokens, token), dim=-1)
    
    return tokens

In [None]:
start_tokens = tokenizer.encode("What do you call a", return_tensors='pt').to(device)
print(tokenizer.decode(generate(start_tokens, 2000)[0].tolist()))

In [27]:
torch.save(model.state_dict(), "models/jokes_transformer")