In [1]:
import numpy as np
import pandas as pd
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from torchvision.transforms import v2
from datasets import load_dataset
import tiktoken
import time
import math
import os

In [2]:
# taken from https://pytorch-tutorials-preview.netlify.app/beginner/transformer_tutorial.html

class PositionalEncoding(nn.Module):

    def __init__(self, 
                 d_model: int, 
                 dropout: float = 0.1, 
                 max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Arguments:
            x: Tensor, shape ``[batch_size, seq_len, embedding_dim]``
        """
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

In [3]:
class SelfAttention(nn.Module):
    def __init__(self, 
                 d_model: int, 
                 d_query: int = 128, 
                 n_heads: int = 8,
                 device: torch.device = torch.device("cpu")):
        super().__init__()
        self.device = device
        self.n_heads = n_heads

        self.W_q = nn.Linear(d_model, d_query)
        self.W_k = nn.Linear(d_model, d_query)
        self.W_v = nn.Linear(d_model, d_model)

        self.scaling_factor = math.sqrt(d_query)

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        
        
        x = x.unsqueeze(1).repeat([1, self.n_heads, 1, 1])
        
        q = self.W_q(x)
        k = self.W_k(x)
        v = self.W_v(x)

        attention_pattern = torch.matmul(q, torch.transpose(k, -2, -1)) / self.scaling_factor
        
        seq_len = attention_pattern.shape[-1]
        mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool().to(self.device)
        attention_pattern = torch.masked_fill(attention_pattern, mask, float("-inf"))

        attention_pattern = self.softmax(attention_pattern)

        output = torch.sum(torch.matmul(attention_pattern, v), dim=1)
        
        return output



In [4]:
class MultilayerPerceptron(nn.Module):
    def __init__(self, 
                 d_model: int, 
                 d_up: int = 256):
        super().__init__()

        self.up = nn.Linear(d_model, d_up)
        self.relu = nn.ReLU()
        self.down = nn.Linear(d_up, d_model)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:

        output = self.up(x)
        output = self.relu(output)
        output = self.down(output)

        output = output + x

        return output

In [5]:
class Transformer(nn.Module):
    def __init__(self, 
                 n_vocab: int, 
                 d_model: int = 128, 
                 d_query: int = 128, 
                 n_heads: int = 8, 
                 n_layers: int = 4, 
                 d_up: int = 256,
                 device: torch.device = torch.device("cpu")):
        super().__init__()

        self.embedding = nn.Embedding(n_vocab, d_model)
        self.pe = PositionalEncoding(d_model, max_len=50000)

        self.attention_layers = nn.ModuleList([layer for _ in range(n_layers) for layer in 
                                               (SelfAttention(d_model, d_query, n_heads, device), 
                                                nn.LayerNorm(d_model),
                                                MultilayerPerceptron(d_model, d_up))])

        # self.self_attention = SelfAttention(d_model, d_query, n_heads, device)
        # self.mlp = MultilayerPerceptron(d_model, d_up)

        self.unembedding = nn.Linear(d_model, n_vocab)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.embedding(x)
        x = self.pe(x)

        for layer in self.attention_layers:
            x = layer(x)

        x = self.unembedding(x)
        
        return x

In [6]:
class ModelCheckpoint():
    def __init__(self, model_state, optim_state, epoch: int, batch: int, rng_state: torch.Tensor):
        self.model_state = model_state
        self.optim_state = optim_state
        self.epoch       = epoch
        self.batch       = batch
        self.rng_state   = rng_state
    def save(self, file_path):
        torch.save({
            "model_state": self.model_state,
            "optim_state": self.optim_state,
            "epoch"      : self.epoch,
            "batch"      : self.batch,
            "rng_state"  : self.rng_state
        }, file_path)
        print(f"Model checkpoint saved at {file_path} at epoch {self.epoch} batch {self.batch}")
    @staticmethod
    def load(file_path):
        checkpoint = torch.load(file_path)
        return ModelCheckpoint(checkpoint["model_state"], 
                               checkpoint["optim_state"], 
                               checkpoint["epoch"], 
                               checkpoint["batch"], 
                               checkpoint["rng_state"])

In [7]:
class CheckpointRandomSampler(torch.utils.data.RandomSampler):
    def __init__(self, data_source, checkpoint: ModelCheckpoint = None):
        super().__init__(data_source)
        self.start_idx = 0 if checkpoint == None else checkpoint.batch
    def __iter__(self):
        batch_idxs = list(super().__iter__())
        for idx in batch_idxs[self.start_idx:]:
            yield idx
        return
    def __len__(self):
        return super().__len__() - self.start_idx

In [8]:
def validate_model(model, 
                   device, 
                   criterion, 
                   test_loader):
    with torch.no_grad():
        avg_loss = 0
        for idx, inputs in enumerate(test_loader):
            inputs = inputs.to(device)
            targets = inputs[:,1:]
            outputs = model(inputs)[:,:-1,:]

            targets = targets.reshape(-1)
            outputs = outputs.reshape(-1, outputs.shape[-1])
            
            loss = criterion(outputs, targets)

            avg_loss += loss

        avg_loss /= len(test_loader)
        print(f"VALIDATE: Average Loss: {avg_loss}")

def train_model(model: nn.Module, 
                optimizer, 
                criterion, 
                device, 
                train_loader, 
                validation_loader, 
                accum_steps,
                num_epochs: int = 4,
                checkpoint: ModelCheckpoint = None):
    
    epoch = 0
    start_batch = 0

    if checkpoint != None:
        model.load_state_dict(checkpoint.model_state)
        optimizer.load_state_dict(checkpoint.optim_state)
        epoch = checkpoint.epoch
        start_batch = checkpoint.batch
        torch.set_rng_state(checkpoint.rng_state)

    print(f"Starting training on epoch {epoch}, batch {start_batch}")

    model.train()

    for epoch in range(epoch, num_epochs):
        start_time = time.time()
        for idx, inputs in enumerate(train_loader):
            batch_idx = start_batch + idx
            inputs = inputs.to(device)
            targets = inputs[:,1:]
            outputs = model(inputs)[:,:-1,:]

            targets = targets.reshape(-1)
            outputs = outputs.reshape(-1, outputs.shape[-1])
            
            loss = criterion(outputs, targets) / accum_steps
            loss.backward()

            if (batch_idx + 1) % accum_steps == 0:
                optimizer.step()
                optimizer.zero_grad()

            if (batch_idx + 1) % (accum_steps * 4) == 0:
                print(f"Epoch [{epoch}].[{batch_idx}] Loss: {loss * accum_steps}")

            if (batch_idx + 1) % (accum_steps * 16) == 0:
                elapsed_time = time.time() - start_time
                print(f"TIME: {elapsed_time / (accum_steps * 16)} seconds per batch")
                start_time = time.time()

                allocated = torch.cuda.memory_allocated() / 1e9
                reserved = torch.cuda.memory_reserved() / 1e9
                peak = torch.cuda.max_memory_allocated() / 1e9
                print(f"USAGE: Allocated {allocated:.2f}GB, Reserved {reserved:.2f}GB, Peak: {peak:.2f}GB")

            if (batch_idx + 1) % (accum_steps * 32) == 0:
                validate_model(model, device, criterion, validation_loader)
                checkpoint = ModelCheckpoint(model.state_dict(), 
                                            optimizer.state_dict(), 
                                            epoch, 
                                            batch_idx, 
                                            torch.get_rng_state())
                checkpoint.save("checkpoint.pt")

            

In [9]:
torch.manual_seed(42)
torch.set_float32_matmul_precision('high')
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(device)

cuda


In [10]:
dataset = load_dataset("roneneldan/TinyStories", split="train+validation")
dataset

Dataset({
    features: ['text'],
    num_rows: 2141709
})

In [11]:
encoder = tiktoken.get_encoding("cl100k_base")

In [None]:
def tokenize(sequence, encoder):
    # NOTE: <EOS> token is encoder.n_vocab
    sequence["text"] = torch.tensor(encoder.encode(sequence["text"]) + [encoder.n_vocab], dtype=torch.int64)
    return sequence

tokenized_dataset = dataset.map(tokenize, num_proc=8, fn_kwargs={"encoder": encoder}).with_format("torch")
tokenized_dataset = tokenized_dataset.train_test_split(test_size=0.2, shuffle=True)

In [13]:
train = tokenized_dataset["train"]["text"]

test_set = tokenized_dataset["test"].train_test_split(test_size=0.001, shuffle=True)
test = test_set["train"]["text"]
validation = test_set["test"]["text"]

In [None]:
# hyperparameters
batch_size = 2
accum_steps = 32
d_model  = 256
d_query  = 64
d_up = 512
n_heads  = 4
n_layers = 4

# NOTE: add 1 for <EOS> token
n_vocab  = encoder.n_vocab + 1

In [15]:
print(n_vocab)

100277


In [16]:
def collate_fn_padding(batch):
    batch = pad_sequence(batch, batch_first=True)
    return batch

checkpoint_path = "checkpoint.pt"

checkpoint = None
if os.path.exists(checkpoint_path):
    checkpoint = ModelCheckpoint.load(checkpoint_path)


sampler = CheckpointRandomSampler(train, checkpoint)

train_loader = DataLoader(train, batch_size=batch_size, sampler=sampler, collate_fn=collate_fn_padding)
test_loader = DataLoader(test, batch_size=batch_size, shuffle=True, collate_fn=collate_fn_padding)
validation_loader = DataLoader(validation, batch_size=batch_size, shuffle=True, collate_fn=collate_fn_padding)

model = Transformer(n_vocab=n_vocab, 
                    d_model=d_model, 
                    d_query=d_query, 
                    n_heads=n_heads, 
                    n_layers=n_layers, 
                    d_up=d_up, 
                    device=device).to(device)
model = torch.compile(model)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

In [17]:
train_model(model=model, 
            optimizer=optimizer, 
            criterion=criterion, 
            device=device, 
            train_loader=train_loader, 
            validation_loader=validation_loader, 
            accum_steps=accum_steps,
            checkpoint=checkpoint)

Starting training on epoch 0, batch 236543


W0120 22:51:36.530000 139836 torch/_inductor/utils.py:1613] [0/0] Not enough SMs to use max_autotune_gemm mode


Epoch [0].[236543] Loss: 2.5811259746551514
TIME: 0.013678241521120071 seconds per batch
USAGE: Allocated 1.15GB, Reserved 1.95GB, Peak: 1.72GB
VALIDATE: Average Loss: 2.99733567237854
Model checkpoint saved at checkpoint.pt at epoch 0 batch 236543
Epoch [0].[236671] Loss: 2.6559653282165527
Epoch [0].[236799] Loss: 3.5133161544799805
Epoch [0].[236927] Loss: 3.46524715423584
Epoch [0].[237055] Loss: 2.581388235092163
TIME: 0.03390326304361224 seconds per batch
USAGE: Allocated 1.12GB, Reserved 5.08GB, Peak: 4.63GB
Epoch [0].[237183] Loss: 3.7092299461364746
Epoch [0].[237311] Loss: 3.644249200820923
Epoch [0].[237439] Loss: 3.072859525680542
Epoch [0].[237567] Loss: 3.0596001148223877
TIME: 0.02334205713123083 seconds per batch
USAGE: Allocated 1.17GB, Reserved 5.88GB, Peak: 4.63GB
VALIDATE: Average Loss: 2.993046998977661
Model checkpoint saved at checkpoint.pt at epoch 0 batch 237567
Epoch [0].[237695] Loss: 3.2633605003356934
Epoch [0].[237823] Loss: 3.0824472904205322
Epoch [0].[2

KeyboardInterrupt: 

In [21]:
def generate_response(model, encoder: tiktoken.Encoding, device: torch.device, prompt: str):
    sequence = encoder.encode(prompt)
    with torch.no_grad():
        num_sentences = 0
        while num_sentences < 15:
            input = torch.tensor(sequence, dtype=torch.int64).unsqueeze(0).to(device)
            output = model(input)
            output = output[0,-1,:].argmax().item()
            sequence = sequence + [output]

            if output == encoder.encode("."):
                num_sentences += 1

        response = encoder.decode(sequence)
        return response

In [20]:
encoder.encode("hello")

[15339]

In [22]:
# checkpoint = ModelCheckpoint.load("checkpoint.pt")
# model.load_state_dict(checkpoint.model_state)

prompt = "Tell me a story about a hungry man"

response = generate_response(model, encoder=encoder, device=device, prompt=prompt)
print(response)

KeyboardInterrupt: 

In [None]:
output = model(encoder.encode(""))