## LLM

In [1]:
import torch as torch
import numpy as np
import pickle as pkl
from tqdm.notebook import tqdm
from src.transformer import Transformer
from src.optimization import train_step, forward_and_loss, group_decay_parameters, save_checkpoint, load_checkpoint
from src.utils import saver, loader
from torch.utils.data import TensorDataset, DataLoader
from IPython.display import clear_output

print("PyTorch version:", torch.__version__)  
print("CUDA toolkit version PyTorch was built with:", torch.version.cuda)  
print("cuDNN version:", torch.backends.cudnn.version()) 
print("cuda available:", torch.cuda.is_available())

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.set_float32_matmul_precision('high')

PyTorch version: 2.7.1+cu128
CUDA toolkit version PyTorch was built with: 12.8
cuDNN version: 90701
cuda available: True


## Load Data

In [2]:
tokenizer = loader("tokenizers/cnn_tokenizer.pkl")

## Initialize Model

In [3]:
torch.manual_seed(42)

embed_dim = 64*18
ff_dim = 4*embed_dim
heads = 18
tf_blocks = 18

model = Transformer(
    embed_dim=embed_dim,
    ff_dim=ff_dim,
    heads=heads,
    tf_blocks=tf_blocks,
    vocab_size=tokenizer.vocab_size,
    max_seq_len=1024,
    dropout=0.1,
    start_token_id=tokenizer.token_to_idx["<s>"],
    use_weight_tying=True
).to(device)

optimizer_grouped_parameters = group_decay_parameters(
    model,
    weight_decay=0.1,
    no_decay=["bias", "LayerNorm.weight"],
    )

loss_train_list = []
loss_eval_list = []

In [4]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=5e-5)
scaler = torch.amp.GradScaler("cuda")

num_epochs      = 1
steps_per_epoch = 1
warmup_steps    = 1000

def lr_lambda(step):
    if step < warmup_steps:
        return float(step) / float(max(1, warmup_steps))
    return 1.0

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

In [6]:
model, optimizer, scheduler = load_checkpoint("models/checkpoint_transformer.pth", model, optimizer, scheduler)

In [7]:
import textwrap
import ipywidgets as widgets
from IPython.display import display
from torch.distributions import Categorical

class Inference:
    def __init__(self, model, tokenizer, context_length, device):
        self.model = model
        self.tokenizer = tokenizer
        self.context_length = context_length
        self.device = device
        

    def run(self, text, T, k, mode=None):
        if mode == "summary":
            text = "<s><b>" + text + "<h>"
        elif mode == "expand":
            text = "<s><h>" + text + "<b>"
        else:
            pass

        tokens = torch.tensor(self.tokenizer.encode(text.lower()), dtype=torch.long).reshape(1, -1).to(self.device)

        self.display = Display()

        model.eval()
        with torch.no_grad():
            for i in range(self.context_length):
                next = self.next_token(tokens, T, k,)

                tokens = torch.cat([tokens, next.reshape(1,1)], dim=1)
                text = tokenizer.decode(tokens[0].tolist())
                self.display.update(text)

                if next[0] == tokenizer.token_to_idx["</s>"]:
                    break
                

    def next_token(self, tokens, T, k):
        logits = self.model(tokens)[0, -1:]
        topk_vals, _    = torch.topk(logits, k=k)
        kth_value       = topk_vals[:,-1]

        logits = torch.where(logits >= kth_value, logits, -torch.inf)
        dist = Categorical(logits=logits/T)
        next = dist.sample()

        return next


class Display:
    def __init__(self):
        self.wrapper = textwrap.TextWrapper(width=80)

        self.ta = widgets.Textarea(
            value="",
            layout=widgets.Layout(width='80ch', height='20em'),
            disabled=True
        )
        display(self.ta)

    def update(self, text):
        self.ta.value = self.wrapper.fill(text.replace("\n", " "))  # this updates in-place


In [8]:
inference = Inference(model, tokenizer, context_length=1024, device=device)

In [9]:
text = "<s><h>A magical horse was spotted in England. It was found in a field. The magical horse is able to breath fire. Scientists don't know where it came from.<b>"

T = 1
k = 50
inference.run(text, T, k)

Textarea(value='', disabled=True, layout=Layout(height='20em', width='80ch'))

KeyboardInterrupt: 