In [4]:
# analyze_model.py - Part 1: Core Functions and Model Loading

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torchsummary import summary
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import os

# Suppress Matplotlib warnings for cleaner output
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

## -----------------------------------------------------------
## UTILITY FUNCTIONS FOR MODEL OUTPUT HANDLING
## -----------------------------------------------------------

def get_model_output(model, input_tensor, return_embeddings=False):
    """
    Safely handle model outputs that may return tuples or single tensors.
    Compatible with delta analysis models that return (output, embeddings).
    """
    model_output = model(input_tensor)
    
    if isinstance(model_output, tuple):
        output, embeddings = model_output
        if return_embeddings:
            return output, embeddings
        else:
            return output
    else:
        # Single tensor output
        if return_embeddings:
            return model_output, None
        else:
            return model_output

## -----------------------------------------------------------
## SECURE MODEL LOADING
## -----------------------------------------------------------

def load_model(model_path, model_class, device=None):
    """
    Securely loads the model from the .pth file.
    
    SECURITY FIXES:
    1. Uses weights_only=True to prevent arbitrary code execution
    2. Adds device mapping for cross-device compatibility
    3. Adds error handling for corrupted files
    4. Validates file exists before loading
    """
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Model file '{model_path}' not found.")
    
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    try:
        # Create model instance first
        model = model_class()
        
        # SECURITY FIX: Use weights_only=True to prevent pickle vulnerabilities
        # Map to CPU first to avoid device issues
        state_dict = torch.load(model_path, map_location='cpu', weights_only=True)
        
        # Load the state dict
        model.load_state_dict(state_dict)
        
        # Move to target device
        model = model.to(device)
        model.eval()  # Set model to evaluation mode
        
        print(f"Model loaded successfully on device: {device}")
        return model
        
    except Exception as e:
        print(f"Error loading model: {e}")
        print("This could be due to:")
        print("1. Architecture mismatch between saved model and model_class")
        print("2. Corrupted model file")
        print("3. Version incompatibility")
        raise

## -----------------------------------------------------------
## ANALYSIS FUNCTIONS
## -----------------------------------------------------------

def analyze_architecture(model, input_size):
    """Prints the model summary using torchsummary."""
    print("\n### 1. Model Architecture Summary ###")
    print("This shows the layers, output shapes, and parameter counts.")
    try:
        if isinstance(input_size, tuple):
            summary(model, input_size=input_size)
        else:
            summary(model, input_size=(input_size,))
    except Exception as e:
        print(f"Could not generate summary. Error: {e}")
        print("Printing model structure instead:")
        print("-" * 50)
        for name, module in model.named_modules():
            if name:  # Skip the root module
                print(f"{name}: {module}")
        print("-" * 50)

def analyze_weights(model):
    """Visualizes the distribution of weights and biases for each layer."""
    print("\n### 2. Weight and Bias Distribution ###")
    print("This helps identify issues like vanishing or exploding gradients.")
    
    params_list = list(model.named_parameters())
    if not params_list:
        print("No parameters found in the model.")
        return
    
    num_params = len(params_list)
    
    if num_params == 1:
        fig, ax = plt.subplots(1, 1, figsize=(8, 4))
        axes = [ax]
    else:
        fig, axes = plt.subplots(num_params, 1, figsize=(8, 2 * num_params))
        if num_params == 1:
            axes = [axes]
    
    fig.suptitle('Weight and Bias Distributions per Layer')
    
    for i, (name, param) in enumerate(params_list):
        if param.requires_grad and param.numel() > 0:
            ax = axes[i]
            param_data = param.data.cpu().numpy().flatten()
            
            if np.all(param_data == param_data[0]):
                ax.text(0.5, 0.5, f"All values = {param_data[0]:.4f}", 
                       transform=ax.transAxes, ha='center', va='center')
                ax.set_title(name)
            else:
                ax.hist(param_data, bins=min(100, len(param_data)), alpha=0.7)
                ax.set_title(f"{name} (mean: {param_data.mean():.4f}, std: {param_data.std():.4f})")
        else:
            axes[i].text(0.5, 0.5, "No gradients or empty tensor", 
                        transform=axes[i].transAxes, ha='center', va='center')
            axes[i].set_title(name)
    
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()

def analyze_feature_maps(model, layer_name, input_tensor):
    """Visualizes the feature maps (activations) of a specific convolutional layer."""
    print(f"\n### 3. Feature Map Visualization (Layer: {layer_name}) ###")
    print("This shows what features the model detects at an intermediate stage.")
    
    activations = {}
    def get_activation(name):
        def hook(model, input, output):
            activations[name] = output.detach()
        return hook

    try:
        available_layers = dict(model.named_modules())
        if layer_name not in available_layers:
            print(f"Error: Layer '{layer_name}' not found.")
            print(f"Available layers: {list(available_layers.keys())}")
            return

        target_layer = available_layers[layer_name]
        handle = target_layer.register_forward_hook(get_activation(layer_name))

        with torch.no_grad():
            _ = get_model_output(model, input_tensor)  # Use our utility function
        handle.remove()

        if layer_name not in activations:
            print(f"No activations captured for layer '{layer_name}'")
            return

        acts = activations[layer_name].squeeze()
        
        if acts.dim() == 1:
            print(f"Layer '{layer_name}' produces 1D output, cannot visualize as feature maps.")
            return
        elif acts.dim() == 2:
            acts = acts.unsqueeze(0)
        elif acts.dim() > 3:
            print(f"Layer '{layer_name}' produces {acts.dim()}D output, taking first sample.")
            acts = acts[0] if acts.dim() == 4 else acts
            
        num_maps = acts.size(0) if acts.dim() >= 3 else 1
        cols = min(8, num_maps)
        rows = (num_maps + cols - 1) // cols
        
        fig, axes = plt.subplots(rows, cols, figsize=(2 * cols, 2 * rows))
        fig.suptitle(f'Feature Maps from Layer: {layer_name}')
        
        if num_maps == 1:
            axes = [axes]
        elif rows == 1:
            axes = axes if hasattr(axes, '__len__') else [axes]
        else:
            axes = axes.flatten()
            
        for i in range(rows * cols):
            ax = axes[i] if hasattr(axes, '__len__') else axes
            if i < num_maps:
                if acts.dim() == 3:
                    img = acts[i].cpu().numpy()
                else:
                    img = acts.cpu().numpy()
                ax.imshow(img, cmap='viridis')
                ax.set_title(f'Map {i+1}', fontsize=8)
            else:
                ax.set_visible(False)
            ax.axis('off')
        plt.tight_layout(rect=[0, 0, 1, 0.96])
        plt.show()

    except Exception as e:
        print(f"Could not visualize feature maps. Error: {e}")
        import traceback
        traceback.print_exc()

def analyze_saliency(model, input_tensor):
    """Generates a saliency map to show which input pixels are most influential."""
    print("\n### 4. Saliency Maps ###")
    print("This highlights the pixels your model 'looks at' for its prediction.")

    input_tensor = input_tensor.clone().detach().requires_grad_(True)
    
    try:
        output = get_model_output(model, input_tensor)  # Use our utility function
        
        if output.dim() == 1:
            output_idx = output.argmax()
        else:
            output_idx = output.argmax(dim=1)
        
        output_max = output.flatten()[output_idx] if output.dim() > 1 else output[output_idx]
        
        model.zero_grad()
        if input_tensor.grad is not None:
            input_tensor.grad.zero_()
            
        output_max.backward()
        
        if input_tensor.grad is None:
            print("No gradients computed. Make sure the model supports gradient computation.")
            return
        
        saliency, _ = torch.max(input_tensor.grad.data.abs(), dim=1)
        saliency = saliency.squeeze()
        
        plt.figure(figsize=(10, 5))
        plt.subplot(1, 2, 1)
        plt.imshow(input_tensor.detach().squeeze().numpy(), cmap='gray')
        plt.title('Original Image')
        plt.axis('off')
        
        plt.subplot(1, 2, 2)
        plt.imshow(saliency.cpu().numpy(), cmap='hot')
        plt.title('Saliency Map')
        plt.axis('off')
        plt.suptitle('Saliency Map Analysis')
        plt.show()
        
    except Exception as e:
        print(f"Could not generate saliency map. Error: {e}")
        import traceback
        traceback.print_exc()