In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer
import random

class MiniGPT(nn.Module):
    def __init__(self, vocab_size, seq_len, d_model=128, n_heads=4, n_layers=2):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Parameter(torch.zeros(1, seq_len, d_model))
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_heads)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        self.lm_head = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        B, T = x.size()
        tok_emb = self.token_emb(x)             # (B, T, d_model)
        x = tok_emb + self.pos_emb[:, :T, :]    # add positional embedding
        x = x.transpose(0, 1)                   # (T, B, d_model) for transformer
        #mask = torch.triu(torch.ones(T, T), diagonal=1).bool().to(x.device)  # shape (T, T), adding causal attention mask
        x = self.transformer(x)                 # (T, B, d_model)
        x = x.transpose(0, 1)                   # (B, T, d_model)
        logits = self.lm_head(x)                # (B, T, vocab_size)
        return logits

class ToyDataset(Dataset):
    def __init__(self, texts, tokenizer, seq_len=128):
        self.examples = []
        for text in texts:
            tokens = tokenizer.encode(text)
            for i in range(0, len(tokens) - seq_len):
                input_seq = tokens[i:i+seq_len]
                target_seq = tokens[i+1:i+seq_len+1]
                self.examples.append((input_seq, target_seq))

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        x, y = self.examples[idx]
        return torch.tensor(x), torch.tensor(y)

def train(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)

def generate(model, tokenizer, prompt, max_new_tokens=50):
    model.eval()
    input_ids = tokenizer(prompt, return_tensors='pt')['input_ids'].to(next(model.parameters()).device)
    generated = input_ids

    for _ in range(max_new_tokens):
        with torch.no_grad():
            logits = model(generated)
            next_token_logits = logits[:, -1, :]
            probs = F.softmax(next_token_logits, dim=-1)
            #next_token = torch.multinomial(probs, num_samples=1) #randomness
            next_token = torch.argmax(probs, dim=-1, keepdim=True) #deterministic
            generated = torch.cat([generated, next_token], dim=1)

    return tokenizer.decode(generated[0])

if __name__ == "__main__":
    # Setup
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    tokenizer.pad_token = tokenizer.eos_token

    # Dummy text data
    #texts = ["The sky is blue.", "The cat sat on the mat.", "The quick brown fox jumps over the lazy dog."] * 100
    texts = ["Roy lives in Sweden.", "Roy is a PhD student.", "Roy works in Ericsson and Chalmers."] * 100

    # Parameters
    seq_len = 8
    batch_size = 8
    vocab_size = tokenizer.vocab_size

    # DataLoader
    dataset = ToyDataset(texts, tokenizer, seq_len=seq_len)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Model & Optimizer
    model = MiniGPT(vocab_size=vocab_size, seq_len=64).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

    # Training Loop
    for epoch in range(100):
        loss = train(model, dataloader, optimizer, device)
        print(f"Epoch {epoch+1}, Loss: {loss:.4f}")

    # Generation Example
    prompt = "Roy is."
    output = generate(model, tokenizer, prompt, max_new_tokens=20)
    print("\nGenerated text:")
    print(output)


Epoch 1, Loss: 7.2804
Epoch 2, Loss: 4.2393
Epoch 3, Loss: 3.4378
Epoch 4, Loss: 2.7730
Epoch 5, Loss: 2.1170
Epoch 6, Loss: 1.5112
Epoch 7, Loss: 1.0032
Epoch 8, Loss: 0.6411
Epoch 9, Loss: 0.4131
Epoch 10, Loss: 0.2794
Epoch 11, Loss: 0.1999
Epoch 12, Loss: 0.1514
Epoch 13, Loss: 0.1206
Epoch 14, Loss: 0.0996
Epoch 15, Loss: 0.0840
Epoch 16, Loss: 0.0723
Epoch 17, Loss: 0.0636
Epoch 18, Loss: 0.0563
Epoch 19, Loss: 0.0500
Epoch 20, Loss: 0.0453
Epoch 21, Loss: 0.0410
Epoch 22, Loss: 0.0374
Epoch 23, Loss: 0.0345
Epoch 24, Loss: 0.0317
Epoch 25, Loss: 0.0294
Epoch 26, Loss: 0.0273
Epoch 27, Loss: 0.0257
Epoch 28, Loss: 0.0240
Epoch 29, Loss: 0.0227
Epoch 30, Loss: 0.0214
Epoch 31, Loss: 0.0201
Epoch 32, Loss: 0.0190
Epoch 33, Loss: 0.0181
Epoch 34, Loss: 0.0171
Epoch 35, Loss: 0.0163
Epoch 36, Loss: 0.0156
Epoch 37, Loss: 0.0148
Epoch 38, Loss: 0.0142
Epoch 39, Loss: 0.0136
Epoch 40, Loss: 0.0131
Epoch 41, Loss: 0.0126
Epoch 42, Loss: 0.0121
Epoch 43, Loss: 0.0117
Epoch 44, Loss: 0.01