In [50]:
import numpy as np

In [51]:
import torch
import torch.nn as nn
device = torch.device('mps') if torch.backends.mps.is_available() else torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [52]:

from torch.utils.data import Dataset, DataLoader
import re
from datasets import load_dataset
from transformers import AutoModelForCausalLM  


model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2").to(device)
dataset = load_dataset("rajpurkar/squad")

In [53]:
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 87599
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 10570
    })
})


In [54]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('gpt2data')

# Create sequences
SEQ_LEN = 150
class TextDataset(Dataset):
    def __init__(self, dataset, split='train'):
        self.data = dataset[split]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        example = self.data[idx]
        prompt = f"Question: {example['question']}\nContext: {example['context']}\nAnswer:"
        answer = f"{prompt} {example['answers']['text'][0]}"
        
        prompt_encodings = tokenizer(
            prompt,
            truncation=True,
            max_length=SEQ_LEN,
            padding = False,
            return_tensors=None
        ).input_ids

        answer_encodings = tokenizer(
            answer,
            truncation=True,
            max_length=SEQ_LEN,
            padding = False,
            return_tensors=None
        ).input_ids

        inputs = prompt_encodings + [tokenizer.eos_token_id] * (SEQ_LEN - len(prompt_encodings))
        # Later set -100 token to be ignored when calculating the loss
        labels = [-100] * len(prompt_encodings) + answer_encodings[len(prompt_encodings):] + [-100] * (SEQ_LEN - len(answer_encodings))
       
        mask = torch.cat([
            torch.ones(len(prompt_encodings), dtype=torch.long),
            torch.zeros(SEQ_LEN - len(prompt_encodings), dtype=torch.long)
        ])
       
        return torch.tensor(inputs, dtype=torch.long), torch.tensor(labels, dtype=torch.long), mask

train_datasets = TextDataset(dataset)
print(train_datasets[100])
train_loader = DataLoader(train_datasets, batch_size=64, shuffle=True)


(tensor([24361,    25,   554,   644,   614,   750,   262,  1074,  1085,   416,
         6102,  1133,  4631,   710,  1592,   262,  8049,  8693,    30,   198,
        21947,    25,  1881,   286,   262,  1388,  5059,  3386,   287,   262,
         3349,   286,   262,  2059,   373,   663,  4346,  1074,    11,   262,
        23382, 20377, 19098,  8685,    13,  6102,  1133,  4631,   710,  2627,
         1182,  3985,   287, 25859,    13,  4698,  4631,   710,    11,   262,
         8685,   561,  1281,   257,  1700,   286, 13343,  7864,    11,  1105,
         9089,    11,   290,  1936,  8470,    13,  5856,   465,  1511,   812,
          262,  8685,  1839,  1115,  2260, 27459,    11,   550,  1936, 41445,
         7028,    11,  1839,   262,  8049,  8693,   287, 36864,    11,   290,
         4635,  1938,   884,   355,  4502,   402,  3974,   290,   262,   366,
        15137, 18455,  3653,  1911,  6102,  1133,  4631,   710,   468,   262,
         4511,  5442,  5873, 20262,  3459,    16,     8,   287,



Let's see what the first pair of input/output sequences look like.

In [55]:
next(iter(train_loader))

[tensor([[24361,    25,  6350,  ..., 50256, 50256, 50256],
         [24361,    25,  1867,  ...,  1678,   357,    67],
         [24361,    25, 28470,  ..., 50256, 50256, 50256],
         ...,
         [24361,    25,  1867,  ..., 50256, 50256, 50256],
         [24361,    25,   554,  ..., 50256, 50256, 50256],
         [24361,    25,  1867,  ...,    72,    12,    33]]),
 tensor([[-100, -100, -100,  ..., -100, -100, -100],
         [-100, -100, -100,  ..., -100, -100, -100],
         [-100, -100, -100,  ..., -100, -100, -100],
         ...,
         [-100, -100, -100,  ..., -100, -100, -100],
         [-100, -100, -100,  ..., -100, -100, -100],
         [-100, -100, -100,  ..., -100, -100, -100]]),
 tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 1, 1, 1]])]

In [56]:
from tqdm import tqdm

def train_gpt(model, dataloader, optimizer, criterion, epochs, device):
    model.to(device)
    model.train()

    for epoch in range(epochs):
        total_loss = 0

        data_loader_with_progress = tqdm(
            iterable=dataloader, ncols=120, desc=f"Epoch {epoch+1}/{epochs}"
        )
        for batch_number, (inputs, targets, mask) in enumerate(data_loader_with_progress):
            inputs = inputs.to(device)
            targets = targets.to(device)
            optimizer.zero_grad()
            logits= model(inputs, attention_mask=mask).logits
            loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1))
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            if (batch_number % 100 == 0) or (batch_number == len(dataloader) - 1):
                data_loader_with_progress.set_postfix(
                    {
                        "avg loss": f"{total_loss/(batch_number+1):.4f}",
                    }
                )            


In [None]:
criterion = nn.CrossEntropyLoss(ignore_index=-100) 
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
train_gpt(model, train_loader, optimizer, criterion, epochs=2, device=device)

Epoch 1/2:   0%|                                                                               | 0/1369 [00:00<?, ?it/s]

We can now use the trained GPT to generate text.  The model will generate a sequence of tokens based on the input prompt. We can use the inverse mapping from our vocabulary to "translate" the tokens to natural text.

In [None]:
class TextGenerator:
    def __init__(self, model, top_k=10):
        self.model = model
        self.model.to(device)

    def sample_from(self, probs, temperature):
        probs[1] = 0  # Mask out UNK token (index 1) to prevent generating <UNK>
        probs = torch.nn.functional.softmax(probs/temperature, dim=-1)
        next_id = torch.multinomial(probs, num_samples=1).item()
        return next_id, probs

    def generate(self, start_prompt, max_tokens, temperature):
        self.model.eval()
        generated_tokens = tokenizer(
            start_prompt,
            truncation=True,
            max_length=SEQ_LEN,
            padding = False,
            return_tensors=None
        ).input_ids

        info = []

        with torch.no_grad():
            while len(generated_tokens) < max_tokens:
                x = torch.tensor([generated_tokens], dtype=torch.long)
                x = x.to(device)
                logits = self.model(x).logits
                last_logits = logits[0, -1] # .cpu().numpy()
                sample_token, probs = self.sample_from(last_logits, temperature)
                generated_tokens.append(sample_token)
                info.append({
                    "prompt": start_prompt,
                    "word_probs": probs,
                })
                if sample_token == 0:
                    break
        print("GEN", generated_tokens)
        generated_words = tokenizer.decode(generated_tokens)
        print("generated text:" + " ".join(generated_words))
        return info


In [None]:
text_generator = TextGenerator(model)
info = text_generator.generate("captain ", max_tokens=150, temperature=3.0)

In [None]:
# Save checkpoint every epoch
import os

checkpoint = {
    'model_state_dict': model.state_dict(),
}

# Save latest checkpoint
checkpoint_path = os.path.join("gpt2_checkpoints", f'gpt2.pth')
os.makedirs("checkpoints", exist_ok=True)
torch.save(checkpoint, checkpoint_path)
print(f"Checkpoint saved: {checkpoint_path}")