In [62]:
import torch

# device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cuda"

torch.manual_seed(0)
torch.cuda.manual_seed_all(0)

In [63]:
from datasets import load_dataset
from transformers import AutoTokenizer
from torch.utils.data import DataLoader

# 1. Load the Rotten Tomatoes dataset (with train, validation, test splits)
dataset = load_dataset("rotten_tomatoes")

# 2. Use the GPT-2 tokenizer for a causal LM.
# GPT-2 doesn't have a pad token by default, so set the pad token to the EOS token.
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

MAX_LENGTH = 64  # maximum context length for your model


def tokenize_function(examples):
    # Tokenize to MAX_LENGTH + 1 tokens to allow for shifting.
    tokenized = tokenizer(
        examples["text"],
        padding="max_length",
        truncation=True,
        max_length=MAX_LENGTH + 1
    )
    # Save original tokenized ids
    original_ids = tokenized["input_ids"]
    tokenized["input_ids"] = [ids[:-1] for ids in original_ids]
    tokenized["labels"] = [ids[1:] for ids in original_ids]
    tokenized["attention_mask"] = [mask[:-1] for mask in tokenized["attention_mask"]]
    return tokenized



# Apply tokenization to all splits and remove the raw text column.
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])

# 3. Set format for PyTorch
for split in ["train", "validation", "test"]:
    tokenized_dataset[split].set_format(
        type="torch",
        columns=["input_ids", "attention_mask", "labels"]
    )

# 4. Create DataLoaders for each split.
BATCH_SIZE = 16

train_dataloader = DataLoader(tokenized_dataset["train"], batch_size=BATCH_SIZE, shuffle=True)
validation_dataloader = DataLoader(tokenized_dataset["validation"], batch_size=BATCH_SIZE)
test_dataloader = DataLoader(tokenized_dataset["test"], batch_size=BATCH_SIZE)

# Your DataLoaders are now ready for use in your training loop for next-token prediction.


In [64]:
import torch
from torch import nn

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.0):
        super().__init__()
        # Linear projections for Q, K, V
        self.queries = nn.LazyLinear(out_features=embed_dim)
        self.keys = nn.LazyLinear(out_features=embed_dim)
        self.values = nn.LazyLinear(out_features=embed_dim)

        # Multi-head attention block (batch_first=True means inputs shape [batch, seq, embed_dim])
        self.att_block = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)

        # Feedforward network
        self.feedforward = nn.Sequential(
            nn.LazyLinear(out_features=embed_dim * 4),
            nn.LeakyReLU(),
            nn.LazyLinear(out_features=embed_dim)
        )

        # Layer normalization layers
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)

        # Dropout layer for regularization
        self.dropout = nn.Dropout(dropout)

    def forward(self, token_sequence):
        # Self-attention sub-layer with residual connection and normalization.
        residual = token_sequence

        # Linear projections for Q, K, V.
        q = self.queries(token_sequence)
        k = self.keys(token_sequence)
        v = self.values(token_sequence)

        # Apply multi-head attention.
        att_output, _ = self.att_block(q, k, v)
        att_output = self.dropout(att_output)

        # Add residual and normalize.
        x = self.norm1(residual + att_output)

        # Feedforward sub-layer with residual connection and normalization.
        residual2 = x
        ff_output = self.feedforward(x)
        ff_output = self.dropout(ff_output)
        x = self.norm2(residual2 + ff_output)

        return x


In [65]:
import torch
from torch import nn


class Decoder(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.decoder = nn.Sequential(
            nn.LazyLinear(out_features=vocab_size)
        )

    def forward(self, attended_sequence):
        outputs = self.decoder(attended_sequence)
        return outputs

In [66]:
import torch
from torch import nn

class Model(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, num_transformer_blocks, max_sequence_length):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.positional_embedding = nn.Embedding(max_sequence_length, embed_dim)

        # Create a ModuleList of transformer blocks.
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(embed_dim=embed_dim, num_heads=num_heads)
            for _ in range(num_transformer_blocks)
        ])

        self.decoder = Decoder(vocab_size=vocab_size)

    def forward(self, tokenized_sequence):
        """
        :param tokenized_sequence: Tensor of size [batch_size, sequence_length] with token IDs.
        :return: Logits of shape [batch_size, sequence_length, vocab_size].
        """
        batch_size, seq_length = tokenized_sequence.shape

        # Embed tokens.
        embedded_sequence = self.embedding(tokenized_sequence)  # Shape: [batch_size, seq_length, embed_dim]

        # Create positional embeddings.
        positions = torch.arange(seq_length, device=embedded_sequence.device).unsqueeze(0).expand(batch_size, seq_length)
        pos_embeds = self.positional_embedding(positions)

        # Sum token and positional embeddings.
        x = embedded_sequence + pos_embeds

        # Pass through the stacked transformer blocks.
        for block in self.transformer_blocks:
            x = block(x)

        # Decoder projection to vocabulary logits.
        predictions = self.decoder(x)
        return predictions


In [67]:
model = Model(
    vocab_size=len(tokenizer.get_vocab()),
    embed_dim=768,
    num_heads=8,
    num_transformer_blocks=10,
    max_sequence_length=MAX_LENGTH,
).to(device=device)

In [None]:
import torch
from torch import nn
from tqdm import tqdm

EPOCHS = 10
LEARNING_RATE = 0.001

LOSS_FN = nn.CrossEntropyLoss()
OPTIMIZER = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

train_losses = []
validation_losses = []
test_losses = []

for epoch in range(EPOCHS):
    print(f"E {(epoch + 1)}/{EPOCHS} - {((epoch + 1) / EPOCHS) * 100:.3f}%")

    train_loss = 0

    model.train()

    for index, batch in enumerate(tqdm(train_dataloader, desc=f"\tTrain")):
        model_input = batch["input_ids"].to(device)
        expected_output = batch["labels"].to(device)

        output = model(model_input)
        loss = LOSS_FN(output.permute(0, 2, 1), expected_output)
        train_loss += loss.item()
        loss.backward()
        OPTIMIZER.step()
        OPTIMIZER.zero_grad()

    train_loss /= len(train_dataloader)
    train_losses.append(train_loss)

    validation_loss = 0

    with torch.no_grad():
        for index, batch in enumerate(tqdm(validation_dataloader, desc=f"\tValidation")):
            model_input = batch["input_ids"].to(device)
            expected_output = batch["labels"].to(device)

            output = model(model_input)
            loss = LOSS_FN(output.permute(0, 2, 1), expected_output)
            validation_loss += loss.item()

            if index == 0:
                # probabilities = torch.softmax(output, dim=-1)
                model_letters = torch.argmax(output, dim=-1)
                for seq in model_letters:
                    decoded_text = tokenizer.decode(seq.tolist(), skip_special_tokens=False)
                    print(decoded_text)

    validation_loss /= len(train_dataloader)
    validation_losses.append(validation_loss)



with torch.no_grad():
    for index, batch in enumerate(tqdm(test_dataloader, desc=f"\tTest")):
        model_input = batch["input_ids"].to(device)
        expected_output = batch["labels"].to(device)

        output = model(model_input)
        loss = LOSS_FN(output.permute(0, 2, 1), expected_output)

        test_losses.append(loss.item())

        # probabilities = torch.softmax(output, dim=-1)
        model_letters = torch.argmax(output, dim=-1)
        for seq in model_letters:
            decoded_text = tokenizer.decode(seq.tolist(), skip_special_tokens=False)
            print(decoded_text)



E 1/10 - 10.000%


	Train:  13%|█▎        | 70/534 [00:04<00:34, 13.48it/s]

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(10, 5))

axes = axes.flatten()

axes[0].plot(train_losses, label="Train loss")
axes[0].plot(validation_losses, label="Validation loss")
axes[0].set_title("Loss during training")
axes[0].legend()

axes[1].plot(test_losses, label="Test loss")
axes[1].set_title("Post training loss")
axes[1].legend()

plt.tight_layout()

plt.show()