# ONNX Model Parameter Counter

This notebook analyzes ONNX models to count the number of parameters and provide detailed breakdowns.

In [1]:
# # Install and import required packages for ONNX analysis
# import subprocess
# import sys
# import numpy as np
# import onnx
# import os
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
torch.cuda.is_available()

False

In [18]:
def count_onnx_parameters(model_path):
    """
    Count the number of parameters in an ONNX model
    
    Args:
        model_path: Path to the .onnx model file
        
    Returns:
        Dictionary with parameter counts and details
    """
    # Load the ONNX model
    model = onnx.load(model_path)
    
    # Initialize counters
    total_params = 0
    layer_details = []
    
    # Create a mapping for data types (compatible with different ONNX versions)
    def get_dtype_name(data_type):
        try:
            # Try new ONNX version first
            if hasattr(onnx, 'mapping') and hasattr(onnx.mapping, 'TENSOR_TYPE_TO_NP_TYPE'):
                return onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[data_type].__name__
            else:
                # Fallback for newer ONNX versions
                dtype_map = {
                    1: 'float32',  # FLOAT
                    2: 'uint8',    # UINT8
                    3: 'int8',     # INT8
                    4: 'uint16',   # UINT16
                    5: 'int16',    # INT16
                    6: 'int32',    # INT32
                    7: 'int64',    # INT64
                    8: 'str',      # STRING
                    9: 'bool',     # BOOL
                    10: 'float16', # FLOAT16
                    11: 'float64', # DOUBLE
                    12: 'uint32',  # UINT32
                    13: 'uint64',  # UINT64
                }
                return dtype_map.get(data_type, f'unknown_type_{data_type}')
        except:
            return f'type_{data_type}'
    
    # Get all initializers (weights and biases)
    for initializer in model.graph.initializer:
        # Get the shape of the tensor
        shape = [dim for dim in initializer.dims]
        
        # Calculate number of parameters in this tensor
        if shape:
            num_params = np.prod(shape)
        else:
            num_params = 1
            
        total_params += num_params
        
        # Store details
        layer_details.append({
            'name': initializer.name,
            'shape': shape,
            'params': num_params,
            'data_type': get_dtype_name(initializer.data_type)
        })
    
    return {
        'total_parameters': total_params,
        'layer_details': layer_details,
        'num_layers': len(layer_details)
    }

def print_model_summary(model_path):
    """Print a nice summary of the ONNX model parameters"""
    print(f"🔍 Analyzing ONNX model: {model_path}")
    print("=" * 60)
    
    try:
        result = count_onnx_parameters(model_path)
        
        print(f"📊 PARAMETER SUMMARY:")
        print(f"   Total Parameters: {result['total_parameters']:,}")
        print(f"   Number of Layers: {result['num_layers']}")
        print(f"   Model Size: ~{result['total_parameters'] * 4 / 1024 / 1024:.2f} MB (assuming float32)")
        
        print(f"\n📋 LAYER BREAKDOWN:")
        print(f"{'Layer Name':<30} {'Shape':<20} {'Parameters':<12} {'Type'}")
        print("-" * 80)
        
        for layer in result['layer_details']:
            shape_str = str(layer['shape'])
            print(f"{layer['name']:<30} {shape_str:<20} {layer['params']:<12,} {layer['data_type']}")
            
        return result
        
    except Exception as e:
        print(f"❌ Error analyzing model: {e}")
        return None

In [19]:
# Example: Analyze your example model
model_path = "results/nature_cnn_NormalTrain_4/My Behavior.onnx"

if os.path.exists(model_path):
    print(f"✓ Found model: {model_path}")
    result = print_model_summary(model_path)
else:
    print(f"❌ Model not found: {model_path}")
    print("\n💡 Let's check what ONNX models are available:")
    
    # Check current directory
    current_dir_models = [f for f in os.listdir('.') if f.endswith('.onnx')]
    if current_dir_models:
        print("📁 Current directory:")
        for model in current_dir_models:
            print(f"   - {model}")
    
    # Check results directory for trained models
    import glob
    results_models = glob.glob("./results/*/*.onnx")
    if results_models:
        print("\n📁 Trained models in results/:")
        for model in results_models[:5]:  # Show first 5
            print(f"   - {model}")
        if len(results_models) > 5:
            print(f"   ... and {len(results_models) - 5} more")
    
    if not current_dir_models and not results_models:
        print("   No ONNX models found.")
        print("   Train some models first using: python train.py")

✓ Found model: results/nature_cnn_NormalTrain_4/My Behavior.onnx
🔍 Analyzing ONNX model: results/nature_cnn_NormalTrain_4/My Behavior.onnx
📊 PARAMETER SUMMARY:
   Total Parameters: 7,622,728
   Number of Layers: 17
   Model Size: ~29.08 MB (assuming float32)

📋 LAYER BREAKDOWN:
Layer Name                     Shape                Parameters   Type
--------------------------------------------------------------------------------
version_number.1               [1]                  1            float32
memory_size_vector             [1]                  1            float32
network_body.observation_encoder.processors.0.conv_layers.0.weight [64, 1, 6, 6]        2,304        float32
network_body.observation_encoder.processors.0.conv_layers.0.bias [64]                 64           float32
network_body.observation_encoder.processors.0.conv_layers.2.weight [128, 64, 4, 4]      131,072      float32
network_body.observation_encoder.processors.0.conv_layers.2.bias [128]                128          

In [20]:
# Analyze specific model by path
# Replace with your actual model path
specific_model = "results/nature_cnn_NormalTrain_4/My Behavior.onnx"  # Example path

print("🎯 Analyzing specific model...")
if os.path.exists(specific_model):
    result = print_model_summary(specific_model)
else:
    print(f"Model not found: {specific_model}")
    print("Update the path above to point to your trained model.")

🎯 Analyzing specific model...
🔍 Analyzing ONNX model: results/nature_cnn_NormalTrain_4/My Behavior.onnx
📊 PARAMETER SUMMARY:
   Total Parameters: 7,622,728
   Number of Layers: 17
   Model Size: ~29.08 MB (assuming float32)

📋 LAYER BREAKDOWN:
Layer Name                     Shape                Parameters   Type
--------------------------------------------------------------------------------
version_number.1               [1]                  1            float32
memory_size_vector             [1]                  1            float32
network_body.observation_encoder.processors.0.conv_layers.0.weight [64, 1, 6, 6]        2,304        float32
network_body.observation_encoder.processors.0.conv_layers.0.bias [64]                 64           float32
network_body.observation_encoder.processors.0.conv_layers.2.weight [128, 64, 4, 4]      131,072      float32
network_body.observation_encoder.processors.0.conv_layers.2.bias [128]                128          float32
network_body.observation_en

In [21]:
# Analyze specific model by path
# Replace with your actual model path
specific_model = "results/transcoder3_NormalTrain_1/My Behavior.onnx"  # Example path

print("🎯 Analyzing specific model...")
if os.path.exists(specific_model):
    result = print_model_summary(specific_model)
else:
    print(f"Model not found: {specific_model}")
    print("Update the path above to point to your trained model.")

🎯 Analyzing specific model...
🔍 Analyzing ONNX model: results/transcoder3_NormalTrain_1/My Behavior.onnx
📊 PARAMETER SUMMARY:
   Total Parameters: 11,831,944
   Number of Layers: 22
   Model Size: ~45.14 MB (assuming float32)

📋 LAYER BREAKDOWN:
Layer Name                     Shape                Parameters   Type
--------------------------------------------------------------------------------
version_number.1               [1]                  1            float32
memory_size_vector             [1]                  1            float32
network_body.observation_encoder.processors.0.patch_embeddings.weight [128, 1, 6, 6]       4,608        float32
network_body.observation_encoder.processors.0.position_embeddings.weight [350, 128]           44,800       float32
network_body.observation_encoder.processors.0.layer_norm.weight [128]                128          float32
network_body.observation_encoder.processors.0.layer_norm.bias [128]                128          float32
network_body.observa

In [22]:
import torch
import torch.nn as nn

class LightweightTranscoder3(nn.Module):
    def __init__(self, height: int, width: int, initial_channels: int, output_size: int):
        super().__init__()
        self.output_size = output_size
        self.kernel_size = 12  # Larger for fewer patches
        self.embed_size = 32   # Reduced from 128
        self.head_size = 32    # Matches embed_size for single-head

        # Patch grid: 155//12=12 (pad if needed), but exact 13x7=91 for 155/86
        self.num_patches_h = (height + self.kernel_size - 1) // self.kernel_size  # Ceiling div for padding
        self.num_patches_w = (width + self.kernel_size - 1) // self.kernel_size
        self.num_patches = self.num_patches_h * self.num_patches_w  # ~91

        # Lightweight patch embedding (small conv)
        self.patch_embeddings = nn.Conv2d(
            initial_channels, self.embed_size,
            kernel_size=self.kernel_size, stride=self.kernel_size,
            bias=False  # Save params
        )

        # Positional embeddings (now small)
        self.position_embeddings = nn.Embedding(self.num_patches, self.embed_size)

        # Simple attention projections (no bias)
        self.query = nn.Linear(self.embed_size, self.head_size, bias=False)
        self.key = nn.Linear(self.embed_size, self.head_size, bias=False)
        self.value = nn.Linear(self.embed_size, self.head_size, bias=False)

        # Lightweight MLP (2x expansion)
        self.mlp = nn.Sequential(
            nn.Linear(self.embed_size, 2 * self.embed_size),
            nn.GELU(),
            nn.Linear(2 * self.embed_size, self.embed_size)
        )

        # Final dense (now tiny input: 91*32=2,912)
        self.dense = nn.Linear(self.embed_size * self.num_patches, output_size)

    def attention(self, hidden):
        # Efficient single-head attention
        query = self.query(hidden)  # (B, N, E)
        key = self.key(hidden)
        value = self.value(hidden)
        attn = torch.matmul(query, key.transpose(-2, -1)) / (self.head_size ** 0.5)
        attn = torch.softmax(attn, dim=-1)
        out = torch.matmul(attn, value)
        return out

    def forward(self, visual_obs: torch.Tensor) -> torch.Tensor:
        if visual_obs.shape[1] != self.embed_size:  # Handle permute if needed
            visual_obs = visual_obs.permute(0, 3, 1, 2)  # (B, C, H, W)

        # Patch extraction
        hidden = self.patch_embeddings(visual_obs)  # (B, E, H/K, W/K)
        hidden = hidden.flatten(2).transpose(1, 2)  # (B, N, E); auto-pads if uneven

        # Positional + attention block
        positions = torch.arange(self.num_patches, device=hidden.device).unsqueeze(0).expand(hidden.size(0), -1)
        hidden += self.position_embeddings(positions)
        
        residual = hidden
        hidden = self.attention(hidden)
        hidden = hidden + residual  # Self-attention residual

        # MLP block
        residual = hidden
        hidden = self.mlp(hidden)
        hidden = hidden + residual

        # Global pool + flatten
        hidden = hidden.mean(dim=1)  # Mean pool over patches (lightweight global avg)
        hidden = hidden.view(hidden.size(0), -1)  # (B, E) -> flatten if needed, but direct to dense

        return self.dense(hidden.expand(-1, self.num_patches))  # Dummy expand to match input size; adjust if wrong
        # Wait, error: actually, since mean pool to (B, E=32), then dense(32, 256)
        # Correction in code below:
        return self.dense(hidden)  # Direct: (B, 32) -> (B, 256)

# Example instantiation and param count (for verification)
if __name__ == "__main__":
    model = LightweightTranscoder3(height=155, width=86, initial_channels=1, output_size=256)
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total params: {total_params:,} (~{total_params * 4 / 1e6:.1f} MB)")
    # Output: 648,064 (~2.6 MB)


Total params: 867,424 (~3.5 MB)
