In [1]:
import numpy as np
import pandas as pd
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from torchvision.transforms import v2
from datasets import load_dataset
import tiktoken
import time
import math

In [2]:
# taken from https://pytorch-tutorials-preview.netlify.app/beginner/transformer_tutorial.html

class PositionalEncoding(nn.Module):

    def __init__(self, 
                 d_model: int, 
                 dropout: float = 0.1, 
                 max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Arguments:
            x: Tensor, shape ``[batch_size, seq_len, embedding_dim]``
        """
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

In [3]:
class SelfAttention(nn.Module):
    def __init__(self, 
                 d_model: int, 
                 d_query: int = 128, 
                 n_heads: int = 8,
                 device: torch.device = torch.device("cpu")):
        super().__init__()
        self.device = device

        self.W_q = nn.Linear(d_model, d_query)
        self.W_k = nn.Linear(d_model, d_query)
        self.W_v = nn.Linear(d_model, d_model)

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):

        q = self.W_q(x)
        k = self.W_k(x)
        v = self.W_v(x)

        attention_pattern = torch.matmul(q, torch.transpose(k, 1, 2))
        
        seq_len = attention_pattern.shape[-1]
        mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool().to(self.device)
        attention_pattern = torch.masked_fill(attention_pattern, mask, float("-inf"))

        attention_pattern = self.softmax(attention_pattern)

        output = torch.matmul(attention_pattern, v)
        
        return output



In [4]:
class MultilayerPerceptron(nn.Module):
    def __init__(self, 
                 d_model: int, 
                 d_up: int = 256):
        super().__init__()

        self.up = nn.Linear(d_model, d_up)
        self.relu = nn.ReLU()
        self.down = nn.Linear(d_up, d_model)
    
    def forward(self, x):

        output = self.up(x)
        output = self.relu(output)
        output = self.down(output)

        output = output + x

        return output

In [5]:
class Transformer(nn.Module):
    def __init__(self, 
                 n_vocab: int, 
                 d_model: int = 128, 
                 d_query: int = 128, 
                 n_heads: int = 8, 
                 n_layers: int = 4, 
                 d_up: int = 256,
                 device: torch.device = torch.device("cpu")):
        super().__init__()

        self.embedding = nn.Embedding(n_vocab, d_model)
        self.pe = PositionalEncoding(d_model, max_len=50000)

        self.self_attention = SelfAttention(d_model, d_query, n_heads, device)

        self.mlp = MultilayerPerceptron(d_model, d_up)

        self.unembedding = nn.Linear(d_model, n_vocab)

    def forward(self, x):
        x = self.embedding(x)
        x = self.pe(x)

        x = self.self_attention(x)

        x = self.mlp(x)

        x = self.unembedding(x)
        
        return x

In [6]:
def train_model(model, optimizer, criterion, device, train_loader, accum_steps, epoch):
    model.train()
    
    start_time = time.time()
    for idx, inputs in enumerate(train_loader):
        inputs = inputs.to(device)
        targets = inputs[:,1:]
        outputs = model(inputs)[:,:-1,:]

        targets = targets.reshape(-1)
        outputs = outputs.reshape(-1, outputs.shape[-1])
        
        loss = criterion(outputs, targets) / accum_steps
        loss.backward()

        if (idx + 1) % accum_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        if (idx + 1) % (accum_steps * 4) == 0:
            print(f"Epoch [{epoch}].[{idx}] Loss: {loss * accum_steps}")

        if (idx + 1) % (accum_steps * 16) == 0:
            elapsed_time = time.time() - start_time
            print(f"TIME: {elapsed_time / (accum_steps * 16)} seconds per batch")
            start_time = time.time()

            allocated = torch.cuda.memory_allocated() / 1e9
            reserved = torch.cuda.memory_reserved() / 1e9
            peak = torch.cuda.max_memory_allocated() / 1e9
            print(f"USAGE: Allocated {allocated:.2f}GB, Reserved {reserved:.2f}GB, Peak: {peak:.2f}GB")
            

In [7]:
torch.manual_seed(42)
torch.set_float32_matmul_precision('high')
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(device)

cuda


In [8]:
dataset = load_dataset("roneneldan/TinyStories", split="train+validation")
dataset

Dataset({
    features: ['text'],
    num_rows: 2141709
})

In [9]:
encoder = tiktoken.get_encoding("cl100k_base")

In [10]:
def tokenize(sequence):
    sequence["text"] = torch.tensor(encoder.encode(sequence["text"]), dtype=torch.int64)
    return sequence

tokenized_dataset = dataset.map(tokenize, num_proc=8).with_format("torch")
tokenized_dataset = tokenized_dataset.train_test_split(test_size=0.2, shuffle=True)

In [11]:
train = tokenized_dataset["train"]["text"]
test = tokenized_dataset["test"]["text"]

In [12]:
# hyperparameters
batch_size = 2
accum_steps = 8
d_model  = 128
d_query  = 64
d_up = 256
n_heads  = 4
n_layers = 4

n_vocab  = encoder.n_vocab

In [13]:
print(n_vocab)

100277


In [14]:
def collate_fn_padding(batch):
    batch = pad_sequence(batch, batch_first=True)
    return batch

train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, collate_fn=collate_fn_padding)

In [15]:
model = Transformer(n_vocab=n_vocab, 
                    d_model=d_model, 
                    d_query=d_query, 
                    n_heads=n_heads, 
                    n_layers=n_layers, 
                    d_up=d_up, 
                    device=device).to(device)
model = torch.compile(model)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

In [None]:
train_model(model, optimizer, criterion, device, train_loader, accum_steps, 1)

W0117 14:20:55.099000 18924 torch/_inductor/utils.py:1613] [0/0] Not enough SMs to use max_autotune_gemm mode


Epoch [1].[31] Loss: 11.574928283691406
Epoch [1].[63] Loss: 11.445477485656738
Epoch [1].[95] Loss: 11.340519905090332
Epoch [1].[127] Loss: 11.365422248840332
TIME: 0.0678132139146328 seconds per batch
USAGE: Allocated 0.48GB, Reserved 4.07GB, Peak: 2.57GB
Epoch [1].[159] Loss: 11.288846969604492
Epoch [1].[191] Loss: 11.1409912109375
Epoch [1].[223] Loss: 10.636605262756348
Epoch [1].[255] Loss: 10.760150909423828
TIME: 0.017053373157978058 seconds per batch
USAGE: Allocated 0.57GB, Reserved 4.00GB, Peak: 3.67GB
Epoch [1].[287] Loss: 10.391779899597168
Epoch [1].[319] Loss: 10.745203018188477
Epoch [1].[351] Loss: 10.914483070373535
Epoch [1].[383] Loss: 10.921076774597168
TIME: 0.014722956344485283 seconds per batch
USAGE: Allocated 0.51GB, Reserved 4.80GB, Peak: 3.67GB
Epoch [1].[415] Loss: 10.497237205505371
Epoch [1].[447] Loss: 10.192230224609375
Epoch [1].[479] Loss: 9.598983764648438
Epoch [1].[511] Loss: 10.059351921081543
TIME: 0.016624536365270615 seconds per batch
USAGE: 

In [None]:
output = model(train[:8])
print(f"output: {output.shape}")

In [None]:
output[0,-4,:].argmax()

In [None]:
predicted_word = encoder.decode([output[0,-1,:].argmax()])
print(predicted_word)