## Import statements

In [None]:
import torch
from torch import nn
import inspect
import logging
from transformers import PreTrainedTokenizerFast
from torch.nn.utils import clip_grad_norm_
from transformers import get_cosine_schedule_with_warmup
from torch.amp import autocast, GradScaler
from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorForLanguageModeling
from torch.utils.data import DataLoader

logging.basicConfig(
    level=logging.DEBUG,
    format='%(asctime)s - %(levelname)s - %(message)s',
)

torch.manual_seed(0)
torch.cuda.manual_seed_all(0)

device = "cuda"

## Decoder block architecture

In [None]:
class CasualMaskedDecoderBlocks(nn.Module):
    def __init__(
            self,
            embed_dim: int,
            num_heads: int,
            num_blocks: int,
            max_seq_len: int,
            dropout: float = 0.0,
            activation_function: str = "gelu",
            ffw_network_multiplier: int = 4,
    ):
        """
        :param embed_dim: Embedding dimension of the tokens in the sequence.
        :param num_heads: Number of heads in each decoder block.
        :param num_blocks: Number of decoder blocks.
        :param dropout: Probability of dropout.
        :param max_seq_len: Maximum expected sequence length.
        :param ffw_network_multiplier: multiplier for embed_dim to get the dimensionality for the feedforward network.
        """
        super().__init__()

        assert embed_dim % num_heads == 0, f'{self.__class__.__name__}.{inspect.currentframe().f_code.co_name}: embed_dim ({embed_dim}) is not divisible by num_heads ({num_heads})'

        assert activation_function in ("relu",
                                       "gelu"), f'{self.__class__.__name__}.{inspect.currentframe().f_code.co_name}: activation_function expected to be "relu"/"gelu", received "{activation_function}" instead'

        self.max_seq_len = max_seq_len
        self.embed_dim = embed_dim

        block = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=ffw_network_multiplier * embed_dim,
            dropout=dropout,
            activation=activation_function,
            batch_first=True,
            norm_first=True
        )

        self.blocks = nn.ModuleList([block for _ in range(num_blocks)])

        causal = torch.triu(torch.ones(max_seq_len, max_seq_len), diagonal=1).bool()
        self.register_buffer("causal_mask", causal)

    def forward(self, tok_seq, padding_mask=None) -> torch.Tensor:
        """
        :param tok_seq: torch.Tensor of size [batch, seq_len, embed_dim] representing the pre-attended token sequence.
        :param padding_mask: Bool mask of size [batch, seq_len]; True for padding tokens.
        :return: Attended token sequence after num_blocks amount of decoder blocks.
        """
        assert tok_seq.dim() == 3, f'{self.__class__.__name__}.{inspect.currentframe().f_code.co_name}: tok_seq tensor should be of size [batch, seq_len, embed_dim], received a tensor with {tok_seq.dim()} dimensions instead'
        batch_size, seq_len, embed_dim = tok_seq.size()
        assert seq_len <= self.max_seq_len, f'{self.__class__.__name__}.{inspect.currentframe().f_code.co_name}: length of inputted sequence ({seq_len}) exceeds maximum expected sequence length ({self.max_seq_len})'
        assert embed_dim == self.embed_dim, f'{self.__class__.__name__}.{inspect.currentframe().f_code.co_name}: received embed_dim ({embed_dim}) does not match the expected embed_dim ({self.embed_dim})'
        if padding_mask is not None:
            assert padding_mask.dim() == 2, f'{self.__class__.__name__}.{inspect.currentframe().f_code.co_name}: padding_mask tensor should be of size [batch, seq_len], received a tensor with {padding_mask.dim()} dimensions instead'
            pm_batch_size, pm_seq_len = padding_mask.size()
            assert batch_size == pm_batch_size and seq_len == pm_seq_len, f'{self.__class__.__name__}.{inspect.currentframe().f_code.co_name}: dimension mismatch between tok_seq ([{batch_size},{seq_len},{embed_dim}]) and padding_mask ([{pm_batch_size}, {pm_seq_len}])'

        casual_mask = self.causal_mask[:seq_len, :seq_len]

        for block in self.blocks:
            tok_seq = block(
                src=tok_seq,
                src_mask=casual_mask,
                src_key_padding_mask=padding_mask
            )

        return tok_seq

## Attended token decoder

In [None]:
class AttendedTokenDecoder(nn.Module):
    def __init__(
            self,
            embed_dim: int,
            vocabulary_size: int,
    ):
        """
        :param embed_dim:
        :param vocabulary_size:
        """
        super().__init__()

        self.embed_dim = embed_dim
        self.vocabulary_size = vocabulary_size

        self.decoder = nn.Sequential(
            nn.Linear(in_features=embed_dim, out_features=vocabulary_size)
        )

    def forward(self, att_tok_seq):
        """
        :param att_tok_seq:
        :return:
        """
        assert att_tok_seq.dim() == 3, f'{self.__class__.__name__}.{inspect.currentframe().f_code.co_name}: att_tok_seq tensor should be of size [batch, seq_len, embed_dim], received a tensor with {att_tok_seq.dim()} dimensions instead'

        batch_size, seq_len, embed_dim = att_tok_seq.size()

        assert embed_dim == self.embed_dim, f'{self.__class__.__name__}.{inspect.currentframe().f_code.co_name}: received embed_dim ({embed_dim}) does not match expected embed_dim ({self.embed_dim})'

        token_logits = self.decoder(att_tok_seq)

        return token_logits

## Combined GPT model

In [None]:
class GPTModel(nn.Module):
    def __init__(
            self,
            embed_dim: int,
            num_heads: int,
            num_blocks: int,
            max_seq_len: int,
            vocab_size: int,
            tokenizer,
            dropout: float = 0.0,
            activation_function: str = "gelu",
            ffw_network_multiplier: int = 4,
    ):
        """
        :param embed_dim:
        :param num_heads:
        :param num_blocks:
        :param max_seq_len:
        :param vocab_size:
        :param tokenizer:
        :param dropout:
        :param activation_function:
        :param ffw_network_multiplier:
        """
        super().__init__()

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.num_blocks = num_blocks
        self.max_seq_len = max_seq_len
        self.vocab_size = vocab_size
        self.dropout = dropout
        self.activation_function = activation_function
        self.ffw_network_multiplier = ffw_network_multiplier

        self.pad_token_id = tokenizer.pad_token_id

        if embed_dim < self.vocab_size: logging.warning(
            f'{self.__class__.__name__}.{inspect.currentframe().f_code.co_name}: embed_dim is smaller than vocabulary_size, consider increasing embed_dim')

        self.token_embedding = nn.Embedding(vocab_size, embed_dim)
        self.positional_embedding = nn.Embedding(max_seq_len, embed_dim)

        self.decoder_blocks = CasualMaskedDecoderBlocks(
            embed_dim=embed_dim,
            num_heads=num_heads,
            num_blocks=num_blocks,
            max_seq_len=max_seq_len,
            dropout=dropout,
            activation_function=activation_function,
            ffw_network_multiplier=ffw_network_multiplier,
        )

        self.decoder = AttendedTokenDecoder(
            embed_dim=embed_dim,
            vocabulary_size=vocab_size,
        )

    def forward(self, tokenized_sequence):
        """
        :param tokenized_sequence:
        :return:
        """
        padding_mask = tokenized_sequence.eq(self.pad_token_id)

        batch_size, seq_length = tokenized_sequence.size()

        embedded_sequence = self.token_embedding(tokenized_sequence)

        position_ids = (
            torch.arange(seq_length, device=embedded_sequence.device)
            .unsqueeze(0)
            .expand(batch_size, seq_length)
        )
        positional_embeddings = self.positional_embedding(position_ids)

        sequence = embedded_sequence + positional_embeddings

        sequence = self.decoder_blocks(sequence, padding_mask)

        logits = self.decoder(sequence)

        return logits

## Initialization

In [None]:
tokenizer_model_path = "../../saved_models/tokenizers/rotten_tomatoes_bpe_style"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_model_path)

EMBED_DIM = 768
NUM_HEADS = 8
NUM_BLOCKS = 8
MAX_SEQ_LENGTH = 256
VOCAB_SIZE = len(tokenizer.get_vocab())

model = GPTModel(
    embed_dim=EMBED_DIM,
    num_heads=NUM_HEADS,
    num_blocks=NUM_BLOCKS,
    max_seq_len=MAX_SEQ_LENGTH,
    vocab_size=VOCAB_SIZE,
    tokenizer=tokenizer
).to(device)

## Data loading

In [None]:
dataset = load_dataset("rotten_tomatoes")


def tokenize_function(examples):
    tok = tokenizer(
        examples["text"],
        padding=False,
        truncation=True,
        max_length=MAX_SEQ_LENGTH + 1
    )
    tok["input_ids"] = [ids[:-1] for ids in tok["input_ids"]]
    tok["attention_mask"] = [mask[:-1] for mask in tok["attention_mask"]]
    return tok


tokenized = dataset.map(tokenize_function, batched=True, remove_columns=["text"])

for split in ["train", "validation", "test"]:
    tokenized[split].set_format(
        type="torch",
        columns=["input_ids", "attention_mask"]
    )


class CausalDataCollator:
    def __init__(self, tokenizer, **kwargs):
        self.base = DataCollatorForLanguageModeling(
            tokenizer=tokenizer, mlm=False, **kwargs
        )

    def __call__(self, examples):
        batch = self.base(examples)
        inputs = batch["input_ids"]
        masks = batch["attention_mask"]
        batch["input_ids"] = inputs[:, :-1]
        batch["attention_mask"] = masks[:, :-1]
        labels = inputs[:, 1:].clone()
        labels = labels.masked_fill(labels == tokenizer.pad_token_id, -100)
        batch["labels"] = labels
        return batch


data_collator = CausalDataCollator(
    tokenizer=tokenizer,
    mlm_probability=0.0
)

BATCH_SIZE = 128
train_dataloader = DataLoader(tokenized["train"], batch_size=BATCH_SIZE, shuffle=True, collate_fn=data_collator)
valid_dataloader = DataLoader(tokenized["validation"], batch_size=BATCH_SIZE, shuffle=False, collate_fn=data_collator)
test_dataloader = DataLoader(tokenized["test"], batch_size=BATCH_SIZE, shuffle=False, collate_fn=data_collator)

## Training loop

In [None]:
NUM_EPOCHS = 10
LEARNING_RATE = 0.001

scaler = GradScaler()
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

num_training_steps = NUM_EPOCHS * len(train_dataloader)
num_warmup_steps = int(0.1 * num_training_steps)

scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps,
    num_cycles=0.5
)

train_losses = []
valid_losses = []
test_losses = []
learning_rates = []

model.train()
optimizer.zero_grad()

for epoch in range(NUM_EPOCHS):
    # print(f"E {(epoch + 1)}/{NUM_EPOCHS} - {((epoch + 1) / NUM_EPOCHS) * 100:.3f}%")
    logging.info(f"E {(epoch + 1)}/{NUM_EPOCHS} - {((epoch + 1) / NUM_EPOCHS) * 100:.3f}%")

    train_loss = 0

    for index, batch in enumerate(train_dataloader):
        last_lr = scheduler.get_last_lr()[0]
        learning_rates.append(last_lr)
        if (index + 1) % 50 == 0:
            # print(f"\tCurrent LR: {last_lr:10f}")
            logging.info(f"\tCurrent LR: {last_lr:10f}")

        input_ids = batch["input_ids"].to(device)
        # attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        with autocast("cuda"):
            logits = model(input_ids)

            loss = loss_fn(
                logits.view(-1, len(tokenizer.get_vocab())),
                labels.view(-1)
            )

        train_loss += loss.item()
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        # clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        optimizer.zero_grad()

        if index == 0:
            # probabilities = torch.softmax(output, dim=-1)
            model_letters = torch.argmax(logits, dim=-1)
            first_seq = model_letters[0]
            decoded_expected_output = tokenizer.decode(
                [token for token in labels[0].tolist() if token != -100],
                skip_special_tokens=False
            )
            decoded_text = tokenizer.decode(first_seq.tolist(), skip_special_tokens=False)
            print(f"\tTrain expected: {decoded_expected_output}")
            print(f"\tTrain output: {decoded_text}")

    train_loss /= len(train_dataloader)
    train_losses.append(train_loss)

    # print(f"\tTrain Loss: {train_loss:.5f}")
    logging.info(f"\tTrain Loss: {train_loss:.5f}")

    valid_loss = 0

    for index, batch in enumerate(valid_dataloader):
        input_ids = batch["input_ids"].to(device)
        # attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        with autocast("cuda"):
            logits = model(input_ids)

            loss = loss_fn(
                logits.view(-1, len(tokenizer.get_vocab())),
                labels.view(-1)
            )

        valid_loss += loss.item()

        if index == 0:
            # probabilities = torch.softmax(output, dim=-1)
            model_letters = torch.argmax(logits, dim=-1)
            first_seq = model_letters[0]
            decoded_expected_output = tokenizer.decode(
                [token for token in labels[0].tolist() if token != -100],
                skip_special_tokens=False
            )
            decoded_text = tokenizer.decode(first_seq.tolist(), skip_special_tokens=False)
            print(f"\tValid expected: {decoded_expected_output}")
            print(f"\tValid output: {decoded_text}")

    valid_loss /= len(valid_dataloader)
    valid_losses.append(valid_loss)

    # print(f"\tValid Loss: {valid_loss:.5f}")
    logging.info(f"\tValid Loss: {valid_loss:.5f}")

test_loss = 0

for index, batch in enumerate(test_dataloader):
    input_ids = batch["input_ids"].to(device)
    # attention_mask = batch["attention_mask"].to(device)
    labels = batch["labels"].to(device)

    with autocast("cuda"):
        logits = model(input_ids)

        loss = loss_fn(
            logits.view(-1, len(tokenizer.get_vocab())),
            labels.view(-1)
        )

    test_loss += loss.item()
    test_losses.append(loss.item())

    # probabilities = torch.softmax(output, dim=-1)
    model_letters = torch.argmax(logits, dim=-1)
    first_seq = model_letters[0]
    # decoded_expected_output = tokenizer.decode(
    #     [token for token in labels[0].tolist() if token != -100],
    #     skip_special_tokens=False
    # )
    decoded_text = tokenizer.decode(first_seq.tolist(), skip_special_tokens=False)
    # print(f"\t\tTest expected: {decoded_expected_output}")
    print(f"\tTest output: {decoded_text}")

test_loss /= len(valid_dataloader)
# print(f"\tTest Loss: {test_loss:.5f}")
logging.info(f"\tTest Loss: {test_loss:.5f}")

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 2, figsize=(10, 10))

axes = axes.flatten()

axes[0].plot(train_losses, label="Train loss")
axes[0].plot(valid_losses, label="Validation loss")
axes[0].set_title("Loss during training")
axes[0].legend()

axes[1].plot(test_losses, label="Test loss")
axes[1].set_title("Post training loss")
axes[1].legend()

axes[2].plot(learning_rates, label="LR values")
axes[2].set_title("Learning rate values over time")
axes[2].legend()

plt.tight_layout()

plt.show()