# Debug Router Parameter Issue

This notebook helps debug why the `router` parameter is not being saved/loaded correctly in the checkpoint.

In [None]:
import torch
import torch.nn as nn
from dyna.model.model import DynaConfig, DynaFormer, DynaLM, ComposerDynaModel
from dyna.model.model_config import ExecutionMode
from transformers import AutoTokenizer

# Create a test config
config = DynaConfig(
    execution_mode=ExecutionMode.moe,
    d_model=412,
    n_layers=2,
    n_heads=4,
    enable_early_exit=True  # This should create the router parameter
)

print("Config created:", config.execution_mode)

In [None]:
# Create the model and examine its parameters
model = DynaFormer(config)

print("DynaFormer parameters:")
for name, param in model.named_parameters():
    print(f"  {name}: {param.shape}")

print("\nDynaFormer buffers:")
for name, buffer in model.named_buffers():
    print(f"  {name}: {buffer.shape}")

In [None]:
# Create the full LM model and examine its parameters
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-1.7B")
lm_model = DynaLM(config, tokenizer.eos_token_id)

print("DynaLM parameters:")
for name, param in lm_model.named_parameters():
    print(f"  {name}: {param.shape}")
    
print(f"\nLooking for 'router' parameters:")
router_params = [name for name, _ in lm_model.named_parameters() if 'router' in name]
print(f"Found router parameters: {router_params}")

In [None]:
# Create the Composer model and examine its parameters
composer_model = ComposerDynaModel(config, tokenizer)

print("ComposerDynaModel parameters:")
composer_router_params = [name for name, _ in composer_model.named_parameters() if 'router' in name]
print(f"Found router parameters: {composer_router_params}")

print(f"\nFull parameter structure:")
for name, param in composer_model.named_parameters():
    if 'transformer' in name:
        print(f"  {name}: {param.shape}")

In [None]:
# Test saving and loading the state dict
state_dict = composer_model.state_dict()

print("Keys in state_dict containing 'router':")
router_keys = [key for key in state_dict.keys() if 'router' in key]
print(router_keys)

print("\nKeys in state_dict containing 'transformer':")
transformer_keys = [key for key in state_dict.keys() if 'transformer' in key]
for key in transformer_keys[:10]:  # Show first 10
    print(f"  {key}")
if len(transformer_keys) > 10:
    print(f"  ... and {len(transformer_keys) - 10} more")

In [None]:
# Check if the router parameter is properly registered
print("DynaFormer module structure:")
print(f"Has router attribute: {hasattr(model, 'router')}")
print(f"Router is parameter: {isinstance(getattr(model, 'router', None), nn.Parameter)}")

print(f"\nRouter parameter details:")
if hasattr(model, 'router'):
    router = model.router
    print(f"  Shape: {router.shape}")
    print(f"  Requires grad: {router.requires_grad}")
    print(f"  Device: {router.device}")
    print(f"  Data type: {router.dtype}")

In [None]:
# Test the parameter path issue
print("Expected parameter path in checkpoint: 'model.transformer.router'")
print("Actual parameter paths:")

# Check the model hierarchy
print(f"composer_model.model type: {type(composer_model.model)}")
print(f"composer_model.model.transformer type: {type(composer_model.model.transformer)}")

# Check if router exists at the expected path
try:
    router_param = composer_model.model.transformer.router
    print(f"Found router at expected path: {router_param.shape}")
except AttributeError as e:
    print(f"Router not found at expected path: {e}")

In [None]:
# Try to understand why the parameter might not be saved
print("Checking parameter registration in modules:")

def check_module_params(module, prefix=""):
    for name, child in module.named_children():
        child_prefix = f"{prefix}.{name}" if prefix else name
        print(f"Module: {child_prefix} ({type(child).__name__})")
        
        for param_name, param in child.named_parameters(recurse=False):
            print(f"  Parameter: {child_prefix}.{param_name} - {param.shape}")
        
        # Recurse into children
        check_module_params(child, child_prefix)

check_module_params(composer_model.model)

In [None]:
# Check if there are any missing keys when loading
import copy

# Create a second model to test loading
config2 = copy.deepcopy(config)
composer_model2 = ComposerDynaModel(config2, tokenizer)

# Save state dict from first model
state_dict = composer_model.state_dict()

# Try to load into second model
try:
    missing_keys, unexpected_keys = composer_model2.load_state_dict(state_dict, strict=False)
    print(f"Missing keys: {missing_keys}")
    print(f"Unexpected keys: {unexpected_keys}")
except Exception as e:
    print(f"Error loading state dict: {e}")

## Analysis

Based on the error traceback, the issue seems to be that the checkpoint contains optimizer state for `model.transformer.router`, but when loading, this parameter doesn't exist in the current model state.

This could happen if:
1. The router parameter wasn't properly saved in the model state dict
2. The model architecture changed between saving and loading
3. There's a mismatch in the parameter naming between save and load

Let's investigate further...

## Analysis

Based on the error traceback, the issue seems to be that the checkpoint contains optimizer state for `model.transformer.router`, but when loading, this parameter doesn't exist in the current model state.

This could happen if:
1. The router parameter wasn't properly saved in the model state dict
2. The model architecture changed between saving and loading
3. There's a mismatch in the parameter naming between save and load

Let's investigate further...

In [None]:
# Check if there are any missing keys when loading
import copy

# Create a second model to test loading
config2 = copy.deepcopy(config)
composer_model2 = ComposerDynaModel(config2, tokenizer)

# Save state dict from first model
state_dict = composer_model.state_dict()

# Try to load into second model
try:
    missing_keys, unexpected_keys = composer_model2.load_state_dict(state_dict, strict=False)
    print(f"Missing keys: {missing_keys}")
    print(f"Unexpected keys: {unexpected_keys}")
except Exception as e:
    print(f"Error loading state dict: {e}")

In [None]:
# Try to understand why the parameter might not be saved
print("Checking parameter registration in modules:")

def check_module_params(module, prefix=""):
    for name, child in module.named_children():
        child_prefix = f"{prefix}.{name}" if prefix else name
        print(f"Module: {child_prefix} ({type(child).__name__})")
        
        for param_name, param in child.named_parameters(recurse=False):
            print(f"  Parameter: {child_prefix}.{param_name} - {param.shape}")
        
        # Recurse into children
        check_module_params(child, child_prefix)

check_module_params(composer_model.model)