## LLM

In [7]:
import torch as torch
import numpy as np
import pickle as pkl
from tqdm.notebook import tqdm
from transformer_kristianwold.transformer import Transformer
from transformer_kristianwold.optimization import train_step, forward_and_loss, group_decay_parameters, save_checkpoint, load_checkpoint
from transformer_kristianwold.utils import saver, loader
from torch.utils.data import TensorDataset, DataLoader
from IPython.display import clear_output
import matplotlib.pyplot as plt

print("PyTorch version:", torch.__version__)  
print("CUDA toolkit version PyTorch was built with:", torch.version.cuda)  
print("cuDNN version:", torch.backends.cudnn.version()) 
print("cuda available:", torch.cuda.is_available())

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.set_float32_matmul_precision('high')

PyTorch version: 2.7.1+cu128
CUDA toolkit version PyTorch was built with: 12.8
cuDNN version: 90701
cuda available: True


## Load and batch data

In [None]:
tokenizer = loader("../../tokenizers/cnn_tokenizer3.pkl")

start_token_id=tokenizer.token_to_idx["<s>"]
vocab_size=tokenizer.vocab_size

print("Start token id:", start_token_id)
print("Vocab size:", vocab_size)


Start token id: 24070
Vocab size: 24074


In [None]:
corpus_train1 = loader("../../corpus/cnn_dailymail_highlight_first_train.pkl")
corpus_train2 = loader("../../corpus/cnn_dailymail_highlight_last_train.pkl")
corpus_train = torch.cat((corpus_train1, corpus_train2), dim=0)

corpus_test1 = loader("../../corpus/cnn_dailymail_highlight_first_test.pkl")
corpus_test2 = loader("../../corpus/cnn_dailymail_highlight_last_test.pkl")
corpus_test = torch.cat((corpus_test1, corpus_test2), dim=0)

In [10]:
def batch_data(corpus, batch_length=1024):
    """
    Splits the corpus into batches of size batch_size.
    """
    length = len(corpus)
    batches = length // batch_length
    corpus_truncated = corpus[:batches * batch_length]  # trim to a multiple of batch_length
    corpus_batched = corpus_truncated.view(-1, batch_length)  # reshape into batches

    return corpus_batched

In [11]:
corpus_train_batched = batch_data(corpus_train, batch_length=1024)
corpus_test_batched = batch_data(corpus_test, batch_length=1024)

In [12]:
loader_train = DataLoader(
    corpus_train_batched,
    batch_size=3,
    shuffle=True,       # shuffle every epoch
    drop_last=True      # drop the last incomplete batch
)

loader_test = DataLoader(
    corpus_test_batched,
    batch_size=3,
    shuffle=True,      # no need to shuffle test data
    drop_last=True      # drop the last incomplete batch
)

## Initialize Transformer

In [None]:
torch.manual_seed(42)

embed_dim = 64*18
ff_dim = 4*embed_dim
heads = 18
tf_blocks = 18

model = Transformer(
    embed_dim=embed_dim,
    ff_dim=ff_dim,
    heads=heads,
    tf_blocks=tf_blocks,
    vocab_size=vocab_size,
    max_seq_len=1024,
    dropout=0.,   # no dropout
    start_token_id=start_token_id,
    use_weight_tying=True
).to(device)

optimizer_grouped_parameters = group_decay_parameters(
    model,
    weight_decay=0., #no weight decay
    no_decay=["bias", "LayerNorm.weight"],
    )

filename = "../../models/checkpoint_transformer_no_regularization_1epoch.pth"

print("Number of parameters:", model.num_parameters())

Number of parameters: 315778058


In [14]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=5e-5)
scaler = torch.amp.GradScaler("cuda")
loss_train_list = []
loss_test_list = []

num_epochs      = 1
steps_per_epoch = len(loader_train)
warmup_steps    = 1000

def lr_lambda(step):
    if step < warmup_steps:
        return float(step) / float(max(1, warmup_steps))
    return 1.0

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

### Save Model

In [15]:
# save_checkpoint(model, 
#                 optimizer, 
#                 scheduler,
#                 loss_train_list,
#                 loss_test_list, 
#                 filename="models/checkpoint_transformer.pth")

### Load Model

In [None]:
#[model, 
#optimizer, 
#scheduler, 
#loss_train_list, 
#loss_test_list] = load_checkpoint("../../models/checkpoint_transformer_3epoch.pth", 
#                                  model, 
#                                  optimizer, 
#                                  scheduler, 
#                                  loss_train_list, 
#                                  loss_test_list)

In [17]:
optimizer.zero_grad()
model.train()
device = next(model.parameters()).device
accum_steps = 40

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    for step, batch in enumerate(tqdm(loader_train, desc="Training")):
        batch = batch.to(device)
        loss_train = train_step(model, 
                          batch, 
                          criterion, 
                          optimizer, 
                          scaler, 
                          scheduler, 
                          accum_steps,
                          step).item()
        
        if (step+1) % 500 == 0:
            model.eval()
            lr = scheduler.get_last_lr()[0]
            iter_test = iter(loader_test)
            with torch.no_grad():
                loss_test = np.mean([forward_and_loss(model, next(iter_test).to(device), criterion).item() 
                                     for _ in range(accum_steps)])
                print(f"Step {step+1}, Loss: {loss_train:<.3f}, Loss_eval: {loss_test:<.3f}, Learning Rate: {lr:3e}")
            model.train()

            loss_train_list.append(loss_train)
            loss_test_list.append(loss_test)

            
        if (step+1) % 5000 == 0:
            save_checkpoint(model, 
                            optimizer, 
                            scheduler,
                            loss_train_list,
                            loss_test_list, 
                            filename=filename)
            
    save_checkpoint(model, 
                    optimizer, 
                    scheduler,
                    loss_train_list,
                    loss_test_list, 
                    filename=filename)

Epoch 1/1


Training:   0%|          | 0/175316 [00:00<?, ?it/s]

Step 500, Loss: 10.028, Loss_eval: 10.034, Learning Rate: 6.000000e-07
Step 1000, Loss: 9.842, Loss_eval: 9.838, Learning Rate: 1.250000e-06
Step 1500, Loss: 9.666, Loss_eval: 9.648, Learning Rate: 1.850000e-06
Step 2000, Loss: 9.444, Loss_eval: 9.482, Learning Rate: 2.500000e-06
Step 2500, Loss: 9.373, Loss_eval: 9.383, Learning Rate: 3.100000e-06
Step 3000, Loss: 9.250, Loss_eval: 9.305, Learning Rate: 3.750000e-06
Step 3500, Loss: 9.233, Loss_eval: 9.230, Learning Rate: 4.350000e-06
Step 4000, Loss: 9.156, Loss_eval: 9.163, Learning Rate: 5.000000e-06
Step 4500, Loss: 9.034, Loss_eval: 9.072, Learning Rate: 5.600000e-06
Step 5000, Loss: 8.974, Loss_eval: 8.964, Learning Rate: 6.250000e-06
Step 5500, Loss: 8.841, Loss_eval: 8.866, Learning Rate: 6.850000e-06
Step 6000, Loss: 8.749, Loss_eval: 8.748, Learning Rate: 7.500000e-06
Step 6500, Loss: 8.605, Loss_eval: 8.636, Learning Rate: 8.100000e-06
Step 7000, Loss: 8.485, Loss_eval: 8.504, Learning Rate: 8.750000e-06
Step 7500, Loss: 8.

In [None]:
filename = "../../models/checkpoint_transformer_no_regularization_2epoch.pth"

optimizer.zero_grad()
model.train()
device = next(model.parameters()).device
accum_steps = 40

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    for step, batch in enumerate(tqdm(loader_train, desc="Training")):
        batch = batch.to(device)
        loss_train = train_step(model, 
                          batch, 
                          criterion, 
                          optimizer, 
                          scaler, 
                          scheduler, 
                          accum_steps,
                          step).item()
        
        if (step+1) % 500 == 0:
            model.eval()
            lr = scheduler.get_last_lr()[0]
            iter_test = iter(loader_test)
            with torch.no_grad():
                loss_test = np.mean([forward_and_loss(model, next(iter_test).to(device), criterion).item() 
                                     for _ in range(accum_steps)])
                print(f"Step {step+1}, Loss: {loss_train:<.3f}, Loss_eval: {loss_test:<.3f}, Learning Rate: {lr:3e}")
            model.train()

            loss_train_list.append(loss_train)
            loss_test_list.append(loss_test)

            
        if (step+1) % 5000 == 0:
            save_checkpoint(model, 
                            optimizer, 
                            scheduler,
                            loss_train_list,
                            loss_test_list, 
                            filename=filename)
            
    save_checkpoint(model, 
                    optimizer, 
                    scheduler,
                    loss_train_list,
                    loss_test_list, 
                    filename=filename)

Epoch 1/1


Training:   0%|          | 0/175316 [00:00<?, ?it/s]

Step 500, Loss: 4.177, Loss_eval: 3.956, Learning Rate: 5.000000e-05
Step 1000, Loss: 3.959, Loss_eval: 3.962, Learning Rate: 5.000000e-05
Step 1500, Loss: 3.600, Loss_eval: 3.953, Learning Rate: 5.000000e-05
Step 2000, Loss: 3.891, Loss_eval: 3.959, Learning Rate: 5.000000e-05
Step 2500, Loss: 4.048, Loss_eval: 3.919, Learning Rate: 5.000000e-05
Step 3000, Loss: 4.040, Loss_eval: 3.975, Learning Rate: 5.000000e-05
Step 3500, Loss: 3.719, Loss_eval: 3.919, Learning Rate: 5.000000e-05
Step 4000, Loss: 3.706, Loss_eval: 3.912, Learning Rate: 5.000000e-05
Step 4500, Loss: 4.089, Loss_eval: 3.885, Learning Rate: 5.000000e-05
Step 5000, Loss: 3.784, Loss_eval: 3.937, Learning Rate: 5.000000e-05
Step 5500, Loss: 3.968, Loss_eval: 3.974, Learning Rate: 5.000000e-05
Step 6000, Loss: 4.275, Loss_eval: 4.011, Learning Rate: 5.000000e-05
Step 6500, Loss: 3.970, Loss_eval: 3.945, Learning Rate: 5.000000e-05
Step 7000, Loss: 4.137, Loss_eval: 3.901, Learning Rate: 5.000000e-05
Step 7500, Loss: 3.78

In [None]:
filename = "../../models/checkpoint_transformer_no_regularization_3epoch.pth"

optimizer.zero_grad()
model.train()
device = next(model.parameters()).device
accum_steps = 40

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    for step, batch in enumerate(tqdm(loader_train, desc="Training")):
        batch = batch.to(device)
        loss_train = train_step(model, 
                          batch, 
                          criterion, 
                          optimizer, 
                          scaler, 
                          scheduler, 
                          accum_steps,
                          step).item()
        
        if (step+1) % 500 == 0:
            model.eval()
            lr = scheduler.get_last_lr()[0]
            iter_test = iter(loader_test)
            with torch.no_grad():
                loss_test = np.mean([forward_and_loss(model, next(iter_test).to(device), criterion).item() 
                                     for _ in range(accum_steps)])
                print(f"Step {step+1}, Loss: {loss_train:<.3f}, Loss_eval: {loss_test:<.3f}, Learning Rate: {lr:3e}")
            model.train()

            loss_train_list.append(loss_train)
            loss_test_list.append(loss_test)

            
        if (step+1) % 5000 == 0:
            save_checkpoint(model, 
                            optimizer, 
                            scheduler,
                            loss_train_list,
                            loss_test_list, 
                            filename=filename)
            
    save_checkpoint(model, 
                    optimizer, 
                    scheduler,
                    loss_train_list,
                    loss_test_list, 
                    filename=filename)

Epoch 1/1


Training:   0%|          | 0/175316 [00:00<?, ?it/s]

Step 500, Loss: 3.303, Loss_eval: 3.555, Learning Rate: 5.000000e-05
Step 1000, Loss: 3.390, Loss_eval: 3.495, Learning Rate: 5.000000e-05
Step 1500, Loss: 3.425, Loss_eval: 3.539, Learning Rate: 5.000000e-05
Step 2000, Loss: 3.247, Loss_eval: 3.494, Learning Rate: 5.000000e-05
Step 2500, Loss: 3.468, Loss_eval: 3.560, Learning Rate: 5.000000e-05
Step 3000, Loss: 3.184, Loss_eval: 3.509, Learning Rate: 5.000000e-05
Step 3500, Loss: 3.779, Loss_eval: 3.531, Learning Rate: 5.000000e-05
Step 4000, Loss: 3.466, Loss_eval: 3.503, Learning Rate: 5.000000e-05
Step 4500, Loss: 3.421, Loss_eval: 3.492, Learning Rate: 5.000000e-05
Step 5000, Loss: 3.500, Loss_eval: 3.517, Learning Rate: 5.000000e-05
Step 5500, Loss: 3.399, Loss_eval: 3.472, Learning Rate: 5.000000e-05
Step 6000, Loss: 3.729, Loss_eval: 3.494, Learning Rate: 5.000000e-05
Step 6500, Loss: 3.541, Loss_eval: 3.496, Learning Rate: 5.000000e-05
Step 7000, Loss: 3.388, Loss_eval: 3.539, Learning Rate: 5.000000e-05
Step 7500, Loss: 3.48