In [29]:
import numpy as np
import pandas as pd
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from torchvision.transforms import v2
from datasets import load_dataset
import tiktoken

In [2]:
torch.manual_seed(42)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(device)

cuda


In [3]:
dataset = load_dataset("roneneldan/TinyStories", split="train+validation")
train = dataset["text"]
dataset

Dataset({
    features: ['text'],
    num_rows: 2141709
})

In [4]:
encoder = tiktoken.get_encoding("cl100k_base")

In [30]:
batch_size = 256
batch_num = 1
data_dir = "data"

In [None]:
batch = list()
for i, text in enumerate(train):
    encoding = torch.tensor(encoder.encode(text), dtype=torch.int64)
    batch.append(encoding)
    
    if ((i+1) % batch_size == 0) or ((i+1) == len(train)):
        padded_batch = pad_sequence(batch, batch_first=True)
        torch.save(padded_batch, f"{data_dir}/batch_{batch_num}.pt")
        batch_num += 1
        batch.clear()
    
    torch.save(padded_batch, f"{data_dir}/batch_{batch_num}.pt")
        
        print(f"\rSaving data {(i/len(train))*100:.2f}% complete.", end="")



Saving data 100.00% complete.

In [None]:
class TinyStoriesDataset(Dataset):
    def __init__(self, data_dir):
        self.data_dir
    def __len__(self):
        return 2141709
    def __getitem__(self, idx): 

In [5]:
train_tokens = torch.load(f"data/batch_1.pt").to(device)
train_tokens.shape

torch.Size([32, 241])

In [6]:
d_model  = 128
d_query  = 64
n_heads  = 4
n_vocab  = encoder.n_vocab
n_layers = 4

In [7]:
print(n_vocab)

100277


In [8]:
# taken from https://pytorch-tutorials-preview.netlify.app/beginner/transformer_tutorial.html
import math

class PositionalEncoding(nn.Module):

    def __init__(self, 
                 d_model: int, 
                 dropout: float = 0.1, 
                 max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Arguments:
            x: Tensor, shape ``[batch_size, seq_len, embedding_dim]``
        """
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

In [9]:
class SelfAttention(nn.Module):
    def __init__(self, 
                 d_model: int, 
                 d_query: int = 128, 
                 n_heads: int = 8,
                 device: torch.device = torch.device("cpu")):
        super().__init__()
        self.device = device

        self.W_q = nn.Linear(d_model, d_query)
        self.W_k = nn.Linear(d_model, d_query)
        self.W_v = nn.Linear(d_model, d_model)

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):

        q = self.W_q(x)
        k = self.W_k(x)
        v = self.W_v(x)

        attention_pattern = torch.matmul(q, torch.transpose(k, 1, 2))
        
        seq_len = attention_pattern.shape[-1]
        mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool().to(self.device)
        attention_pattern = torch.masked_fill(attention_pattern, mask, float("-inf"))

        attention_pattern = self.softmax(attention_pattern)

        output = torch.matmul(attention_pattern, v)
        
        return output



In [10]:
class MultilayerPerceptron(nn.Module):
    def __init__(self, 
                 d_model: int, 
                 d_up: int = 256):
        super().__init__()

        self.up = nn.Linear(d_model, d_up)
        self.relu = nn.ReLU()
        self.down = nn.Linear(d_up, d_model)
    
    def forward(self, x):

        output = self.up(x)
        output = self.relu(output)
        output = self.down(output)

        output = output + x

        return output

In [11]:
class Transformer(nn.Module):
    def __init__(self, 
                 n_vocab: int, 
                 d_model: int = 128, 
                 d_query: int = 128, 
                 n_heads: int = 8, 
                 n_layers: int = 4, 
                 d_up: int = 256,
                 device: torch.device = torch.device("cpu")):
        super().__init__()

        self.embedding = nn.Embedding(n_vocab, d_model)
        self.pe = PositionalEncoding(d_model, max_len=50000)

        self.self_attention = SelfAttention(d_model, d_query, n_heads, device)

        self.mlp = MultilayerPerceptron(d_model, d_up)

        self.unembedding = nn.Linear(d_model, n_vocab)

    def forward(self, x):
        x = self.embedding(x)
        x = self.pe(x)

        x = self.self_attention(x)

        x = self.mlp(x)

        x = self.unembedding(x)
        
        return x

In [None]:
def collate_fn_padding(batch):

    batch = pad_sequence(batch, batch_first=True)
    return batch
train_loader = DataLoader(train, batch_size=1, shuffle=True)

In [13]:
model = Transformer(n_vocab, d_model, d_query, n_heads, n_layers, device=device).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

In [14]:
output = model(train_tokens)
output.shape

torch.Size([32, 241, 100277])

In [15]:
output[:,:-1,:].reshape(-1, output.shape[-1]).shape

torch.Size([7680, 100277])

In [16]:
train_tokens[:,1:].reshape(-1).shape

torch.Size([7680])

In [17]:
output.dtype

torch.float32

In [18]:
train_tokens.dtype

torch.int32

In [19]:
def train_model(model, optimizer, criterion, device, train_loader, epoch):
    model.train()
    for idx, inputs in enumerate(train_loader):
        
        targets = inputs[:,1:]
        outputs = model(inputs)[:,:-1,:]

        targets = targets.reshape(-1)
        outputs = outputs.reshape(-1, outputs.shape[-1])
        
        loss = criterion(outputs, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if idx % 25 == 0:
            print(f"Epoch [{epoch}].[{idx}] Loss: {loss}")

In [20]:
train_model(model, optimizer, criterion, device, train_loader, 1)

Epoch [1].[0] Loss: 11.620615005493164
Epoch [1].[25] Loss: 10.795742988586426


In [21]:
output = model(train_tokens[:8])
print(f"output: {output.shape}")

output: torch.Size([8, 241, 100277])


In [27]:
output[0,-4,:].argmax()

tensor(0, device='cuda:0')

In [25]:
predicted_word = encoder.decode([output[0,-1,:].argmax()])
print(predicted_word)

!
