# Model Architecture Explorer

This notebook provides an interactive exploration of model architectures, encoders, and classification heads.
It allows you to visualize model structures, analyze parameters, and experiment with different configurations.

## Setup and Imports

Import all necessary libraries for model exploration.

In [None]:
import os
import sys
import warnings
from pathlib import Path
from typing import Dict, Any, List, Tuple, Optional
from collections import OrderedDict

import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer, AutoConfig
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import display, HTML, clear_output
import ipywidgets as widgets
from tqdm.auto import tqdm

# Suppress warnings
warnings.filterwarnings('ignore')

# Add src to path
if 'src' not in sys.path:
    sys.path.append('src')

print("‚úÖ All imports successful!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## Load Dependencies

Load configuration and model components.

In [None]:
# Load configuration management
%run 01_Configuration_Management.ipynb

# Import project modules
from src.models.model import EvidenceModel
from src.models.encoders import (
    RobertaEncoder, BertEncoder, DebertaEncoder,
    get_encoder_class
)
from src.models.heads import (
    MultiLabelClassificationHead, RegressionHead,
    get_head_class
)

print("‚úÖ Dependencies loaded!")

## Model Architecture Explorer

Interactive exploration of different model architectures.

In [None]:
def create_architecture_explorer():
    """Create an interactive model architecture explorer."""
    
    # Model type selection
    encoder_type = widgets.Dropdown(
        options=['roberta', 'bert', 'deberta'],
        value='roberta',
        description='Encoder Type:'
    )
    
    # Model size selection
    model_size = widgets.Dropdown(
        options=['base', 'large'],
        value='base',
        description='Model Size:'
    )
    
    # LoRA configuration
    use_lora = widgets.Checkbox(
        value=False,
        description='Use LoRA'
    )
    
    lora_r = widgets.IntSlider(
        value=16,
        min=4,
        max=64,
        step=4,
        description='LoRA r:',
        disabled=True
    )
    
    lora_alpha = widgets.IntSlider(
        value=32,
        min=8,
        max=128,
        step=8,
        description='LoRA Œ±:',
        disabled=True
    )
    
    # Head configuration
    num_labels = widgets.IntSlider(
        value=10,
        min=2,
        max=50,
        description='Num Labels:'
    )
    
    hidden_size = widgets.Dropdown(
        options=[256, 512, 768, 1024],
        value=512,
        description='Hidden Size:'
    )
    
    dropout = widgets.FloatSlider(
        value=0.1,
        min=0.0,
        max=0.5,
        step=0.05,
        description='Dropout:'
    )
    
    # Buttons
    create_button = widgets.Button(
        description='üèóÔ∏è Create Model',
        button_style='primary'
    )
    
    analyze_button = widgets.Button(
        description='üìä Analyze Architecture',
        button_style='info',
        disabled=True
    )
    
    output = widgets.Output()
    
    def on_lora_change(change):
        lora_r.disabled = not change['new']
        lora_alpha.disabled = not change['new']
    
    use_lora.observe(on_lora_change, names='value')
    
    def on_create_clicked(b):
        with output:
            output.clear_output()
            
            try:
                print(f"üèóÔ∏è Creating Model Architecture")
                print("=" * 40)
                
                # Create configuration
                config = ExperimentConfig()
                
                # Update encoder configuration
                config.model.encoder.type = encoder_type.value
                
                # Set model name based on type and size
                model_mapping = {
                    ('roberta', 'base'): 'roberta-base',
                    ('roberta', 'large'): 'roberta-large',
                    ('bert', 'base'): 'bert-base-uncased',
                    ('bert', 'large'): 'bert-large-uncased',
                    ('deberta', 'base'): 'microsoft/deberta-base',
                    ('deberta', 'large'): 'microsoft/deberta-large'
                }
                
                config.model.encoder.pretrained_model_name_or_path = model_mapping[
                    (encoder_type.value, model_size.value)
                ]
                
                # LoRA configuration
                config.model.encoder.lora.enabled = use_lora.value
                if use_lora.value:
                    config.model.encoder.lora.r = lora_r.value
                    config.model.encoder.lora.alpha = lora_alpha.value
                
                # Head configuration
                config.model.heads.symptom_labels.layers.hidden_size = hidden_size.value
                config.model.heads.symptom_labels.layers.dropout = dropout.value
                
                # Update number of labels
                config.data.multi_label_fields = [f"label_{i}" for i in range(num_labels.value)]
                
                print(f"   Encoder: {encoder_type.value}-{model_size.value}")
                print(f"   Model: {config.model.encoder.pretrained_model_name_or_path}")
                print(f"   LoRA: {'Enabled' if use_lora.value else 'Disabled'}")
                if use_lora.value:
                    print(f"     r={lora_r.value}, Œ±={lora_alpha.value}")
                print(f"   Labels: {num_labels.value}")
                print(f"   Hidden size: {hidden_size.value}")
                print(f"   Dropout: {dropout.value}")
                
                # Create model
                print(f"\nü§ñ Instantiating model...")
                model = EvidenceModel(config.model)
                
                # Count parameters
                total_params = sum(p.numel() for p in model.parameters())
                trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
                
                print(f"   ‚úÖ Model created successfully!")
                print(f"   Total parameters: {total_params:,}")
                print(f"   Trainable parameters: {trainable_params:,}")
                print(f"   Frozen parameters: {total_params - trainable_params:,}")
                
                # Store model globally for analysis
                global current_model, current_config
                current_model = model
                current_config = config
                
                # Enable analysis button
                analyze_button.disabled = False
                
            except Exception as e:
                print(f"‚ùå Error creating model: {e}")
                import traceback
                traceback.print_exc()
    
    def on_analyze_clicked(b):
        if 'current_model' in globals():
            analyze_model_architecture(current_model, current_config)
        else:
            print("‚ùå No model created. Please create a model first.")
    
    create_button.on_click(on_create_clicked)
    analyze_button.on_click(on_analyze_clicked)
    
    # Layout
    controls = widgets.VBox([
        widgets.HTML("<h3>Model Architecture Explorer</h3>"),
        widgets.HBox([encoder_type, model_size]),
        widgets.HBox([use_lora, lora_r, lora_alpha]),
        widgets.HBox([num_labels, hidden_size, dropout]),
        widgets.HBox([create_button, analyze_button])
    ])
    
    return widgets.VBox([controls, output])

# Display architecture explorer
architecture_explorer = create_architecture_explorer()
display(architecture_explorer)

## Model Analysis Functions

Detailed analysis of model architecture and parameters.

In [None]:
def analyze_model_architecture(model: nn.Module, config: ExperimentConfig):
    """Analyze model architecture in detail."""
    
    print(f"üìä Model Architecture Analysis")
    print("=" * 50)
    
    # Model summary
    print(f"\nüèóÔ∏è Model Structure:")
    print(f"   Model class: {model.__class__.__name__}")
    print(f"   Encoder type: {config.model.encoder.type}")
    print(f"   Pretrained model: {config.model.encoder.pretrained_model_name_or_path}")
    
    # Parameter analysis
    print(f"\nüìà Parameter Analysis:")
    
    param_info = []
    total_params = 0
    trainable_params = 0
    
    for name, module in model.named_modules():
        if len(list(module.children())) == 0:  # Leaf modules only
            module_params = sum(p.numel() for p in module.parameters())
            module_trainable = sum(p.numel() for p in module.parameters() if p.requires_grad)
            
            if module_params > 0:
                param_info.append({
                    'Module': name,
                    'Type': module.__class__.__name__,
                    'Total Params': module_params,
                    'Trainable Params': module_trainable,
                    'Frozen Params': module_params - module_trainable,
                    'Trainable %': (module_trainable / module_params * 100) if module_params > 0 else 0
                })
                
                total_params += module_params
                trainable_params += module_trainable
    
    # Create parameter DataFrame
    param_df = pd.DataFrame(param_info)
    param_df = param_df.sort_values('Total Params', ascending=False)
    
    print(f"   Total parameters: {total_params:,}")
    print(f"   Trainable parameters: {trainable_params:,} ({trainable_params/total_params*100:.1f}%)")
    print(f"   Frozen parameters: {total_params - trainable_params:,} ({(total_params - trainable_params)/total_params*100:.1f}%)")
    
    # Display top parameter-heavy modules
    print(f"\nüîù Top Parameter-Heavy Modules:")
    top_modules = param_df.head(10)
    for _, row in top_modules.iterrows():
        print(f"   {row['Module']}: {row['Total Params']:,} params ({row['Trainable %']:.1f}% trainable)")
    
    # Visualize parameter distribution
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    fig.suptitle('Model Parameter Analysis', fontsize=16)
    
    # 1. Parameter distribution by module type
    module_type_params = param_df.groupby('Type')['Total Params'].sum().sort_values(ascending=False)
    
    axes[0, 0].pie(module_type_params.values, labels=module_type_params.index, autopct='%1.1f%%')
    axes[0, 0].set_title('Parameters by Module Type')
    
    # 2. Trainable vs Frozen parameters
    trainable_frozen = [trainable_params, total_params - trainable_params]
    labels = ['Trainable', 'Frozen']
    colors = ['lightgreen', 'lightcoral']
    
    axes[0, 1].pie(trainable_frozen, labels=labels, colors=colors, autopct='%1.1f%%')
    axes[0, 1].set_title('Trainable vs Frozen Parameters')
    
    # 3. Top modules by parameter count
    top_10 = param_df.head(10)
    y_pos = np.arange(len(top_10))
    
    axes[1, 0].barh(y_pos, top_10['Total Params'], color='skyblue')
    axes[1, 0].set_yticks(y_pos)
    axes[1, 0].set_yticklabels([name.split('.')[-1] for name in top_10['Module']])
    axes[1, 0].set_xlabel('Parameters')
    axes[1, 0].set_title('Top 10 Modules by Parameter Count')
    
    # 4. Trainable percentage by module
    trainable_pct = param_df[param_df['Total Params'] > 1000]['Trainable %'].head(15)
    module_names = [name.split('.')[-1] for name in param_df[param_df['Total Params'] > 1000]['Module'].head(15)]
    
    axes[1, 1].bar(range(len(trainable_pct)), trainable_pct, color='lightgreen')
    axes[1, 1].set_xticks(range(len(trainable_pct)))
    axes[1, 1].set_xticklabels(module_names, rotation=45, ha='right')
    axes[1, 1].set_ylabel('Trainable %')
    axes[1, 1].set_title('Trainable Percentage by Module')
    
    plt.tight_layout()
    plt.show()
    
    # Memory estimation
    print(f"\nüíæ Memory Estimation:")
    
    # Estimate memory usage (rough approximation)
    param_memory_mb = total_params * 4 / (1024 * 1024)  # 4 bytes per float32 parameter
    gradient_memory_mb = trainable_params * 4 / (1024 * 1024)  # Gradients for trainable params
    optimizer_memory_mb = trainable_params * 8 / (1024 * 1024)  # Adam optimizer states (rough estimate)
    
    total_memory_mb = param_memory_mb + gradient_memory_mb + optimizer_memory_mb
    
    print(f"   Model parameters: {param_memory_mb:.1f} MB")
    print(f"   Gradients: {gradient_memory_mb:.1f} MB")
    print(f"   Optimizer states: {optimizer_memory_mb:.1f} MB")
    print(f"   Total (approx): {total_memory_mb:.1f} MB ({total_memory_mb/1024:.2f} GB)")
    
    # LoRA analysis if enabled
    if config.model.encoder.lora.enabled:
        print(f"\nüîß LoRA Analysis:")
        print(f"   LoRA rank (r): {config.model.encoder.lora.r}")
        print(f"   LoRA alpha: {config.model.encoder.lora.alpha}")
        print(f"   LoRA dropout: {config.model.encoder.lora.dropout}")
        
        # Count LoRA parameters
        lora_params = sum(p.numel() for name, p in model.named_parameters() if 'lora' in name.lower())
        if lora_params > 0:
            print(f"   LoRA parameters: {lora_params:,} ({lora_params/total_params*100:.2f}% of total)")
            reduction_ratio = total_params / lora_params if lora_params > 0 else 0
            print(f"   Parameter reduction: {reduction_ratio:.1f}x")
    
    return param_df

def compare_model_architectures():
    """Compare different model architectures."""
    
    print(f"üîÑ Model Architecture Comparison")
    print("=" * 50)
    
    # Define architectures to compare
    architectures = [
        {'name': 'RoBERTa-base', 'type': 'roberta', 'model': 'roberta-base'},
        {'name': 'BERT-base', 'type': 'bert', 'model': 'bert-base-uncased'},
        {'name': 'DeBERTa-base', 'type': 'deberta', 'model': 'microsoft/deberta-base'},
    ]
    
    comparison_data = []
    
    for arch in architectures:
        try:
            print(f"\nüìä Analyzing {arch['name']}...")
            
            # Create configuration
            config = ExperimentConfig()
            config.model.encoder.type = arch['type']
            config.model.encoder.pretrained_model_name_or_path = arch['model']
            
            # Create model
            model = EvidenceModel(config.model)
            
            # Count parameters
            total_params = sum(p.numel() for p in model.parameters())
            trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
            
            # Get model config for additional info
            try:
                model_config = AutoConfig.from_pretrained(arch['model'])
                hidden_size = getattr(model_config, 'hidden_size', 'N/A')
                num_layers = getattr(model_config, 'num_hidden_layers', 'N/A')
                num_heads = getattr(model_config, 'num_attention_heads', 'N/A')
            except:
                hidden_size = num_layers = num_heads = 'N/A'
            
            comparison_data.append({
                'Architecture': arch['name'],
                'Total Parameters': total_params,
                'Trainable Parameters': trainable_params,
                'Hidden Size': hidden_size,
                'Layers': num_layers,
                'Attention Heads': num_heads,
                'Memory (MB)': total_params * 4 / (1024 * 1024)
            })
            
            print(f"   ‚úÖ {arch['name']}: {total_params:,} parameters")
            
        except Exception as e:
            print(f"   ‚ùå Error with {arch['name']}: {e}")
    
    if comparison_data:
        # Create comparison DataFrame
        comparison_df = pd.DataFrame(comparison_data)
        
        print(f"\nüìã Architecture Comparison:")
        display(comparison_df)
        
        # Visualize comparison
        fig, axes = plt.subplots(1, 2, figsize=(15, 6))
        
        # Parameter comparison
        architectures = comparison_df['Architecture']
        total_params = comparison_df['Total Parameters']
        
        axes[0].bar(architectures, total_params, color=['skyblue', 'lightcoral', 'lightgreen'])
        axes[0].set_ylabel('Total Parameters')
        axes[0].set_title('Parameter Count Comparison')
        axes[0].tick_params(axis='x', rotation=45)
        
        # Memory comparison
        memory_usage = comparison_df['Memory (MB)']
        
        axes[1].bar(architectures, memory_usage, color=['skyblue', 'lightcoral', 'lightgreen'])
        axes[1].set_ylabel('Memory Usage (MB)')
        axes[1].set_title('Memory Usage Comparison')
        axes[1].tick_params(axis='x', rotation=45)
        
        plt.tight_layout()
        plt.show()
        
        return comparison_df
    
    return None

# Create comparison button
compare_button = widgets.Button(
    description='üîÑ Compare Architectures',
    button_style='warning'
)

def on_compare_clicked(b):
    compare_model_architectures()

compare_button.on_click(on_compare_clicked)

print("\nüîÑ Architecture Comparison:")
display(compare_button)

## Encoder Deep Dive

Detailed exploration of different encoder types.

In [None]:
def explore_encoder_details():
    """Explore encoder architectures in detail."""
    
    print(f"üîç Encoder Architecture Deep Dive")
    print("=" * 50)
    
    encoder_configs = {
        'RoBERTa': {
            'class': RobertaEncoder,
            'model': 'roberta-base',
            'description': 'Robustly Optimized BERT Pretraining Approach'
        },
        'BERT': {
            'class': BertEncoder,
            'model': 'bert-base-uncased',
            'description': 'Bidirectional Encoder Representations from Transformers'
        },
        'DeBERTa': {
            'class': DebertaEncoder,
            'model': 'microsoft/deberta-base',
            'description': 'Decoding-enhanced BERT with Disentangled Attention'
        }
    }
    
    for name, info in encoder_configs.items():
        print(f"\nü§ñ {name} Encoder:")
        print(f"   Description: {info['description']}")
        print(f"   Model: {info['model']}")
        
        try:
            # Load model config
            config = AutoConfig.from_pretrained(info['model'])
            
            print(f"   Architecture Details:")
            print(f"     Hidden size: {getattr(config, 'hidden_size', 'N/A')}")
            print(f"     Layers: {getattr(config, 'num_hidden_layers', 'N/A')}")
            print(f"     Attention heads: {getattr(config, 'num_attention_heads', 'N/A')}")
            print(f"     Intermediate size: {getattr(config, 'intermediate_size', 'N/A')}")
            print(f"     Max position embeddings: {getattr(config, 'max_position_embeddings', 'N/A')}")
            print(f"     Vocab size: {getattr(config, 'vocab_size', 'N/A')}")
            
            # Special features
            if name == 'DeBERTa':
                print(f"     Relative attention: {getattr(config, 'relative_attention', 'N/A')}")
                print(f"     Position bucket size: {getattr(config, 'position_bucket_size', 'N/A')}")
            
        except Exception as e:
            print(f"     ‚ùå Could not load config: {e}")

def explore_head_architectures():
    """Explore classification head architectures."""
    
    print(f"\nüéØ Classification Head Architectures")
    print("=" * 50)
    
    # Multi-label classification head
    print(f"\nüè∑Ô∏è  Multi-Label Classification Head:")
    print(f"   Purpose: Multi-label classification with sigmoid activation")
    print(f"   Architecture:")
    print(f"     Input: [batch_size, hidden_size]")
    print(f"     Hidden layer: Linear(hidden_size, head_hidden_size)")
    print(f"     Activation: ReLU or GELU")
    print(f"     Dropout: Configurable dropout rate")
    print(f"     Output: Linear(head_hidden_size, num_labels)")
    print(f"     Final activation: Sigmoid (for multi-label)")
    
    # Regression head
    print(f"\nüìä Regression Head:")
    print(f"   Purpose: Continuous value prediction")
    print(f"   Architecture:")
    print(f"     Input: [batch_size, hidden_size]")
    print(f"     Hidden layer: Linear(hidden_size, head_hidden_size)")
    print(f"     Activation: ReLU or GELU")
    print(f"     Dropout: Configurable dropout rate")
    print(f"     Output: Linear(head_hidden_size, 1)")
    print(f"     Final activation: None (linear output)")
    
    # Pooling strategies
    print(f"\nüèä Pooling Strategies:")
    pooling_strategies = {
        'cls': 'Use [CLS] token representation',
        'mean': 'Average pooling over all tokens',
        'max': 'Max pooling over all tokens',
        'attention': 'Attention-weighted pooling'
    }
    
    for strategy, description in pooling_strategies.items():
        print(f"   {strategy.upper()}: {description}")

# Create exploration buttons
encoder_button = widgets.Button(
    description='üîç Explore Encoders',
    button_style='info'
)

head_button = widgets.Button(
    description='üéØ Explore Heads',
    button_style='info'
)

def on_encoder_clicked(b):
    explore_encoder_details()

def on_head_clicked(b):
    explore_head_architectures()

encoder_button.on_click(on_encoder_clicked)
head_button.on_click(on_head_clicked)

print("\nüîç Architecture Deep Dive:")
display(widgets.HBox([encoder_button, head_button]))

print("\n‚úÖ Model Architecture Explorer complete!")
print("\nThis notebook provides:")
print("‚Ä¢ Interactive model architecture exploration")
print("‚Ä¢ Detailed parameter analysis and visualization")
print("‚Ä¢ Model architecture comparison")
print("‚Ä¢ Encoder and head architecture deep dive")
print("‚Ä¢ Memory usage estimation")
print("‚Ä¢ LoRA configuration analysis")