# Description

In this notebook, I will visualize the architecture of Pytorch model. 

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
import sys 
import time
import torch
torch.manual_seed(123)
import torch.nn as nn

from model import Llama3Model, generate, text_to_token_ids, token_ids_to_text
from tokenizer import Llama3Tokenizer, ChatFormat, clean_text
from model import LLAMA32_CONFIG_1B, LLAMA32_CONFIG_3B
from model import compute_rope_params

from torchinfo import summary
from torchviz import make_dot
from torchview import draw_graph

In [None]:
# ===== Hyper-parameter =====
MODEL_FILE = "model/llama3.2-1B-base.pth"
MODEL_CONTEXT_LENGTH = 8192  # Support up to 131_072
MAX_NEW_TOKENS = 100
TEMPERATURE = 0.
TOP_K = 1
TOKENIZER_FILE = "tokenizer.model"

device = "cpu"
torch.set_default_device(device)

# 1. Layer summarize (text-based)

In [None]:
# Load model
if os.path.exists(MODEL_FILE) == False:
    print(f"[ERROR] Model does not exist !!!")
    sys.exit(0)

if "1B" in MODEL_FILE:
    llama32_config = LLAMA32_CONFIG_1B
elif "3B" in MODEL_FILE:
    llama32_config = LLAMA32_CONFIG_3B
else:
    print(f"[ERROR] Check model file again !!!")
    sys.exit(0)

llama32_config["context_length"] = MODEL_CONTEXT_LENGTH

model = Llama3Model(llama32_config)
checkpoint = torch.load(MODEL_FILE, weights_only=True, map_location=device)
model.load_state_dict(checkpoint)
model.to(device)

print(f"Model loaded successfully from {MODEL_FILE}")
print(f"Model size (in GB): {sum(p.numel() for p in model.parameters())*4/1024/1024/1024:.2f}")

In [None]:
batch_size = 1
input_seq_len = 16
vocab_size = model.cfg["vocab_size"]
dummy_input = torch.randint(0, vocab_size, (batch_size, input_seq_len),\
                            dtype=torch.int32).to(device)

print(f"Shape of dummy input: {dummy_input.shape}")

In [None]:
summary(model, input_size=(batch_size, input_seq_len), dtypes=[torch.int32],\
        col_names=["input_size", "output_size", "num_params"])

# 2. Graph visualizaiton

In [None]:
batch_size = 1
input_seq_len = 16
vocab_size = model.cfg["vocab_size"]
dummy_input = torch.randint(0, vocab_size, (batch_size, input_seq_len),\
                            dtype=torch.int32).to(device)

print(f"Shape of dummy input: {dummy_input.shape}")
print(dummy_input.device)

In [None]:
y = model(dummy_input)
print(f"Shape of model output: {y.shape}")

In [None]:
graph = draw_graph(
    model,
    input_data=dummy_input,
    expand_nested=True,       # expand TransformerBlocks etc.
    depth=1,                  # how deep to expand (increase for more detail)
    save_graph=True,          # saves to file
    filename="visualization/llama_forward",
)

# 3. Loading weight of first Transformer block

In [None]:
transformer_block_0 = model.trf_blocks[0]
for idx, (name, param) in enumerate(transformer_block_0.named_parameters()):
    print(f"{idx}: {name} | Size: {param.size()}")

In [None]:
batch_size = 1
input_seq_len = 16
embed_dim = 2048
vocab_size = model.cfg["vocab_size"]
dummy_input = torch.randint(0, vocab_size, (batch_size, input_seq_len, embed_dim),\
                            dtype=torch.bfloat16).to(device)

print(f"Shape of dummy input: {dummy_input.shape}")
mask = torch.triu(torch.ones(input_seq_len, input_seq_len, device=device,\
                             dtype=torch.bool), diagonal=1)
cos = torch.zeros((input_seq_len, model.cfg["emb_dim"]//model.cfg["n_heads"]),\
                  device=device, dtype=torch.bfloat16)
sin = torch.zeros((input_seq_len, model.cfg["emb_dim"]//model.cfg["n_heads"]),\
                  device=device, dtype=torch.bfloat16)
                               
output_transformer_block = transformer_block_0(dummy_input, mask, cos, sin)
print(f"Shape of transformer output: {output_transformer_block.shape}")

In [None]:
class BlockWrapper(nn.Module):
    def __init__(self, transformer_block):
        super().__init__()
        self.block = transformer_block

    def forward(self, x, mask, cos, sin):
        return self.block(x, mask, cos, sin)

In [None]:
wrapper = BlockWrapper(transformer_block_0)

# Draw the forward graph
graph = draw_graph(
    wrapper,
    input_data=(dummy_input, mask, cos, sin),
    expand_nested=True,
    depth=2,
    save_graph=True,
    filename="visualization/transformer_block"
)