In [None]:
import torch

# device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cuda"

torch.manual_seed(0)
torch.cuda.manual_seed_all(0)

In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorForLanguageModeling
from torch.utils.data import DataLoader

# 1. Load the Rotten Tomatoes dataset
dataset = load_dataset("rotten_tomatoes")

# 2. GPT‑2 tokenizer, set pad token = eos
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

MAX_LENGTH = 128


def tokenize_function(examples):
    # no fixed padding here, just truncate+tokenize
    tok = tokenizer(
        examples["text"],
        padding=False,
        truncation=True,
        max_length=MAX_LENGTH + 1
    )
    # drop the last token for inputs, drop the first for labels *later*
    tok["input_ids"] = [ids[:-1] for ids in tok["input_ids"]]
    tok["attention_mask"] = [mask[:-1] for mask in tok["attention_mask"]]
    return tok


# apply tokenization
tokenized = dataset.map(tokenize_function, batched=True, remove_columns=["text"])

# 3. only set format for input_ids & attention_mask
for split in ["train", "validation", "test"]:
    tokenized[split].set_format(
        type="torch",
        columns=["input_ids", "attention_mask"]
    )


# 4. DataCollatorForLanguageModeling will:
#    - pad input_ids & attention_mask dynamically
#    - create `labels = input_ids` for causal LM (mlm=False)
class CausalDataCollator:
    def __init__(self, tokenizer, **kwargs):
        self.base = DataCollatorForLanguageModeling(
            tokenizer=tokenizer, mlm=False, **kwargs
        )

    def __call__(self, examples):
        batch = self.base(examples)
        inputs = batch["input_ids"]
        masks = batch["attention_mask"]
        # shift for causal LM
        batch["input_ids"] = inputs[:, :-1]
        batch["attention_mask"] = masks[:, :-1]
        labels = inputs[:, 1:].clone()
        # map pad→-100 so CE ignores both pad and any -100
        labels = labels.masked_fill(labels == tokenizer.pad_token_id, -100)
        batch["labels"] = labels
        return batch


data_collator = CausalDataCollator(
    tokenizer=tokenizer,
    mlm_probability=0.0
)

BATCH_SIZE = 128
train_dataloader = DataLoader(tokenized["train"], batch_size=BATCH_SIZE, shuffle=True, collate_fn=data_collator)
valid_dataloader = DataLoader(tokenized["validation"], batch_size=BATCH_SIZE, shuffle=False, collate_fn=data_collator)
test_dataloader = DataLoader(tokenized["test"], batch_size=BATCH_SIZE, shuffle=False, collate_fn=data_collator)


In [None]:
import torch
from torch import nn


class DecoderBlocks(nn.Module):
    def __init__(self, embed_dim, num_heads, num_blocks, dropout=0.1, max_seq_len=1024):
        super().__init__()
        # one GPT‐style block = encoder layer + causal mask
        layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=4 * embed_dim,
            dropout=dropout,
            activation="gelu",
            batch_first=True,
            norm_first=True  # pre‐LN like GPT‑2
        )
        # stack num_blocks of them
        self.blocks = nn.ModuleList([layer for _ in range(num_blocks)])

        causal = torch.triu(torch.ones(max_seq_len, max_seq_len), diagonal=1).bool()
        self.register_buffer("causal_mask", causal)

    def forward(self, x, padding_mask=None):
        """
        x:     [batch, seq_len, embed_dim]
        padding_mask:[batch, seq_len] bool mask (True for pads)
        """
        seq_len = x.size(1)
        # slice out the relevant [seq_len, seq_len] portion
        casual_mask = self.causal_mask[:seq_len, :seq_len]  # float mask with -inf where masked

        for block in self.blocks:
            x = block(
                src=x,
                src_mask=casual_mask,  # causal mask
                src_key_padding_mask=padding_mask  # pad mask
            )
        return x


In [None]:
import torch
from torch import nn


class Decoder(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.decoder = nn.Sequential(
            nn.LazyLinear(out_features=vocab_size)
        )

    def forward(self, attended_sequence):
        outputs = self.decoder(attended_sequence)
        return outputs

In [None]:
import torch
from torch import nn


class Model(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, num_transformer_blocks, max_sequence_length):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.positional_embedding = nn.Embedding(max_sequence_length, embed_dim)

        self.transformer_blocks = DecoderBlocks(
            embed_dim=embed_dim,
            num_heads=num_heads,
            num_blocks=num_transformer_blocks,
            dropout=0.0,
            max_seq_len=max_sequence_length,
        )

        self.decoder = Decoder(vocab_size=vocab_size)

    def forward(self, input_ids, attention_mask):
        batch_size, seq_length = input_ids.shape

        # Embed tokens.
        embedded_sequence = self.embedding(input_ids)  # Shape: [batch_size, seq_length, embed_dim]

        # Create positional embeddings.
        position_ids = torch.arange(seq_length, device=embedded_sequence.device).unsqueeze(0).expand(batch_size,
                                                                                                     seq_length)
        # pos_embeds = self.positional_embedding(position_ids).unsqueeze(0).expand(batch_size, -1, -1)
        pos_embeds = self.positional_embedding(position_ids)

        # Sum token and positional embeddings.
        x = embedded_sequence + pos_embeds

        pad_mask = attention_mask == 0

        x = self.transformer_blocks(x, pad_mask)

        # Decoder projection to vocabulary logits.
        predictions = self.decoder(x)
        return predictions


In [None]:
VOCAB_SIZE = len(tokenizer.get_vocab()),
EMBED_DIM = 32,
NUM_HEADS = 1,
NUM_TRANSFORMER_BLOCKS = 1,
MAX_SEQ_LEN = 128,

model = Model(
    vocab_size=len(tokenizer.get_vocab()),
    embed_dim=32,
    num_heads=1,
    num_transformer_blocks=1,
    max_sequence_length=128,
).to(device=device)

In [None]:
# import torch
#
# model_path = "models/vocabsize_(50257,)_embeddim_(768,)_numheads_(12,)_numtransformerblocks_(12,)_maxseqlen_(1024,)_E660_timestamp_20250406_070150.pt"
#
# state_dict = torch.load(model_path, map_location=device)
# model.load_state_dict(state_dict)

In [None]:
import torch
from datetime import datetime


def save_model(curr_epoch):
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = f"models/vocabsize_{VOCAB_SIZE}_embeddim_{EMBED_DIM}_numheads_{NUM_HEADS}_numtransformerblocks_{NUM_TRANSFORMER_BLOCKS}_maxseqlen_{MAX_SEQ_LEN}_E{curr_epoch}_timestamp_{timestamp}.pt"

    torch.save(model.state_dict(), filename)

    print(f"Model saved as {filename}")

In [None]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torch.nn.utils import clip_grad_norm_
from transformers import get_cosine_schedule_with_warmup
from torch.amp import autocast, GradScaler

scaler = GradScaler()

EPOCHS = 100
LEARNING_RATE = 0.01

LOSS_FN = nn.CrossEntropyLoss()
OPTIMIZER = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

num_epochs = EPOCHS
num_training_steps = num_epochs * len(train_dataloader)
num_warmup_steps = int(0.1 * num_training_steps)  # 10% warmup

SCHEDULER = get_cosine_schedule_with_warmup(
    OPTIMIZER,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps,
    num_cycles=0.5  # half‑cosine
)

train_losses = []
validation_losses = []
test_losses = []
lr_values = []

for epoch in range(EPOCHS):
    print(f"E {(epoch + 1)}/{EPOCHS} - {((epoch + 1) / EPOCHS) * 100:.3f}%")

    if (epoch + 1) % 10 == 0:
        save_model(curr_epoch=epoch + 1)

    train_loss = 0

    model.train()
    OPTIMIZER.zero_grad()

    train_bar = tqdm(train_dataloader, desc=f"\tTrain")
    for index, batch in enumerate(train_bar):
        current_lr = SCHEDULER.get_last_lr()[0]
        lr_values.append(current_lr)
        if index % 10 == 0:
            train_bar.set_description(f"\tTrain, LR: {current_lr:.10f}")

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)
        # print(f"Input IDs: {input_ids}")
        # print(f"Labels: {labels}")
        # output = model(model_input, attention_mask)

        with autocast("cuda"):
            logits = model(input_ids, attention_mask)  # [B, S, V]
            # shift_logits = logits[:, :-1, :].contiguous()  # predict next token
            # shift_labels = model_input[:, 1:].contiguous()  # the true next tokens

            loss = LOSS_FN(
                logits.view(-1, len(tokenizer.get_vocab())),
                labels.view(-1)
            )
        # lbls = batch["labels"]
        # print("labels min/max:", lbls.min().item(), lbls.max().item())

        # loss = LOSS_FN(output.permute(0, 2, 1), expected_output)
        train_loss += loss.item()
        scaler.scale(loss).backward()
        scaler.unscale_(OPTIMIZER)

        clip_grad_norm_(model.parameters(), max_norm=1.0)

        scaler.step(OPTIMIZER)
        scaler.update()
        SCHEDULER.step()
        OPTIMIZER.zero_grad()

        if index == 0:
            # probabilities = torch.softmax(output, dim=-1)
            model_letters = torch.argmax(logits, dim=-1)
            for i, seq in enumerate(model_letters):
                if i == 0:
                    decoded_text = tokenizer.decode(seq.tolist(), skip_special_tokens=False)
                    print(f"Train output: {decoded_text}")
                    # expected_output = tokenizer.decode(expected_output[i].tolist(), skip_special_tokens=False)
                    # print(expected_output)
                    decoded_expected_output = tokenizer.decode(
                        [token for token in labels[i].tolist() if token != -100],
                        skip_special_tokens=False
                    )
                    print(f"Train expected: {decoded_expected_output}")

    train_loss /= len(train_dataloader)
    train_losses.append(train_loss)
    print(f"\tTrain Loss: {train_loss:.5f}")
    # current_lr = SCHEDULER.get_last_lr()[0]
    # train_bar.set_description(f"\tTrain, LR: {current_lr:.5f}, Loss: {train_loss:.5f}")

    model.eval()

    validation_loss = 0

    with torch.no_grad(), autocast("cuda"):
        valid_bar = tqdm(valid_dataloader, desc=f"\tValidation")
        for index, batch in enumerate(valid_bar):

            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            # output = model(model_input, attention_mask)

            logits = model(input_ids, attention_mask)  # [B, S, V]
            # shift_logits = logits[:, :-1, :].contiguous()  # predict next token
            # shift_labels = input_ids[:, 1:].contiguous()  # the true next tokens

            loss = LOSS_FN(
                logits.view(-1, len(tokenizer.get_vocab())),
                labels.view(-1)
            )

            # loss = LOSS_FN(output.permute(0, 2, 1), expected_output)
            validation_loss += loss.item()

            if index == 0:
                # probabilities = torch.softmax(output, dim=-1)
                model_letters = torch.argmax(logits, dim=-1)
                for i, seq in enumerate(model_letters):
                    if i == 0:
                        decoded_text = tokenizer.decode(seq.tolist(), skip_special_tokens=False)
                        print(f"Valid output: {decoded_text}")
                        # expected_output = tokenizer.decode(expected_output[i].tolist(), skip_special_tokens=False)
                        # print(expected_output)
                        decoded_expected_output = tokenizer.decode(
                            [token for token in labels[i].tolist() if token != -100],
                            skip_special_tokens=False
                        )
                        print(f"Valid expected: {decoded_expected_output}")

    validation_loss /= len(valid_dataloader)
    validation_losses.append(validation_loss)
    print(f"\tValidation Loss: {validation_loss:.5f}")
    # valid_bar.set_description(f"\tValidation, Loss: {validation_loss:.5f}")

model.eval()

test_loss = 0

with torch.no_grad(), autocast("cuda"):
    test_bar = tqdm(test_dataloader, desc=f"\tTest")
    for index, batch in enumerate(test_bar):

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        # output = model(model_input, attention_mask)
        # loss = LOSS_FN(output.permute(0, 2, 1), expected_output)

        logits = model(input_ids, attention_mask)  # [B, S, V]
        # shift_logits = logits[:, :-1, :].contiguous()  # predict next token
        # shift_labels = input_ids[:, 1:].contiguous()  # the true next tokens

        loss = LOSS_FN(
            logits.view(-1, len(tokenizer.get_vocab())),
            labels.view(-1)
        )

        test_loss += loss.item()
        test_losses.append(loss.item())

        # probabilities = torch.softmax(output, dim=-1)
        if index == 0:
            model_letters = torch.argmax(logits, dim=-1)
            for seq in model_letters:
                decoded_text = tokenizer.decode(seq.tolist(), skip_special_tokens=False)
                print(decoded_text)

    test_loss /= len(test_dataloader)
    print(f"\tTest Loss: {test_loss:.5f}")
    # test_bar.set_description(f"\tTest, Loss: {test_loss:.5f}")

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 2, figsize=(10, 10))

axes = axes.flatten()

axes[0].plot(train_losses, label="Train loss")
axes[0].plot(validation_losses, label="Validation loss")
axes[0].set_title("Loss during training")
axes[0].legend()

axes[1].plot(test_losses, label="Test loss")
axes[1].set_title("Post training loss")
axes[1].legend()

axes[2].plot(lr_values, label="LR values")
axes[2].set_title("Learning rate values over time")
axes[2].legend()

plt.tight_layout()

plt.show()