In [32]:
BATCH_SIZE = 16

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

Mounted at /content/drive


In [1]:
# !pip install transformer-lens fancy_einsum

In [2]:
from transformer_lens.utils import get_corner, gelu_new

In [3]:
# Create a PyTorch dataset for IMDB movie reviews using HuggingFace's datasets library
# https://huggingface.co/datasets/imdb
import math
import re
import torch
import torch.nn as nn
import einops
from fancy_einsum import einsum
import pandas as pd
import tqdm.auto as tqdm
from dataclasses import dataclass
from torch.utils.data import Dataset
from collections import defaultdict
from torch.nn import Embedding
from transformers import AutoTokenizer, DataCollatorForLanguageModeling, default_data_collator
from datasets import load_dataset, concatenate_datasets

In [8]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config
tokenizer = AutoTokenizer.from_pretrained('gpt2')
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
# config = GPT2Config(vocab_size=len(tokenizer), n_embd=512, n_layer=4, n_head=8)

In [33]:
@dataclass
class Config:
    d_model: int = 768
    debug: bool = False
    layer_norm_eps: float = 1e-5
    d_vocab: int = 50257
    init_range: float = 0.02
    n_ctx: int = 15642
    d_head: int = 64
    d_mlp: int = 3072
    n_heads: int = 12
    n_layers: int = 12

cfg = Config()
print(cfg)

Config(d_model=768, debug=False, layer_norm_eps=1e-05, d_vocab=50257, init_range=0.02, n_ctx=15642, d_head=64, d_mlp=3072, n_heads=12, n_layers=12)


In [34]:
batch_size = BATCH_SIZE
num_epochs = 1
max_steps = 1000
log_every = 10
lr = 1e-3
weight_decay = 1e-2
model_cfg = Config(debug=False, d_model=256, n_heads=4, d_head=64, d_mlp=1024, n_layers=2, n_ctx=256, d_vocab=cfg.d_vocab)

In [35]:
def rand_float_test(cls, shape):
    cfg = Config(debug=True)
    layer = cls(cfg)
    random_input = torch.randn(shape)
    print("Input shape:", random_input.shape)
    output = layer(random_input)
    print("Output shape:", output.shape)
    print()
    return output

def rand_int_test(cls, shape):
    cfg = Config(debug=True)
    layer = cls(cfg)
    random_input = torch.randint(100, 1000, shape)
    print("Input shape:", random_input.shape)
    output = layer(random_input)
    print("Output shape:", output.shape)
    print()
    return output

In [36]:
class MLP(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_in = nn.Parameter(torch.empty((cfg.d_model, cfg.d_mlp)))
        nn.init.normal_(self.W_in, std=self.cfg.init_range)
        self.b_in = nn.Parameter(torch.zeros((cfg.d_mlp)))
        self.W_out = nn.Parameter(torch.empty((cfg.d_mlp, cfg.d_model)))
        nn.init.normal_(self.W_out, std=self.cfg.init_range)
        self.b_out = nn.Parameter(torch.zeros((cfg.d_model)))

    def forward(self, normalized_resid_mid):
        # normalized_resid_mid: [batch, position, d_model]
        # if self.cfg.debug: print("Normalized_resid_mid:", normalized_resid_mid.shape)
        pre = einsum("batch position d_model, d_model d_mlp -> batch position d_mlp", normalized_resid_mid, self.W_in) + self.b_in
        post = gelu_new(pre)
        mlp_out = einsum("batch position d_mlp, d_mlp d_model -> batch position d_model", post, self.W_out) + self.b_out
        return mlp_out

rand_float_test(MLP, [2, 4, 768])

Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768])


tensor([[[ 7.0771e-01,  2.5544e-02, -4.8398e-01,  ..., -7.7187e-01,
          -1.6858e-01,  1.9644e-01],
         [ 9.4483e-02, -1.2457e-01,  6.8828e-02,  ..., -7.0338e-01,
          -4.4167e-01,  4.3592e-04],
         [ 3.8401e-01,  1.2869e-01,  6.2888e-02,  ..., -6.6646e-01,
          -4.7603e-02, -6.3357e-01],
         [-3.1750e-01,  1.9309e-01, -1.5287e-01,  ...,  2.1369e-01,
          -2.2808e-01,  3.7091e-01]],

        [[ 3.9809e-01,  1.1421e-01, -4.0952e-01,  ..., -3.3643e-01,
          -5.7297e-02, -2.6065e-01],
         [-3.1438e-04, -2.7293e-01, -2.2632e-01,  ...,  6.2189e-02,
           8.5641e-02,  2.6140e-03],
         [-1.8383e-01,  2.1350e-01,  2.8224e-01,  ...,  2.5073e-01,
          -1.2176e-01,  3.9101e-01],
         [-6.5818e-01,  1.2267e+00,  2.8108e-01,  ...,  5.3583e-01,
          -3.5283e-01, -2.9918e-01]]], grad_fn=<AddBackward0>)

In [37]:
class PosEmbed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_pos = nn.Parameter(torch.empty((cfg.n_ctx, cfg.d_model)))
        nn.init.normal_(self.W_pos, std=self.cfg.init_range)

    def forward(self, tokens):
        # tokens: [batch, position]
        if self.cfg.debug: print("Tokens:", tokens.shape)
        pos_embed = self.W_pos[:tokens.size(1), :] # [position, d_model]
        pos_embed = einops.repeat(pos_embed, "position d_model -> batch position d_model", batch=tokens.size(0))
        if self.cfg.debug: print("pos_embed:", pos_embed.shape)
        return pos_embed

rand_int_test(PosEmbed, [2, 4])

Input shape: torch.Size([2, 4])
Tokens: torch.Size([2, 4])
pos_embed: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768])


tensor([[[ 0.0073,  0.0249, -0.0099,  ...,  0.0178,  0.0282, -0.0433],
         [-0.0008,  0.0033,  0.0023,  ..., -0.0084,  0.0018, -0.0210],
         [-0.0092,  0.0175,  0.0208,  ..., -0.0210, -0.0081,  0.0201],
         [-0.0040,  0.0416, -0.0110,  ...,  0.0071,  0.0244,  0.0157]],

        [[ 0.0073,  0.0249, -0.0099,  ...,  0.0178,  0.0282, -0.0433],
         [-0.0008,  0.0033,  0.0023,  ..., -0.0084,  0.0018, -0.0210],
         [-0.0092,  0.0175,  0.0208,  ..., -0.0210, -0.0081,  0.0201],
         [-0.0040,  0.0416, -0.0110,  ...,  0.0071,  0.0244,  0.0157]]],
       grad_fn=<ReshapeAliasBackward0>)

In [38]:
class LayerNorm(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.w = nn.Parameter(torch.ones(cfg.d_model))
        self.b = nn.Parameter(torch.zeros(cfg.d_model))

    def forward(self, residual):
        # residual: [batch, position, d_model]
        # if self.cfg.debug: print("Residual:", residual.shape)
        residual = residual - einops.reduce(residual, "batch position d_model -> batch position 1", "mean")
        # Calculate the variance, square root it. Add in an epsilon to prevent divide by zero.
        scale = (einops.reduce(residual.pow(2), "batch position d_model -> batch position 1", "mean") + cfg.layer_norm_eps).sqrt()
        normalized = residual / scale
        normalized = normalized * self.w + self.b
        # if self.cfg.debug: print("Normalized:", residual.shape)
        return normalized

In [39]:
class Attention(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_Q = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        nn.init.normal_(self.W_Q, std=self.cfg.init_range)
        self.b_Q = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
        self.W_K = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        nn.init.normal_(self.W_K, std=self.cfg.init_range)
        self.b_K = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
        self.W_V = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        nn.init.normal_(self.W_V, std=self.cfg.init_range)
        self.b_V = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))

        self.W_O = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_head, cfg.d_model)))
        nn.init.normal_(self.W_O, std=self.cfg.init_range)
        self.b_O = nn.Parameter(torch.zeros((cfg.d_model)))

        self.register_buffer("IGNORE", torch.tensor(-1e5, dtype=torch.float32))

    def forward(self, normalized_resid_pre):
        # normalized_resid_pre: [batch, position, d_model]
        if self.cfg.debug: print("Normalized_resid_pre:", normalized_resid_pre.shape)

        q = einsum("batch query_pos d_model, n_heads d_model d_head -> batch query_pos n_heads d_head", normalized_resid_pre, self.W_Q) + self.b_Q
        k = einsum("batch key_pos d_model, n_heads d_model d_head -> batch key_pos n_heads d_head", normalized_resid_pre, self.W_K) + self.b_K

        attn_scores = einsum("batch query_pos n_heads d_head, batch key_pos n_heads d_head -> batch n_heads query_pos key_pos", q, k)
        attn_scores = attn_scores / math.sqrt(self.cfg.d_head)
        attn_scores = self.apply_causal_mask(attn_scores)

        pattern = attn_scores.softmax(dim=-1) # [batch, n_head, query_pos, key_pos]

        v = einsum("batch key_pos d_model, n_heads d_model d_head -> batch key_pos n_heads d_head", normalized_resid_pre, self.W_V) + self.b_V

        z = einsum("batch n_heads query_pos key_pos, batch key_pos n_heads d_head -> batch query_pos n_heads d_head", pattern, v)

        attn_out = einsum("batch query_pos n_heads d_head, n_heads d_head d_model -> batch query_pos d_model", z, self.W_O) + self.b_O
        return attn_out

    def apply_causal_mask(self, attn_scores):
        # attn_scores: [batch, n_heads, query_pos, key_pos]
        mask = torch.triu(torch.ones(attn_scores.size(-2), attn_scores.size(-1), device=attn_scores.device), diagonal=1).bool()
        attn_scores.masked_fill_(mask, self.IGNORE)
        return attn_scores

rand_float_test(Attention, [2, 4, 768])

Input shape: torch.Size([2, 4, 768])
Normalized_resid_pre: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768])


tensor([[[ 0.5883,  0.4072,  0.7434,  ...,  0.3714,  0.0497, -0.0960],
         [ 0.2710,  0.3625,  0.1091,  ...,  0.0810, -0.0357, -0.4937],
         [ 0.1203,  0.0571, -0.0880,  ...,  0.0924, -0.1721, -0.2260],
         [ 0.1385,  0.0764, -0.1383,  ...,  0.0642, -0.0483, -0.3116]],

        [[ 0.7955,  0.0508, -0.1467,  ..., -0.7288,  0.3035, -0.0901],
         [ 0.4759, -0.0549, -0.0384,  ..., -0.2907,  0.2928, -0.3600],
         [ 0.2403, -0.1736, -0.2889,  ..., -0.0220,  0.0970, -0.0989],
         [ 0.3066, -0.0805, -0.1398,  ..., -0.1417,  0.1320, -0.0782]]],
       grad_fn=<AddBackward0>)

In [40]:
class FeedForward(nn.Module):
    def __init__(self, embedding_dim, hidden_dim):
        super(FeedForward, self).__init__()
        self.fc1 = nn.Linear(embedding_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, embedding_dim)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

In [41]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg

        self.ln1 = LayerNorm(cfg)
        self.attn = Attention(cfg)
        self.ln2 = LayerNorm(cfg)
        self.mlp = MLP(cfg)

    def forward(self, resid_pre):
        # resid_pre [batch, position, d_model]
        normalized_resid_pre = self.ln1(resid_pre)
        attn_out = self.attn(normalized_resid_pre)
        resid_mid = resid_pre + attn_out

        normalized_resid_mid = self.ln2(resid_mid)
        mlp_out = self.mlp(normalized_resid_mid)
        resid_post = resid_mid + mlp_out
        return resid_post
rand_float_test(TransformerBlock, [2, 4, 768])

Input shape: torch.Size([2, 4, 768])
Normalized_resid_pre: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768])


tensor([[[ 0.6263, -0.4525, -0.2144,  ..., -0.2297,  0.9080,  1.2107],
         [ 0.0373,  0.1283,  1.2095,  ...,  0.7652,  0.7788,  2.7103],
         [ 0.1047, -0.1235,  0.2026,  ...,  1.2542, -0.5981,  2.6094],
         [-1.3330,  1.6202,  0.6015,  ..., -0.7816,  0.4968,  0.0255]],

        [[-1.7777, -0.5475,  0.9965,  ...,  1.2684,  0.5034,  0.0781],
         [ 0.4971, -0.5653,  0.7481,  ...,  3.9081, -0.3278,  0.9421],
         [ 0.7433, -0.7672,  0.2922,  ...,  1.0622,  0.0452,  1.1211],
         [-0.5060, -1.8409, -0.1003,  ..., -0.5613, -0.3160,  0.8080]]],
       grad_fn=<AddBackward0>)

In [42]:
class Embed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_E = nn.Parameter(torch.empty((cfg.d_vocab, cfg.d_model)))
        nn.init.normal_(self.W_E, std=self.cfg.init_range)

    def forward(self, tokens):
        # tokens: [batch, position]
        # if self.cfg.debug: print("Tokens:", tokens.shape)
        embed = self.W_E[tokens, :] # [batch, position, d_model]
        # if self.cfg.debug: print("Embeddings:", embed.shape)
        return embed

rand_int_test(Embed, [2, 4])

Input shape: torch.Size([2, 4])
Output shape: torch.Size([2, 4, 768])


tensor([[[ 0.0510, -0.0019, -0.0130,  ...,  0.0188,  0.0114,  0.0218],
         [-0.0120,  0.0063, -0.0195,  ..., -0.0316,  0.0137,  0.0178],
         [-0.0075,  0.0239,  0.0104,  ...,  0.0303,  0.0098, -0.0036],
         [ 0.0070,  0.0128,  0.0147,  ...,  0.0036, -0.0314, -0.0186]],

        [[ 0.0268,  0.0023,  0.0467,  ...,  0.0229, -0.0345,  0.0034],
         [ 0.0338, -0.0112,  0.0073,  ..., -0.0168,  0.0087,  0.0139],
         [-0.0229,  0.0126,  0.0442,  ..., -0.0051,  0.0256, -0.0336],
         [ 0.0049,  0.0242,  0.0134,  ...,  0.0158, -0.0133,  0.0236]]],
       grad_fn=<IndexBackward0>)

In [43]:
class Unembed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_U = nn.Parameter(torch.empty((cfg.d_model, cfg.d_vocab)))
        nn.init.normal_(self.W_U, std=self.cfg.init_range)
        self.b_U = nn.Parameter(torch.zeros((cfg.d_vocab), requires_grad=False))

    def forward(self, normalized_resid_final):
        # normalized_resid_final [batch, position, d_model]
        if self.cfg.debug: print("Normalized_resid_final:", normalized_resid_final.shape)
        logits = einsum("batch position d_model, d_model d_vocab -> batch position d_vocab", normalized_resid_final, self.W_U) + self.b_U
        return logits

rand_float_test(Unembed, [2, 4, 768])

Input shape: torch.Size([2, 4, 768])
Normalized_resid_final: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 50257])


tensor([[[-0.0557, -0.9152, -0.1448,  ..., -0.6397,  0.5445, -0.1708],
         [ 0.8018, -0.5568, -0.8694,  ...,  0.6254, -0.2636, -0.3345],
         [-0.0671, -0.2100, -0.8385,  ...,  0.1795,  0.2128,  0.5714],
         [-0.1715,  0.6015, -0.1744,  ...,  0.2500,  0.0871, -0.2267]],

        [[ 0.7630, -0.3896,  0.7479,  ..., -0.0585, -0.3484, -0.2877],
         [-0.2778,  0.2537,  0.8654,  ...,  0.7089,  0.2914, -0.6436],
         [-0.3794, -1.1478,  1.0959,  ...,  0.2955,  0.3531,  0.8792],
         [-0.6239,  0.6491,  0.1131,  ..., -0.5089,  0.0508, -0.4537]]],
       grad_fn=<AddBackward0>)

In [44]:
class GPT2(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.embed = Embed(cfg)
        self.pos_embed = PosEmbed(cfg)
        self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)])
        self.ln_final = LayerNorm(cfg)
        self.unembed = Unembed(cfg)

    def forward(self, tokens):
        # tokens [batch, position]
        embed = self.embed(tokens)
        pos_embed = self.pos_embed(tokens)
        residual = embed + pos_embed
        for block in self.blocks:
            residual = block(residual)
        normalized_resid_final = self.ln_final(residual)
        logits = self.unembed(normalized_resid_final)
        # logits have shape [batch, position, logits]
        return logits

rand_int_test(GPT2, [2, 4])

Input shape: torch.Size([2, 4])
Tokens: torch.Size([2, 4])
pos_embed: torch.Size([2, 4, 768])
Normalized_resid_pre: torch.Size([2, 4, 768])
Normalized_resid_pre: torch.Size([2, 4, 768])
Normalized_resid_pre: torch.Size([2, 4, 768])
Normalized_resid_pre: torch.Size([2, 4, 768])
Normalized_resid_pre: torch.Size([2, 4, 768])
Normalized_resid_pre: torch.Size([2, 4, 768])
Normalized_resid_pre: torch.Size([2, 4, 768])
Normalized_resid_pre: torch.Size([2, 4, 768])
Normalized_resid_pre: torch.Size([2, 4, 768])
Normalized_resid_pre: torch.Size([2, 4, 768])
Normalized_resid_pre: torch.Size([2, 4, 768])
Normalized_resid_pre: torch.Size([2, 4, 768])
Normalized_resid_final: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 50257])


tensor([[[-0.5213, -0.8311, -0.0675,  ..., -0.5825,  0.0222, -0.1156],
         [-0.1897, -0.7643,  0.4709,  ..., -0.1320,  0.0366, -0.5127],
         [-0.1808, -0.8806,  0.4033,  ..., -0.3110,  0.2349, -0.3973],
         [-0.2245, -1.1128,  0.6590,  ..., -0.0811,  0.2280, -0.2946]],

        [[ 0.4035, -0.7987, -0.2036,  ..., -0.1238,  0.4460, -0.1607],
         [ 0.2073, -0.6002, -0.2648,  ...,  0.0331,  0.1240,  0.1212],
         [ 0.2192, -0.8463,  0.1544,  ...,  0.2274,  0.1696, -0.0165],
         [ 0.2839, -0.6639,  0.3670,  ...,  0.1804, -0.1207,  0.1614]]],
       grad_fn=<AddBackward0>)

In [45]:
# !pip install lightning

In [46]:
import torch.optim as optim
import torch.nn.functional as F
from transformers import GPT2Tokenizer, TextDataset, DataCollatorForLanguageModeling
import pytorch_lightning as pl
import torchmetrics

In [47]:
# define model hyperparameters
# vocab_size = 50257 + 1 # Add 1 for the padding token
# vocab_size = vocab_size_from_imdb + 1
vocab_size = cfg.d_vocab
num_heads = cfg.n_heads
hidden_dim = 768
num_layers = cfg.n_layers
embedding_dim = cfg.d_model

In [48]:
vocab_size

50257

In [49]:
def sanity_check_hyperparameters(embedding_dim, num_heads):
    if embedding_dim % num_heads != 0:
        raise ValueError(f"embedding_dim ({embedding_dim}) must be divisible by num_heads ({num_heads}).")
    print("Hyperparameters are set correctly!")

# Example usage
sanity_check_hyperparameters(embedding_dim, num_heads)

Hyperparameters are set correctly!


In [50]:
def lm_cross_entropy_loss(logits, tokens):
    # Measure next token loss
    # Logits have shape [batch, position, d_vocab]
    # Tokens have shape [batch, position]
    log_probs = logits.log_softmax(dim=-1)
    pred_log_probs = log_probs[:, :-1].gather(dim=-1, index=tokens[:, 1:].unsqueeze(-1)).squeeze(-1)
    return -pred_log_probs.mean()

In [51]:
model = GPT2(cfg)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

## Load the dataset

In [52]:
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from datasets import load_dataset, Dataset
import pandas as pd
import matplotlib.pyplot as plt

def load_mt_kaggle_csv(filepath):
    """Load a CSV file into a Pandas DataFrame."""
    try:
        df = pd.read_csv(filepath)
        print("CSV file successfully loaded!")
        return df
    except Exception as e:
        print(f"Error occurred: {e}")
        return None
    
df = load_mt_kaggle_csv('data/mtsamples.csv')
df = df.dropna(subset=['transcription'])
transcription_df = df[['transcription']]

CSV file successfully loaded!


In [53]:
from datasets import Features, Value

def tokenize_function(examples):
    # Add the eos_token to each transcription
    transcriptions_with_eos = [t + tokenizer.eos_token for t in examples['transcription']]
    
    # Tokenize using the tokenizer
    tokens = tokenizer(transcriptions_with_eos, padding='max_length', truncation=True, max_length=512, return_tensors="pt")
    
    return tokens

dataset = Dataset.from_pandas(transcription_df)
tokenized_datasets = dataset.map(tokenize_function, batched=True)

Map:   0%|          | 0/4966 [00:00<?, ? examples/s]

In [54]:
from torch.utils.data import DataLoader

# Ensure the data format is set correctly
tokenized_datasets.set_format(type='torch', columns=['input_ids'])

# Create the DataLoader
train_dataloader = DataLoader(tokenized_datasets, shuffle=True, batch_size=BATCH_SIZE, num_workers=0)

In [55]:
# Sanity Check
# Iterate over a single batch from the train_dataloader
for batch in train_dataloader:
    input_ids = batch['input_ids']
    
    # Check tensor size
    assert input_ids.size() == (BATCH_SIZE, 512), f"Expected size ({BATCH_SIZE}, 512), but got {input_ids.size()}"
    
    # Check tensor stride
    assert input_ids.stride() == (512, 1), f"Expected stride (512, 1), but got {input_ids.stride()}"
    
    print("Sanity check passed!")
    break  # We only want to check one batch for this sanity check


Sanity check passed!


In [56]:
import pytorch_lightning as pl
import torch.optim as optim

class GPT2ReviewModel(pl.LightningModule):
    def __init__(self, model, tokenizer, loss_fn, learning_rate=3e-4):
        super(GPT2ReviewModel, self).__init__()
        self.model = model
        self.tokenizer = tokenizer
        self.loss_fn = loss_fn
        self.learning_rate = learning_rate

    def forward(self, input_data, mask):
        return self.model(input_data, attention_mask=mask)

    def training_step(self, batch, batch_idx):
        input_data = batch["input_ids"][:, :-1]
        targets = batch["input_ids"][:, 1:]
        
        mask = (input_data != self.tokenizer.pad_token_id).float().unsqueeze(1).unsqueeze(2)
        
        outputs = self(input_data, mask)
        logits = outputs.logits
        logits = logits.view(-1, logits.shape[2])
        targets = targets.contiguous().view(-1)  # Ensure the tensor is contiguous
        
        loss = self.loss_fn(logits, targets)
        return loss

    def configure_optimizers(self):
        optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
        return optimizer


In [None]:
import pytorch_lightning as pl

loss_fn = nn.CrossEntropyLoss()
model_instance = GPT2ReviewModel(model, tokenizer, loss_fn)

trainer = pl.Trainer(max_epochs=1)  
trainer.fit(model_instance, train_dataloader)