In [None]:
import torch

# device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cuda"

torch.manual_seed(0)
torch.cuda.manual_seed_all(0)

In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorForLanguageModeling
from torch.utils.data import DataLoader

# 1. Load the Rotten Tomatoes dataset
dataset = load_dataset("rotten_tomatoes")

# 2. GPT‑2 tokenizer, set pad token = eos
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

MAX_LENGTH = 1024


def tokenize_function(examples):
    # no fixed padding here, just truncate+tokenize
    tok = tokenizer(
        examples["text"],
        padding=False,
        truncation=True,
        max_length=MAX_LENGTH + 1
    )
    # drop the last token for inputs, drop the first for labels *later*
    tok["input_ids"] = [ids[:-1] for ids in tok["input_ids"]]
    tok["attention_mask"] = [mask[:-1] for mask in tok["attention_mask"]]
    return tok


# apply tokenization
tokenized = dataset.map(tokenize_function, batched=True, remove_columns=["text"])

# 3. only set format for input_ids & attention_mask
for split in ["train", "validation", "test"]:
    tokenized[split].set_format(
        type="torch",
        columns=["input_ids", "attention_mask"]
    )

# 4. DataCollatorForLanguageModeling will:
#    - pad input_ids & attention_mask dynamically
#    - create `labels = input_ids` for causal LM (mlm=False)
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,
    mlm_probability=0.0
)

BATCH_SIZE = 32
train_dataloader = DataLoader(tokenized["train"], batch_size=BATCH_SIZE, shuffle=True, collate_fn=data_collator)
valid_dataloader = DataLoader(tokenized["validation"], batch_size=BATCH_SIZE, shuffle=False, collate_fn=data_collator)
test_dataloader = DataLoader(tokenized["test"], batch_size=BATCH_SIZE, shuffle=False, collate_fn=data_collator)


In [None]:
import torch
from torch import nn


class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.0):
        super().__init__()
        # Linear projections for Q, K, V
        self.queries = nn.LazyLinear(out_features=embed_dim)
        self.keys = nn.LazyLinear(out_features=embed_dim)
        self.values = nn.LazyLinear(out_features=embed_dim)

        # Multi-head attention block (batch_first=True means inputs shape [batch, seq, embed_dim])
        self.att_block = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)

        # Feedforward network
        self.feedforward = nn.Sequential(
            nn.LazyLinear(out_features=embed_dim * 4),
            nn.LeakyReLU(),
            nn.LazyLinear(out_features=embed_dim)
        )

        # Layer normalization layers
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)

        # Dropout layer for regularization
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, key_padding_mask=None):
        # Self-attention sub-layer with residual connection and normalization.
        residual = x

        # Linear projections for Q, K, V.
        q, k, v = self.queries(x), self.keys(x), self.values(x)

        seq_len = x.size(1)
        causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=x.device)).bool()
        attn_out, _ = self.att_block(
            q, k, v,
            attn_mask=~causal_mask,  # mask future tokens
            key_padding_mask=key_padding_mask  # mask pads
        )

        # Apply multi-head attention.
        att_output, _ = self.att_block(q, k, v)
        att_output = self.dropout(att_output)

        # Add residual and normalize.
        x = self.norm1(residual + att_output)

        # Feedforward sub-layer with residual connection and normalization.
        residual2 = x
        ff_output = self.feedforward(x)
        ff_output = self.dropout(ff_output)
        x = self.norm2(residual2 + ff_output)

        return x


In [None]:
import torch
from torch import nn


class Decoder(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.decoder = nn.Sequential(
            nn.LazyLinear(out_features=vocab_size)
        )

    def forward(self, attended_sequence):
        outputs = self.decoder(attended_sequence)
        return outputs

In [None]:
import torch
from torch import nn


class Model(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, num_transformer_blocks, max_sequence_length):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.positional_embedding = nn.Embedding(max_sequence_length, embed_dim)

        # Create a ModuleList of transformer blocks.
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(embed_dim=embed_dim, num_heads=num_heads)
            for _ in range(num_transformer_blocks)
        ])

        self.decoder = Decoder(vocab_size=vocab_size)

    def forward(self, input_ids, attention_mask):
        batch_size, seq_length = input_ids.shape

        # Embed tokens.
        embedded_sequence = self.embedding(input_ids)  # Shape: [batch_size, seq_length, embed_dim]

        # Create positional embeddings.
        positions = torch.arange(seq_length, device=embedded_sequence.device).unsqueeze(0).expand(batch_size,
                                                                                                  seq_length)
        pos_embeds = self.positional_embedding(positions)

        # Sum token and positional embeddings.
        x = embedded_sequence + pos_embeds

        # Pass through the stacked transformer blocks.
        for block in self.transformer_blocks:
            x = block(x, key_padding_mask=(attention_mask == 0))

        # Decoder projection to vocabulary logits.
        predictions = self.decoder(x)
        return predictions


In [None]:
VOCAB_SIZE = len(tokenizer.get_vocab()),
EMBED_DIM = 768,
NUM_HEADS = 12,
NUM_TRANSFORMER_BLOCKS = 12,
MAX_SEQ_LEN = MAX_LENGTH,

model = Model(
    vocab_size=len(tokenizer.get_vocab()),
    embed_dim=768,
    num_heads=12,
    num_transformer_blocks=12,
    max_sequence_length=1024,
).to(device=device)

In [None]:
import torch
from torch import nn
from tqdm import tqdm

EPOCHS = 15
LEARNING_RATE = 0.00001

LOSS_FN = nn.CrossEntropyLoss()
OPTIMIZER = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

train_losses = []
validation_losses = []
test_losses = []

for epoch in range(EPOCHS):
    print(f"E {(epoch + 1)}/{EPOCHS} - {((epoch + 1) / EPOCHS) * 100:.3f}%")

    train_loss = 0

    model.train()

    for index, batch in enumerate(tqdm(train_dataloader, desc=f"\tTrain")):
        model_input = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        expected_output = batch["labels"].to(device)

        output = model(model_input, attention_mask)

        # lbls = batch["labels"]
        # print("labels min/max:", lbls.min().item(), lbls.max().item())

        loss = LOSS_FN(output.permute(0, 2, 1), expected_output)
        train_loss += loss.item()
        loss.backward()
        OPTIMIZER.step()
        OPTIMIZER.zero_grad()

    train_loss /= len(train_dataloader)
    train_losses.append(train_loss)
    print(f"\tTrain Loss: {train_loss:.5f}")

    model.eval()

    validation_loss = 0

    with torch.no_grad():
        for index, batch in enumerate(tqdm(valid_dataloader, desc=f"\tValidation")):
            model_input = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            expected_output = batch["labels"].to(device)

            output = model(model_input, attention_mask)
            loss = LOSS_FN(output.permute(0, 2, 1), expected_output)
            validation_loss += loss.item()

            if index == 0:
                # probabilities = torch.softmax(output, dim=-1)
                model_letters = torch.argmax(output, dim=-1)
                for i, seq in enumerate(model_letters):
                    if i == 0:
                        decoded_text = tokenizer.decode(seq.tolist(), skip_special_tokens=False)
                        print(decoded_text)
                        # expected_output = tokenizer.decode(expected_output[i].tolist(), skip_special_tokens=False)
                        # print(expected_output)
                        decoded_expected_output = tokenizer.decode(
                            [token for token in expected_output[i].tolist() if token != -100],
                            skip_special_tokens=False
                        )
                        print(decoded_expected_output)

    validation_loss /= len(train_dataloader)
    validation_losses.append(validation_loss)
    print(f"\tValidation Loss: {validation_loss:.5f}")

model.eval()

with torch.no_grad():
    for index, batch in enumerate(tqdm(test_dataloader, desc=f"\tTest")):
        model_input = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        expected_output = batch["labels"].to(device)

        output = model(model_input, attention_mask)
        loss = LOSS_FN(output.permute(0, 2, 1), expected_output)

        test_losses.append(loss.item())

        # probabilities = torch.softmax(output, dim=-1)
        if index == 0:
            model_letters = torch.argmax(output, dim=-1)
            for seq in model_letters:
                decoded_text = tokenizer.decode(seq.tolist(), skip_special_tokens=False)
                print(decoded_text)

In [None]:
import torch
from datetime import datetime

# Unique filename using timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"models/vocabsize_{VOCAB_SIZE}_embeddim_{EMBED_DIM}_numheads_{NUM_HEADS}_numtransformerblocks_{NUM_TRANSFORMER_BLOCKS}_maxseqlen_{MAX_SEQ_LEN}_timestamp_{timestamp}.pt"

torch.save(model.state_dict(), filename)

print(f"Model saved as {filename}")
state_dict = torch.load(filename)
print(state_dict.keys())

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(10, 5))

axes = axes.flatten()

axes[0].plot(train_losses, label="Train loss")
axes[0].plot(validation_losses, label="Validation loss")
axes[0].set_title("Loss during training")
axes[0].legend()

axes[1].plot(test_losses, label="Test loss")
axes[1].set_title("Post training loss")
axes[1].legend()

plt.tight_layout()

plt.show()