In [1]:
import torch

if torch.cuda.is_available():
    device = torch.device("cuda")
# Commented out because we yet again find mps to be drastically slower
# elif torch.backends.mps.is_available():
#     torch._dynamo.disable()  # https://github.com/pytorch/pytorch/issues/149184
#     device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"{device=}")

In [2]:
from transformers import AutoTokenizer
from datasets import load_dataset

tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
dataset = load_dataset("wikitext", "wikitext-103-v1")

In [3]:
from pathlib import Path
from datasets import load_from_disk

# TODO: Use BooksCorpus dataset and context_length=512
context_length = 128

def tokenize(batch):
    # TODO: Sequence packing
    outputs = tokenizer(
        batch["text"],
        truncation=True,
        max_length=context_length,
        return_overflowing_tokens=True,
        return_length=True,
    )
    return {
        "input_ids": [
            input_ids
            for length, input_ids in zip(outputs["length"], outputs["input_ids"])
            if length == context_length
        ]
    }

if Path("tokenized-wiki-ds.hf").exists():
    tokenized_ds = load_from_disk("tokenized-wiki-ds.hf")
else:
    tokenized_ds = dataset.map(
        tokenize, batched=True, remove_columns=dataset["train"].column_names
    )
    tokenized_ds.save_to_disk("tokenized-wiki-ds.hf")
tokenized_ds

DatasetDict({
    test: Dataset({
        features: ['input_ids'],
        num_rows: 1184
    })
    train: Dataset({
        features: ['input_ids'],
        num_rows: 509994
    })
    validation: Dataset({
        features: ['input_ids'],
        num_rows: 1037
    })
})

In [4]:
from torch import nn

class PositionalEmbedding(nn.Module):
    def __init__(self, embedding_dim):
        super().__init__()
        self.numerators = 10_000 ** (
            torch.arange(
                start=0,
                end=embedding_dim,
                step=2,
            ).float()
            / embedding_dim
        )

    def forward(self, input_ids: torch.Tensor):
        with torch.no_grad():
            positions = torch.arange(
                input_ids.shape[1],
                device=input_ids.device,
            ).float()
            numerators = self.numerators.to(input_ids.device)
            raw_embeddings = positions.unsqueeze(1) @ (1 / numerators).unsqueeze(0)
            even_embeddings = torch.sin(raw_embeddings)
            odd_embeddings = torch.cos(raw_embeddings)
            embeddings = torch.stack(
                [even_embeddings, odd_embeddings], dim=-1
            ).view(
                len(positions), -1
            )
            return embeddings.unsqueeze(0).expand(input_ids.shape[0], -1, -1)


token_embedder = nn.Embedding(
    num_embeddings=tokenizer.vocab_size, embedding_dim=512, device=device
)
positional_embedder = PositionalEmbedding(embedding_dim=512)
transformer_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, device=device)
transformer = nn.TransformerEncoder(transformer_layer, num_layers=6)
decoder = nn.Linear(512, tokenizer.vocab_size, device=device)

src = ["Hi, my name", "The United States of"]
tokenized = tokenizer(src, return_tensors="pt").to(device)
embedded = token_embedder(tokenized.input_ids) + positional_embedder(tokenized.input_ids)
transformed = transformer(
    embedded.permute(1, 0, 2),  # Transformer expects (seq_len, batch_size, features)
    mask=nn.Transformer.generate_square_subsequent_mask(tokenized.input_ids.shape[1], device=device),
    # Skipping is_causal since seems troublesome: https://github.com/pytorch/pytorch/issues/96941
)
logits = decoder(transformed.permute(1, 0, 2))  # Back to (batch_size, seq_len, features)
result = tokenizer.batch_decode(logits[:, -1, :].argmax(dim=-1))
result



[' networks', ' networks']

In [5]:
import torch.nn.functional as F

class MyGPT(nn.Module):
    def __init__(self, d_model, nhead, num_layers, dim_feedforward, device):
        super().__init__()
        self.device = device
        self.token_embedder = nn.Embedding(
            num_embeddings=tokenizer.vocab_size, embedding_dim=d_model, device=device
        )
        self.positional_embedder = PositionalEmbedding(embedding_dim=d_model)
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, device=device),
            num_layers=num_layers,
        )
        self.decoder = nn.Linear(d_model, tokenizer.vocab_size, device=device)

    def forward(self, input_ids: torch.Tensor):
        embedded = self.token_embedder(input_ids) + self.positional_embedder(input_ids)
        transformed = self.transformer(
            embedded.permute(1, 0, 2),  # Transformer expects (seq_len, batch_size, features)
            mask=nn.Transformer.generate_square_subsequent_mask(input_ids.shape[1], device=input_ids.device),
        )
        logits = self.decoder(transformed.permute(1, 0, 2))
        return logits

def nucleus_sample(probits: torch.Tensor, prob_threshold: float):
    sorted_probs, sorted_indices = torch.sort(probits, dim=-1, descending=True)
    prev_cumulative_probs = torch.cumsum(sorted_probs, dim=-1) - sorted_probs
    sorted_indices_to_remove = prev_cumulative_probs >= prob_threshold
    sorted_probs[sorted_indices_to_remove] = 0
    sampled_sorted_idx = torch.multinomial(sorted_probs, num_samples=1)
    sampled_idx = sorted_indices.gather(dim=-1, index=sampled_sorted_idx)
    return sampled_idx

def greedy_sample(probits: torch.Tensor):
    return probits.argmax(dim=-1, keepdim=True)

def stream(model, input_ids: torch.Tensor, max_length, prob_threshold: float, temperature: float):
    # TODO: KV-cache to avoid quadratic computational complexity in `max_length`
    output_ids = input_ids.clone()
    for _ in range(max_length):
        with torch.no_grad():
            next_token_logits = model(output_ids)[:, -1, :]
            # TODO: If repetitions resurface as an issue: repetition_penalty and no_repeat_ngram_size
            next_token_shaped_logits = next_token_logits / temperature
            next_token_probits = F.softmax(next_token_shaped_logits, dim=-1)
            next_token_id = nucleus_sample(probits=next_token_probits, prob_threshold=prob_threshold)
            output_ids = torch.cat([output_ids, next_token_id], dim=1)
            yield next_token_id.item()

def print_stream(
    model, tokenizer, prompt: str, max_length=50, prob_threshold: float = 0.95, temperature: float = 1.0
):
    print(prompt, end="", flush=True)
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
    for token in stream(
        model=model, input_ids=input_ids, max_length=max_length, prob_threshold=prob_threshold, temperature=temperature
    ):
        if token == tokenizer.eos_token_id:
            break
        print(tokenizer.decode(token), end="", flush=True)
    print("", flush=True)


model = MyGPT(d_model=512, nhead=8, num_layers=6, dim_feedforward=2048, device=device)
model

MyGPT(
  (token_embedder): Embedding(50257, 512)
  (positional_embedder): PositionalEmbedding()
  (transformer): TransformerEncoder(
    (layers): ModuleList(
      (0-5): 6 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (decoder): Linear(in_features=512, out_features=50257, bias=True)
)

In [14]:
tokenizer.batch_decode(model(tokenizer(src, return_tensors="pt").input_ids.to(device))[:, -1, :].argmax(dim=-1))

['umbn', ' cyt']

In [15]:
for _ in range(10):
    print_stream(model=model, tokenizer=tokenizer, prompt="The United States of", max_length=50)
    print("", flush=True)

The United States of disinteg Salvador McAuliffe pred Christian predicio SalvadorViceUnd cyt continuation catchy stimul Launch Launch Launch Launch Launch Launch Launch Launch Launch Launch Launch Launch Launch Launch Byzantine cyt208 Launch swallow Ride Ride Ride AurTue Aur Launch Launch Launch Launch Launch Launch Launch Launch Launch Launch walked
The United States of chall experien LaunchRel 345coolCookBuilt Hundredscoolér hectcoolCookBuilt Hundredscool mentor Launch Launch Launch Launch Launch Launch Launch Launch swallow cyt Launch walked cyt Morgan radiobleacher pepper Launch radioConsider radio cyt Launch Launch Launch Byzantine Launch refusal normal cartLeave Launch
The United States ofViceUnd nexusIB 104699 Lehakh experien Launch fentanyl batters Launch Launch Launch Launch Launch Launch Launch Launch Launch Launch Launch Launch closely cyt208 Launch Byzantine Launch Byzantine cyt208 Launch radio cyt Launch walked208 Launch Launch Launch Launch Launch Launch Launch Launch rad

In [16]:
dataset["train"]["text"][:5]

['',
 ' = Valkyria Chronicles III = \n',
 '',
 ' Senjō no Valkyria 3 : <unk> Chronicles ( Japanese : 戦場のヴァルキュリア3 , lit . Valkyria of the Battlefield 3 ) , commonly referred to as Valkyria Chronicles III outside Japan , is a tactical role @-@ playing video game developed by Sega and Media.Vision for the PlayStation Portable . Released in January 2011 in Japan , it is the third game in the Valkyria series . Employing the same fusion of tactical and real @-@ time gameplay as its predecessors , the story runs parallel to the first game and follows the " Nameless " , a penal military unit serving the nation of Gallia during the Second Europan War who perform secret black operations and are pitted against the Imperial unit " <unk> Raven " . \n',
 " The game began development in 2010 , carrying over a large portion of the work done on Valkyria Chronicles II . While it retained the standard features of the series , it also underwent multiple adjustments , such as making the game more forgiving

In [18]:
tokenizer.pad_token = tokenizer.eos_token
print(tokenizer.pad_token_id)
tokenized = [
    {
        "input_ids": tokenizer(s, return_tensors="pt").input_ids.flatten().tolist()
    }
    for s in ["Hi", "my name is", "What, my name is"]
]
print(tokenized)
tokenizer.pad(tokenized, return_tensors="pt")

50256
[{'input_ids': [17250]}, {'input_ids': [1820, 1438, 318]}, {'input_ids': [2061, 11, 616, 1438, 318]}]


{'input_ids': tensor([[17250, 50256, 50256, 50256, 50256],
        [ 1820,  1438,   318, 50256, 50256],
        [ 2061,    11,   616,  1438,   318]]), 'attention_mask': tensor([[1, 0, 0, 0, 0],
        [1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1]])}

In [21]:
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR
from tqdm import tqdm

# All hyperparams in this cell were copied from GPT paper (although we use a different dataset)
model = MyGPT(d_model=768, nhead=12, num_layers=12, dim_feedforward=3072, device=device)

if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs")
    model = nn.DataParallel(model)

train_batch_size = 64
tokenizer.pad_token = tokenizer.eos_token
train_dl = DataLoader(
    tokenized_ds["train"],
    batch_size=train_batch_size,
    shuffle=True,
    collate_fn=lambda x: tokenizer.pad(x, return_tensors="pt"),
)
validation_dl = DataLoader(
    tokenized_ds["validation"],
    batch_size=train_batch_size * 2,
    shuffle=False,
    collate_fn=lambda x: tokenizer.pad(x, return_tensors="pt"),
)

optimizer = torch.optim.Adam(
    params=model.parameters(),
    betas=(0.9, 0.98),
    eps=1e-9,
    lr=2.5e-4,
)

num_epochs = 100
warmup_steps = 2000
total_steps = num_epochs * len(train_dl)
decay_steps = total_steps - warmup_steps
scheduler = SequentialLR(
    optimizer=optimizer,
    schedulers=[
        LinearLR(optimizer, start_factor=1.0 / warmup_steps, end_factor=1.0, total_iters=warmup_steps),
        CosineAnnealingLR(optimizer, T_max=decay_steps, eta_min=0.0),
    ],
    milestones=[warmup_steps],
)

log_frequency = 25
stream_frequency = 100
eval_frequency = 250

model.train()
train_losses = []
eval_losses = []
for epoch_i in range(num_epochs):
    for batch_i, batch in enumerate(train_dl):
        X: torch.Tensor = batch.input_ids.to(device)[:, :-1].contiguous()
        y: torch.Tensor = batch.input_ids.to(device)[:, 1:].contiguous()
        logits = model(X)
        loss = nn.functional.cross_entropy(
            logits.view(-1, logits.shape[-1]),
            y.view(-1),
            ignore_index=tokenizer.pad_token_id,
        )
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        scheduler.step()
        
        if batch_i % log_frequency == 0:
            train_losses.append(loss.item())
            print(f"Batch {batch_i + 1}/{len(train_dl)} in epoch {epoch_i + 1}/{num_epochs}: Loss {loss.item()}")
        
        if batch_i % stream_frequency == 0:
            model.eval()
            print_stream(model=model, tokenizer=tokenizer, prompt="In 1814, the", max_length=50)
            print("", flush=True)
            model.train()
        
        if batch_i % eval_frequency == 0:
            with torch.no_grad():
                model.eval()
                avg_val_loss = 0.0
                for validation_batch in validation_dl:
                    X_val: torch.Tensor = validation_batch.input_ids.to(device)[:, :-1].contiguous()
                    y_val: torch.Tensor = validation_batch.input_ids.to(device)[:, 1:].contiguous()
                    val_logits = model(X_val)
                    avg_val_loss += nn.functional.cross_entropy(
                        val_logits.view(-1, val_logits.shape[-1]),
                        y_val.view(-1),
                        ignore_index=tokenizer.pad_token_id,
                    ) / len(validation_dl)
                eval_losses.append(avg_val_loss.item())
                print(f"Avg. validation Loss {avg_val_loss.item()}")
                model.train()
    print("=" * 40 + f"COMPLETED EPOCH {epoch_i + 1}/{num_epochs}" + "=" * 40)

Batch 1/7969 in epoch 1/100: Loss 11.03641414642334
In 1814, the holidaysceptiveceptiveceptiveceptiveceptiveceptiveceptive CSV CSV CSV CSV CSV CSV CSV CSV CSV CSV CSV CSV CSV drifting drifting drifting drifting drifting drifting drifting drifting drifting drifting CSV CSV CSV CSV CSV CSV CSV CSV CSV CSV CSV CSV doomed drifting drifting doomed drifting doomed drifting
Avg. validation Loss 11.022744178771973
Batch 26/7969 in epoch 1/100: Loss 10.633322715759277
Batch 51/7969 in epoch 1/100: Loss 10.01376724243164


KeyboardInterrupt: 

In [None]:
import numpy as np
import matplotlib.pyplot as plt

train_loss_batch_i = np.arange(len(train_losses)) * log_frequency
eval_loss_batch_i = np.arange(len(eval_losses)) * eval_frequency
plt.plot(train_loss_batch_i, train_losses, "--o", label="Train Loss")
plt.plot(eval_loss_batch_i, eval_losses, "--o", label="Eval Loss")
plt.xlabel("Batch number")
plt.ylabel("Loss")
plt.title("Loss over batches")
plt.legend()
plt.grid()
plt.show()

In [10]:
model = MyGPT(d_model=768, nhead=12, num_layers=12, dim_feedforward=3072, device=device)
model.load_state_dict(torch.load("model_weights.pth", map_location=device))

for _ in range(10):
    print_stream(model=model, tokenizer=tokenizer, prompt="He first", max_length=50, prob_threshold=0.95, temperature=0.3)
    print("", flush=True)

He first met the young woman in the early 1960s , when she was still a teenager , and she became a professional in the early 1960s . She was soon joined by her sister , who became her husband 's manager and her mother . She was the

He first appeared in the media in the second issue of the magazine , in October , the first issue of the magazine , and in the first issue of the magazine , she was the first to be published in the magazine . She was also the first female editor of

He first appeared in the first season of the American television series The Simpsons , where he was played by John Frasier . Frasier 's role in the episode was based on the character of Homer , a former member of the show 's staff .

He first met with the British Army on 9 March 1917 , and was accepted into the Royal Australian Air Force ( RAAF ) on 10 March . Posted to the Middle East in April 1915 , he was posted to No. 1 Squadron , flying Sopwith Camels

He first appeared in the second season of the American 