# Trained vs Random-Init GPT Completion

This notebook loads a trained checkpoint and compares text generation against a random-init model with the exact same architecture.

By default it looks for:
- `../config.yaml`
- `../step_002400.pt`

If your files are elsewhere, update the two path variables in the next cell.

In [1]:
from pathlib import Path
import sys

import torch
import yaml

PROJECT_ROOT = Path.cwd().resolve().parent if Path.cwd().name == "notebooks" else Path.cwd().resolve()
if str(PROJECT_ROOT / "src") not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT / "src"))

CONFIG_PATH = Path("../config.yaml")
CHECKPOINT_PATH = Path("../step_002400.pt")

def resolve_existing_path(default_path: Path, fallback_pattern: str) -> Path:
    candidate = (Path.cwd() / default_path).resolve()
    if candidate.exists():
        return candidate

    matches = sorted(PROJECT_ROOT.rglob(fallback_pattern))
    if not matches:
        raise FileNotFoundError(
            f"Could not find {default_path} and no fallback match for pattern '{fallback_pattern}'."
        )
    return matches[-1]

resolved_config_path = resolve_existing_path(CONFIG_PATH, "config.yaml")
resolved_checkpoint_path = resolve_existing_path(CHECKPOINT_PATH, "step_002400.pt")

print("Project root:", PROJECT_ROOT)
print("Config path:", resolved_config_path)
print("Checkpoint path:", resolved_checkpoint_path)

Project root: /Users/GabrielLevaillant/Desktop/Pet-Projects/local-llm-training-k8s
Config path: /Users/GabrielLevaillant/Desktop/Pet-Projects/local-llm-training-k8s/runs/20260219_172720_a45befc_gpt-wikitext-local/config.yaml
Checkpoint path: /Users/GabrielLevaillant/Desktop/Pet-Projects/local-llm-training-k8s/runs/20260219_172720_a45befc_gpt-wikitext-local/checkpoints/step_002400.pt


In [2]:
from llmtrain.config.schemas import RunConfig
from llmtrain.registry import initialize_registries
from llmtrain.registry.models import get_model_adapter

with resolved_config_path.open("r", encoding="utf-8") as f:
    cfg_dict = yaml.safe_load(f)
cfg = RunConfig.model_validate(cfg_dict)

payload = torch.load(resolved_checkpoint_path, map_location="cpu", weights_only=False)
checkpoint_config = payload.get("config")
if isinstance(checkpoint_config, dict) and checkpoint_config != cfg.model_dump():
    print("Warning: checkpoint config differs from config.yaml; using config.yaml for model build.")

initialize_registries()
adapter_cls = get_model_adapter(cfg.model.name)
adapter = adapter_cls()

device_name = cfg.run.device
if device_name == "mps" and not torch.backends.mps.is_available():
    print("MPS requested but unavailable; falling back to CPU.")
    device_name = "cpu"
device = torch.device(device_name)

trained_model = adapter.build_model(cfg)
trained_model.load_state_dict(payload["model_state_dict"])
trained_model = trained_model.to(device).eval()

random_model = adapter.build_model(cfg)
random_model = random_model.to(device).eval()

tokenizer = adapter.build_tokenizer(cfg)
if tokenizer is None:
    raise RuntimeError("Tokenizer is required for text generation in this notebook.")

print("Model adapter:", adapter_cls.__name__)
print("Model type:", cfg.model.name)
print("Device:", device)
print("Checkpoint step:", payload.get("step"))

Model adapter: GPTAdapter
Model type: gpt
Device: mps
Checkpoint step: 2400


In [5]:
def count_parameters(model: torch.nn.Module) -> int:
    return sum(p.numel() for p in model.parameters())

def generate_text(
    model: torch.nn.Module,
    tokenizer,
    prompt: str,
    max_new_tokens: int = 48,
    temperature: float = 0.8,
    top_k: int | None = 40,
    seed: int = 1234,
) -> str:
    torch.manual_seed(seed)
    if device.type == "mps":
        # Keep MPS and CPU randomness aligned for repeatable sampling.
        torch.mps.manual_seed(seed)

    encoded = tokenizer.encode(prompt)
    x = torch.tensor([encoded], dtype=torch.long, device=device)

    for _ in range(max_new_tokens):
        x_cond = x[:, -cfg.model.block_size :]
        with torch.no_grad():
            logits = model(x_cond)
            next_logits = logits[:, -1, :]

        if temperature <= 0:
            next_token = torch.argmax(next_logits, dim=-1, keepdim=True)
        else:
            next_logits = next_logits / temperature
            if top_k is not None and top_k > 0:
                k = min(top_k, next_logits.size(-1))
                values, _ = torch.topk(next_logits, k=k)
                cutoff = values[:, -1].unsqueeze(-1)
                next_logits = torch.where(
                    next_logits < cutoff,
                    torch.full_like(next_logits, float("-inf")),
                    next_logits,
                )

            probs = torch.softmax(next_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)

        x = torch.cat((x, next_token), dim=1)

    return tokenizer.decode(x[0].tolist())

print("Parameter count (trained):", f"{count_parameters(trained_model):,}")
print("Parameter count (random): ", f"{count_parameters(random_model):,}")

Parameter count (trained): 40,691,328
Parameter count (random):  40,691,328


In [27]:
prompt = "Here is a list of my favorite things"
max_new_tokens = 5
temperature = 0.8
top_k = 40
seed = 7

trained_completion = generate_text(
    trained_model,
    tokenizer,
    prompt=prompt,
    max_new_tokens=max_new_tokens,
    temperature=temperature,
    top_k=top_k,
    seed=seed,
)

random_completion = generate_text(
    random_model,
    tokenizer,
    prompt=prompt,
    max_new_tokens=max_new_tokens,
    temperature=temperature,
    top_k=top_k,
    seed=seed,
)

print("=== Prompt ===")
print(prompt)
print("\n=== Trained model completion ===")
print(trained_completion)
print("\n=== Random-init model completion ===")
print(random_completion)

=== Prompt ===
Here is a list of my favorite things

=== Trained model completion ===
Here is a list of my favorite things of the second @-

=== Random-init model completion ===
Here is a list of my favorite thingsdebdullahSurv TrinityAlbert


In [28]:
def top_next_tokens(model: torch.nn.Module, tokenizer, text: str, k: int = 10) -> list[tuple[str, float]]:
    ids = tokenizer.encode(text)
    x = torch.tensor([ids[-cfg.model.block_size :]], dtype=torch.long, device=device)
    with torch.no_grad():
        logits = model(x)[:, -1, :]
        probs = torch.softmax(logits, dim=-1)
    top_probs, top_ids = torch.topk(probs, k=min(k, probs.size(-1)), dim=-1)
    out: list[tuple[str, float]] = []
    for token_id, p in zip(top_ids[0].tolist(), top_probs[0].tolist()):
        token_text = tokenizer.decode([token_id]).replace("\n", "\\n")
        out.append((token_text, float(p)))
    return out

trained_top = top_next_tokens(trained_model, tokenizer, prompt, k=10)
random_top = top_next_tokens(random_model, tokenizer, prompt, k=10)

print("Top next-token candidates (trained):")
for token, prob in trained_top:
    print(f"{token!r}: {prob:.4f}")

print("\nTop next-token candidates (random-init):")
for token, prob in random_top:
    print(f"{token!r}: {prob:.4f}")

Top next-token candidates (trained):
' .': 0.1339
' ,': 0.0880
' that': 0.0518
' to': 0.0491
' and': 0.0362
' in': 0.0268
" '": 0.0261
' of': 0.0225
' "': 0.0216
' as': 0.0143

Top next-token candidates (random-init):
' Sie': 0.0001
' bookmark': 0.0001
' stationed': 0.0001
' things': 0.0001
' Melissa': 0.0001
' tsunami': 0.0001
' happen': 0.0001
' inspir': 0.0001
' reprim': 0.0001
' Finding': 0.0001
