# Loading and Inspecting the QAT GPT-2 Checkpoint

This notebook shows how to properly load and inspect the saved .pth file

In [2]:
import torch
import json
from pathlib import Path
import sys
import os

# Add parent directory to path for imports
sys.path.append(os.path.dirname(os.path.abspath('')))

## 1. Load the Checkpoint File

In [4]:
# Path to your checkpoint
checkpoint_path = 'qat_gpt2_8bit_fp32_20250917_113554.pth'

# Check if file exists
if os.path.exists(checkpoint_path):
    print(f"✓ Checkpoint file found at: {checkpoint_path}")
    print(f"File size: {os.path.getsize(checkpoint_path) / (1024*1024):.2f} MB")
else:
    print(f"✗ File not found at: {checkpoint_path}")
    print("Please update the path to your checkpoint file")

✓ Checkpoint file found at: qat_gpt2_8bit_fp32_20250917_113554.pth
File size: 505.80 MB


In [5]:
# Load the checkpoint
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

checkpoint = torch.load(checkpoint_path, map_location=device)
print("\n✓ Checkpoint loaded successfully!")

Using device: cuda

✓ Checkpoint loaded successfully!


## 2. Inspect Checkpoint Structure

In [6]:
# Show top-level keys in the checkpoint
print("Top-level keys in checkpoint:")
print("="*50)
for i, key in enumerate(checkpoint.keys(), 1):
    print(f"{i}. {key}")

Top-level keys in checkpoint:
1. model_state_dict
2. model_config
3. training_config
4. timestamp


In [None]:
# Detailed information about each key
print("Detailed checkpoint contents:")
print("="*50)

for key in checkpoint.keys():
    value = checkpoint[key]
    print(f"\n'{key}':")
    
    if isinstance(value, dict):
        print(f"  Type: dict with {len(value)} items")
        if key == 'model_state_dict' or key == 'model':
            # Show sample of model weights
            sample_keys = list(value.keys())[:5]
            print(f"  Sample keys: {sample_keys}")
        elif len(value) < 200:  # For small dicts, show all keys
            for k, v in value.items():
                if isinstance(v, torch.Tensor):
                    print(f"    {k}: tensor shape {v.shape}")
                else:
                    print(f"    {k}: {v}")
    
    elif isinstance(value, list):
        print(f"  Type: list with {len(value)} items")
        print(f"  Contents: {value}")
    
    elif isinstance(value, torch.Tensor):
        print(f"  Type: tensor")
        print(f"  Shape: {value.shape}")
        print(f"  Dtype: {value.dtype}")
    
    elif isinstance(value, (int, float, str)):
        print(f"  Type: {type(value).__name__}")
        print(f"  Value: {value}")
    
    else:
        print(f"  Type: {type(value).__name__}")

Detailed checkpoint contents:

'model_state_dict':
  Type: dict with 1571 items
  Sample keys: ['wte.weight', 'wpe.weight', 'h.0.ln_1.weight', 'h.0.ln_1.bias', 'h.0.attn.bias']

'model_config':
  Type: dict with 22 items

'training_config':
  Type: dict with 17 items
    train_split: train[:5000]
    val_split: validation[:1000]
    batch_size: 8
    max_seq_length: 256
    doc_stride: 128
    learning_rate: 0.0001
    weight_decay: 0.01
    adam_epsilon: 1e-08
    adam_betas: (0.9, 0.999)
    max_grad_norm: 1.0
    num_iterations: 150
    gradient_accumulation_steps: 8
    eval_interval: 50
    save_interval: 100
    use_amp: True
    empty_cache_interval: 25
    num_workers: 0

'timestamp':
  Type: str
  Value: 20250917_113554


## 3. Load Configuration from JSON

In [10]:
# Load the training stats JSON which contains configuration
config_path = 'qat_training_stats_20250917_113554.json'

if os.path.exists(config_path):
    with open(config_path, 'r') as f:
        training_stats = json.load(f)
    
    print("Training stats/config loaded from JSON:")
    print("="*50)
    
    # Show config if it exists
    if 'model_config' in training_stats:
        config = training_stats['model_config']
        print("\nModel Configuration:")
        for key, value in config.items():
            print(f"  {key}: {value}")
    
    # Show training parameters
    if 'training_config' in training_stats:
        print("\nTraining Parameters:")
        for key, value in training_stats['training_config'].items():
            print(f"  {key}: {value}")
else:
    print(f"Config file not found at: {config_path}")

Training stats/config loaded from JSON:

Model Configuration:
  quantization_bits: 8
  n_layer: 6
  n_embd: 768
  n_head: 12

Training Parameters:
  train_split: train[:5000]
  val_split: validation[:1000]
  batch_size: 8
  max_seq_length: 256
  doc_stride: 128
  learning_rate: 0.0001
  weight_decay: 0.01
  adam_epsilon: 1e-08
  adam_betas: [0.9, 0.999]
  max_grad_norm: 1.0
  num_iterations: 150
  gradient_accumulation_steps: 8
  eval_interval: 50
  save_interval: 100
  use_amp: True
  num_workers: 0


## 4. Create Model and Load Weights

In [None]:
from transformers import GPT2Config
from shared.models import SwitchableQATGPT2

# Create GPT2Config from the JSON configuration
if 'config' in training_stats:
    config_dict = training_stats['config']
    
    # Create base GPT2 config
    gpt2_config = GPT2Config(
        vocab_size=config_dict.get('vocab_size', 50257),
        n_positions=config_dict.get('n_positions', 1024),
        n_embd=config_dict.get('n_embd', 768),
        n_layer=config_dict.get('n_layer', 12),
        n_head=config_dict.get('n_head', 12),
    )
    
    # Add custom QAT attributes
    for key, value in config_dict.items():
        setattr(gpt2_config, key, value)
    
    print("GPT2Config created with attributes:")
    important_attrs = ['n_layer', 'n_embd', 'n_head', 'vocab_size', 
                      'lora_rank', 'lora_alpha', 'lora_dropout']
    for attr in important_attrs:
        if hasattr(gpt2_config, attr):
            print(f"  {attr}: {getattr(gpt2_config, attr)}")

In [None]:
# Create the model
bit_widths = checkpoint.get('bit_widths', [4, 8, 16])
print(f"\nCreating model with bit widths: {bit_widths}")

model = SwitchableQATGPT2(gpt2_config, bit_widths=bit_widths, initialize_weights=False)
model = model.to(device)
print("✓ Model created successfully")

In [None]:
# Load the model weights
if 'model_state_dict' in checkpoint:
    state_dict = checkpoint['model_state_dict']
elif 'model' in checkpoint:
    state_dict = checkpoint['model']
else:
    # The checkpoint might be the state dict itself
    state_dict = checkpoint

# Load with strict=False to handle any mismatched keys
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)

print("\nModel weights loaded!")
if missing_keys:
    print(f"\n⚠ Missing keys: {len(missing_keys)}")
    print("Sample missing keys:", missing_keys[:5])
if unexpected_keys:
    print(f"\n⚠ Unexpected keys: {len(unexpected_keys)}")
    print("Sample unexpected keys:", unexpected_keys[:5])

if not missing_keys and not unexpected_keys:
    print("✓ All weights loaded perfectly!")

## 5. Test the Model

In [None]:
# Set model to evaluation mode
model.eval()

# Test with a simple input
from transformers import GPT2Tokenizer

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
test_text = "The quick brown fox"
inputs = tokenizer(test_text, return_tensors='pt').to(device)

print(f"Test input: '{test_text}'")
print(f"Input shape: {inputs['input_ids'].shape}")

# Run inference
with torch.no_grad():
    # Test different bit widths if model supports switching
    for bit_width in bit_widths:
        if hasattr(model, 'set_precision'):
            model.set_precision(bit_width)
        
        outputs = model(**inputs)
        logits = outputs.logits if hasattr(outputs, 'logits') else outputs[0]
        
        print(f"\nBit width {bit_width}:")
        print(f"  Output shape: {logits.shape}")
        print(f"  Output range: [{logits.min().item():.3f}, {logits.max().item():.3f}]")
        
        # Get predicted next token
        next_token_id = logits[0, -1].argmax().item()
        next_token = tokenizer.decode([next_token_id])
        print(f"  Predicted next token: '{next_token}'")

## 6. Alternative: Load Checkpoint for diagnose_model_issues.py

In [None]:
# This shows how to modify the checkpoint to work with diagnose_model_issues.py
# The script expects 'config' key in the checkpoint

# Create a new checkpoint with the expected structure
if 'config' not in checkpoint and 'config' in training_stats:
    print("Creating modified checkpoint with 'config' key...")
    
    modified_checkpoint = checkpoint.copy()
    modified_checkpoint['config'] = training_stats['config']
    
    # Save the modified checkpoint
    modified_path = checkpoint_path.replace('.pth', '_with_config.pth')
    torch.save(modified_checkpoint, modified_path)
    
    print(f"✓ Saved modified checkpoint to: {modified_path}")
    print("\nYou can now use this file with diagnose_model_issues.py:")
    print(f"python test/diagnose_model_issues.py --model_path {modified_path}")
else:
    print("Checkpoint already has 'config' key or config not found in JSON")