# GPT model smoke notebook

This notebook creates a small `GPT` model from `llmtrain.models.gpt`, prints parameter counts,
and runs a tiny forward pass to show basic model information.

In [4]:
import torch

from llmtrain.models.gpt import GPT


def count_parameters(model: torch.nn.Module) -> tuple[int, int, int]:
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    non_trainable = total - trainable
    return total, trainable, non_trainable

In [5]:
cfg = {
    "vocab_size": 256,
    "block_size": 32,
    "d_model": 64,
    "n_layers": 2,
    "n_heads": 4,
    "d_ff": 256,
    "dropout": 0.1,
    "tie_embeddings": True,
}

model = GPT(**cfg)

total, trainable, non_trainable = count_parameters(model)

print("model:", model.__class__.__name__)
print("config:", cfg)
print("total parameters:", f"{total:,}")
print("trainable parameters:", f"{trainable:,}")
print("non-trainable parameters:", f"{non_trainable:,}")

model: GPT
config: {'vocab_size': 256, 'block_size': 32, 'd_model': 64, 'n_layers': 2, 'n_heads': 4, 'd_ff': 256, 'dropout': 0.1, 'tie_embeddings': True}
total parameters: 118,528
trainable parameters: 118,528
non-trainable parameters: 0


In [6]:
torch.manual_seed(7)

batch_size = 2
seqlen = 16

input_ids = torch.randint(0, cfg["vocab_size"], (batch_size, seqlen), dtype=torch.long)
attention_mask = torch.ones((batch_size, seqlen), dtype=torch.long)

with torch.no_grad():
    logits = model(input_ids=input_ids, attention_mask=attention_mask)

print("input_ids shape:", tuple(input_ids.shape))
print("attention_mask shape:", tuple(attention_mask.shape))
print("logits shape:", tuple(logits.shape))
print("logits dtype:", logits.dtype)
print("device:", next(model.parameters()).device)
print("contains NaN:", bool(torch.isnan(logits).any()))

input_ids shape: (2, 16)
attention_mask shape: (2, 16)
logits shape: (2, 16, 256)
logits dtype: torch.float32
device: cpu
contains NaN: False


In [7]:
# Small extra info snapshot
embedding_params = model.token_embedding.weight.numel()
lm_head_params = model.lm_head.weight.numel()

print("token embedding params:", f"{embedding_params:,}")
print("lm_head params:", f"{lm_head_params:,}")
print("weights tied:", model.lm_head.weight.data_ptr() == model.token_embedding.weight.data_ptr())
print("number of transformer blocks:", len(model.blocks))

token embedding params: 16,384
lm_head params: 16,384
weights tied: True
number of transformer blocks: 2
