# Description

In this notebook, I will visualize the architecture of Pytorch model. 

In [1]:
import os
import sys 
import time
import torch
torch.manual_seed(123)
import torch.nn as nn

from utils.model import Llama3Model, generate, text_to_token_ids, token_ids_to_text
from utils.tokenizer import Llama3Tokenizer, ChatFormat, clean_text
from utils.model import LLAMA32_CONFIG_1B, LLAMA32_CONFIG_3B
from utils.model import compute_rope_params

from torchinfo import summary

from utils.quantization import *

In [11]:
# ===== Hyper-parameter =====
MODEL_FILE = "model/llama3.2-1B-instruct.pth"
MODEL_CONTEXT_LENGTH = 8192  # Support up to 131_072
MAX_NEW_TOKENS = 100
TEMPERATURE = 0.
TOP_K = 1
TOKENIZER_FILE = "model/tokenizer.model"

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# 1. Load model

In [3]:
# Load model
if os.path.exists(MODEL_FILE) == False:
    print(f"[ERROR] Model does not exist !!!")
    sys.exit(0)

if "1B" in MODEL_FILE:
    llama32_config = LLAMA32_CONFIG_1B
elif "3B" in MODEL_FILE:
    llama32_config = LLAMA32_CONFIG_3B
else:
    print(f"[ERROR] Check model file again !!!")
    sys.exit(0)

llama32_config["context_length"] = MODEL_CONTEXT_LENGTH

model = Llama3Model(llama32_config)
checkpoint = torch.load(MODEL_FILE, weights_only=True, map_location=device)
model.load_state_dict(checkpoint)
model.to(device)

print(f"Model loaded successfully from {MODEL_FILE}")
print(f"Model size (in GB): {sum(p.numel() for p in model.parameters())*4/1024/1024/1024:.2f}")

Model loaded successfully from model/llama3.2-1B-instruct.pth
Model size (in GB): 5.58


In [4]:
batch_size = 1
input_seq_len = 16
vocab_size = model.cfg["vocab_size"]
dummy_input = torch.randint(0, vocab_size, (batch_size, input_seq_len),\
                            dtype=torch.int32).to(device)

print(f"Shape of dummy input: {dummy_input.shape}")

Shape of dummy input: torch.Size([1, 16])


In [5]:
summary(model, input_size=(batch_size, input_seq_len), dtypes=[torch.int32],\
        col_names=["input_size", "output_size", "num_params"])

Layer (type:depth-idx)                        Input Shape               Output Shape              Param #
Llama3Model                                   [1, 16]                   [1, 16, 128256]           --
├─Embedding: 1-1                              [1, 16]                   [1, 16, 2048]             262,668,288
├─ModuleList: 1-2                             --                        --                        --
│    └─TransformerBlock: 2-1                  [1, 16, 2048]             [1, 16, 2048]             --
│    │    └─RMSNorm: 3-1                      [1, 16, 2048]             [1, 16, 2048]             2,048
│    │    └─GroupedQueryAttention: 3-2        [1, 16, 2048]             [1, 16, 2048]             10,485,760
│    │    └─RMSNorm: 3-3                      [1, 16, 2048]             [1, 16, 2048]             2,048
│    │    └─FeedForward: 3-4                  [1, 16, 2048]             [1, 16, 2048]             50,331,648
│    └─TransformerBlock: 2-2                  [1, 16, 2

# 2. Loading weight of first Transformer block

In [6]:
transformer_block_0 = model.trf_blocks[0]
for idx, (name, param) in enumerate(transformer_block_0.named_parameters()):
    print(f"{idx}: {name} | Size: {param.size()}")

0: att.W_key.weight | Size: torch.Size([512, 2048])
1: att.W_value.weight | Size: torch.Size([512, 2048])
2: att.W_query.weight | Size: torch.Size([2048, 2048])
3: att.out_proj.weight | Size: torch.Size([2048, 2048])
4: ff.fc1.weight | Size: torch.Size([8192, 2048])
5: ff.fc2.weight | Size: torch.Size([8192, 2048])
6: ff.fc3.weight | Size: torch.Size([2048, 8192])
7: norm1.weight | Size: torch.Size([2048])
8: norm2.weight | Size: torch.Size([2048])


## 2.1. Quantization single weight

In [7]:
matrix = transformer_block_0.att.W_query.weight.data

print(f"Shape of matrix: {matrix.shape}")
print(f"Dtype: {matrix.dtype}")
print(f"Min: {matrix.min().item()}, Max: {matrix.max().item()}")

Shape of matrix: torch.Size([2048, 2048])
Dtype: torch.bfloat16
Min: -0.67578125, Max: 0.58203125


In [8]:
num_bits = 8

scale, zero_point = compute_quantization_param(matrix, num_bits)
# print(f"Scale: {scale}, Zero point: {zero_point}")

matrix_quantized = quantize_tensor(matrix, scale, zero_point, num_bits)
print(f"Shape of matrix_quantized: {matrix_quantized.shape}")
print(f"Dtype: {matrix_quantized.dtype}")
print(f"Min: {matrix_quantized.min().item()}, Max: {matrix_quantized.max().item()}")

Shape of matrix_quantized: torch.Size([2048, 2048])
Dtype: torch.int8
Min: -128, Max: 126


In [9]:
matrix_dequantized = dequantize_tensor(matrix_quantized, scale, zero_point)
print(f"Shape of matrix_dequantized: {matrix_dequantized.shape}")
print(f"Dtype: {matrix_dequantized.dtype}")

mse_error = nn.MSELoss()(matrix, matrix_dequantized).item()
print(f"MSE error: {mse_error:.6f}")

Shape of matrix_dequantized: torch.Size([2048, 2048])
Dtype: torch.float32
MSE error: 0.000002


## 2.2. Quantization activation

In [32]:
tokenizer = Llama3Tokenizer(TOKENIZER_FILE)

def get_embedding_of_text(input_text:str, model,\
                          tokenizer, device:str) -> torch.Tensor:
    """
    Function to get the embedding output of input text
    """
    input_ids = text_to_token_ids(input_text, tokenizer)
    input_ids = torch.tensor(input_ids, dtype=torch.int32).to(device)
    with torch.no_grad():
        emb_output = model.tok_emb(input_ids)
    return emb_output

In [39]:
input_text = "What is the capital of VietNam?"
input_seq_len = len(text_to_token_ids(input_text, tokenizer))

input_embedding = get_embedding_of_text(input_text, model, tokenizer, device)
print(f"Shape of input_embedding: {input_embedding.shape}, Dtype: {input_embedding.dtype}")

Shape of input_embedding: torch.Size([1, 8, 2048]), Dtype: torch.bfloat16


  input_ids = torch.tensor(input_ids, dtype=torch.int32).to(device)


In [44]:
W_query = transformer_block_0.att.W_query.weight.data
print(f"Shape of W_query: {W_query.shape}, Dtype: {W_query.dtype}")

query_output = torch.matmul(input_embedding, W_query.T)
print(f"Shape of query_output: {query_output.shape}, Dtype: {query_output.dtype}")

Shape of W_query: torch.Size([2048, 2048]), Dtype: torch.bfloat16
Shape of query_output: torch.Size([1, 8, 2048]), Dtype: torch.bfloat16


In [47]:
scale, zero_point = compute_quantization_param(W_query, num_bits)
W_query_quantized = quantize_tensor(W_query, scale, zero_point, num_bits)
print(f"Shape of W_query_quantized: {W_query_quantized.shape},\
        Dtype: {W_query_quantized.dtype}")

W_query_dequantized = dequantize_tensor(W_query_quantized, scale, zero_point, output_dtype=W_query.dtype)
print(f"Shape of W_query_dequantized: {W_query_dequantized.shape},\
        Dtype: {W_query_dequantized.dtype}")

Shape of W_query_quantized: torch.Size([2048, 2048]),        Dtype: torch.int8
Shape of W_query_dequantized: torch.Size([2048, 2048]),        Dtype: torch.bfloat16


In [51]:
query_output_dequantized = torch.matmul(input_embedding, W_query_dequantized.T)
print(f"Shape of query_output_dequantized: {query_output_dequantized.shape}, Dtype: {query_output_dequantized.dtype}")

mse_error = nn.MSELoss()(query_output, query_output_dequantized).item()
l2_error = torch.norm(query_output - query_output_dequantized, p=2).item()
print(f"MSE error: {mse_error:.6f}")
print(f"L2 error: {l2_error:.6f}")

Shape of query_output_dequantized: torch.Size([1, 8, 2048]), Dtype: torch.bfloat16
MSE error: 0.000002
L2 error: 0.164062
