## LLM

In [1]:
import torch as torch
import numpy as np
import pickle as pkl
from tqdm.notebook import tqdm
from src.transformer import Transformer
from src.optimization import train_step, forward_and_loss, group_decay_parameters, save_checkpoint, load_checkpoint
from src.utils import saver, loader
from torch.utils.data import TensorDataset, DataLoader
from IPython.display import clear_output

print("PyTorch version:", torch.__version__)  
print("CUDA toolkit version PyTorch was built with:", torch.version.cuda)  
print("cuDNN version:", torch.backends.cudnn.version()) 

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

PyTorch version: 2.7.1+cu128
CUDA toolkit version PyTorch was built with: 12.8
cuDNN version: 90701


## Load Data

In [2]:
tokenizer = loader("cnn_tokenizer.pkl")

In [10]:
corpus_train1 = torch.tensor(loader("corpus/cnn_dailymail_article_train_tokens1.pkl"))
corpus_train2 = torch.tensor(loader("corpus/cnn_dailymail_article_train_tokens2.pkl"))
corpus_train3 = torch.tensor(loader("corpus/cnn_dailymail_article_train_tokens3.pkl"))
corpus_train4 = torch.tensor(loader("corpus/cnn_dailymail_article_train_tokens4.pkl"))
corpus_train = torch.cat((corpus_train1, corpus_train2, corpus_train3, corpus_train4), dim=0)

corpus_test = torch.tensor(loader("corpus/cnn_dailymail_article_test_tokens.pkl"))

In [11]:
def batch_data(corpus, batch_length=1024, offset=None):
    """
    Splits the corpus into batches of size batch_size.
    """
    length = len(corpus)
    batches = length // batch_length
    corpus_truncated = corpus[:batches * batch_length]  # trim to a multiple of batch_length
    corpus_batched = corpus_truncated.view(-1, batch_length)  # reshape into batches

    # overlapping batches augmentation
    if offset is not None:
        corpus_offset = corpus_truncated[offset : offset - batch_length]
        corpus_offset = corpus_offset.view(-1, batch_length)  # reshape into batches
        corpus_batched = torch.cat((corpus_batched, corpus_offset), dim=0)  # concatenate the offset batches

    return corpus_batched

In [12]:
corpus_train_batched = batch_data(corpus_train, batch_length=1024, offset=512)
corpus_test_batched = batch_data(corpus_test, batch_length=1024, offset=None)

In [13]:
loader_train = DataLoader(
    corpus_train_batched,
    batch_size=3,
    shuffle=True,       # shuffle every epoch
    drop_last=True      # drop the last incomplete batch
)

loader_test = DataLoader(
    corpus_test_batched,
    batch_size=3,
    shuffle=True,      # no need to shuffle test data
    drop_last=True      # drop the last incomplete batch
)

## Initialize Model

In [14]:
torch.manual_seed(42)

embed_dim = 64*18
ff_dim = 4*embed_dim
heads = 18
tf_blocks = 18

model = Transformer(
    embed_dim=embed_dim,
    ff_dim=ff_dim,
    heads=heads,
    tf_blocks=tf_blocks,
    vocab_size=tokenizer.vocab_size,
    max_seq_len=1024,
    dropout=0.1,
    start_token_id=tokenizer.token_to_idx["<s>"],
    use_weight_tying=True
).to(device)

optimizer_grouped_parameters = group_decay_parameters(
    model,
    weight_decay=0.1,
    no_decay=["bias", "LayerNorm.weight"],
    )

In [15]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=5e-5)
scaler = torch.amp.GradScaler("cuda")


loss_train = []

num_epochs      = 1
steps_per_epoch = len(loader_train)
warmup_steps    = 1000

def lr_lambda(step):
    if step < warmup_steps:
        return float(step) / float(max(1, warmup_steps))
    return 1.0

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

In [16]:
model, optimizer, scheduler = load_checkpoint("checkpoint_transformer.pth", model, optimizer, scheduler)

## Train Loop

In [17]:
optimizer.zero_grad()
model.train()
device = next(model.parameters()).device
accum_steps = 40

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    for step, batch in enumerate(tqdm(loader_train, desc="Training")):
        batch = batch.to(device)
        loss = train_step(model, 
                          batch, 
                          criterion, 
                          optimizer, 
                          scaler, 
                          scheduler, 
                          accum_steps,
                          step)
        if (step+1) % 100 == 0:
            model.eval()
            lr = scheduler.get_last_lr()[0]
            iter_test = iter(loader_test)
            with torch.no_grad():
                loss_eval = [forward_and_loss(model, next(iter_test).to(device), criterion).item() for _ in range(accum_steps)]
                print(f"Step {step+1}, Loss: {loss.item():<.4f}, Loss_eval: {np.mean(loss_eval):<.4f}, Learning Rate: {lr:4e}")
            model.train()

        if (step+1) % 5000 == 0:
            save_checkpoint(model, 
                            optimizer, 
                            scheduler, 
                            filename="checkpoint_transformer.pth")


Epoch 1/1


Training:   0%|          | 0/175315 [00:00<?, ?it/s]

Step 100, Loss: 3.4477, Loss_eval: 3.5698, Learning Rate: 5.000000e-05
Step 200, Loss: 3.0590, Loss_eval: 3.5081, Learning Rate: 5.000000e-05
Step 300, Loss: 3.4448, Loss_eval: 3.5213, Learning Rate: 5.000000e-05
Step 400, Loss: 3.2959, Loss_eval: 3.5556, Learning Rate: 5.000000e-05
Step 500, Loss: 3.4306, Loss_eval: 3.5298, Learning Rate: 5.000000e-05
Step 600, Loss: 3.0973, Loss_eval: 3.5499, Learning Rate: 5.000000e-05
Step 700, Loss: 3.1289, Loss_eval: 3.5220, Learning Rate: 5.000000e-05
Step 800, Loss: 3.3314, Loss_eval: 3.5511, Learning Rate: 5.000000e-05
Step 900, Loss: 3.3305, Loss_eval: 3.5418, Learning Rate: 5.000000e-05
Step 1000, Loss: 3.1556, Loss_eval: 3.5685, Learning Rate: 5.000000e-05
Step 1100, Loss: 3.2533, Loss_eval: 3.5340, Learning Rate: 5.000000e-05
Step 1200, Loss: 3.3586, Loss_eval: 3.5337, Learning Rate: 5.000000e-05
Step 1300, Loss: 3.5454, Loss_eval: 3.5311, Learning Rate: 5.000000e-05
Step 1400, Loss: 3.2435, Loss_eval: 3.5580, Learning Rate: 5.000000e-05
S

In [27]:
text = "  <s><h>"

tokens = torch.tensor(tokenizer.encode(text.lower()), dtype=torch.long).reshape(1, -1).to(device)

In [28]:
import textwrap
import ipywidgets as widgets
from IPython.display import display
from torch.distributions import Categorical

wrapper = textwrap.TextWrapper(width=80)

# create a read-only text area
ta = widgets.Textarea(
    value="",
    layout=widgets.Layout(width='80ch', height='20em'),
    disabled=True
)
display(ta)


T = 1
k = 50

#torch.random.torch.manual_seed(42) 

for i in range(1024):
    logits = model(tokens)[0, -1:]
    topk_vals, _    = torch.topk(logits, k=k)
    #print(topk_vals)
    kth_value       = topk_vals[:,-1]

    logits = torch.where(logits >= kth_value, logits, -torch.inf)
    dist = Categorical(logits=logits/T)
    idx = dist.sample()
    tokens = torch.cat([tokens, idx.reshape(1,1)], dim=1)
    #print(tokens.shape)
    text = tokenizer.decode(tokens[0].tolist())
    ta.value = wrapper.fill(text.replace("\n", " "))  # this updates in-place

    if idx[0] == tokenizer.token_to_idx["</s>"]:
        break

Textarea(value='', disabled=True, layout=Layout(height='20em', width='80ch'))