In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import random
from tqdm import tqdm
import copy

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)
random.seed(42)

# Define the function f(x1, x2) - using a complex non-linear function
def f(x1, x2):
    """A non-linear function combining sinusoidal and quadratic terms"""
    return np.sin(3 * x1) + 0.5 * np.cos(4 * x2) + 0.5 * (x1**2 - x2**2)

# Define the threshold T for binary classification
T = 0.5

# Generate random samples in the unit square
def generate_samples(n_samples):
    """Generate random samples in the unit square [-1, 1] × [-1, 1]"""
    X = np.random.uniform(-1, 1, (n_samples, 2))
    x1, x2 = X[:, 0], X[:, 1]
    y_values = np.array([f(x1i, x2i) for x1i, x2i in zip(x1, x2)])
    y = (y_values >= T).astype(np.float32)
    return X, y

# Define the MLP model with named modules
class MLP(nn.Module):
    def __init__(self, input_size, hidden_sizes, output_size):
        super(MLP, self).__init__()
        
        # Use ModuleDict for named modules
        self.layers = nn.ModuleDict()
        
        # Store layer names and dimensions for reference
        self.layer_info = {
            'input': {'name': 'Input Layer', 'size': input_size}
        }
        
        # Add hidden layers
        prev_size = input_size
        for i, hidden_size in enumerate(hidden_sizes):
            linear_name = f'linear{i+1}'
            relu_name = f'relu{i+1}'
            
            self.layers[linear_name] = nn.Linear(prev_size, hidden_size)
            self.layers[relu_name] = nn.ReLU()
            
            self.layer_info[linear_name] = {'name': f'Hidden Layer {i+1} (Linear)', 'size': hidden_size}
            self.layer_info[relu_name] = {'name': f'Hidden Layer {i+1} (ReLU)', 'size': hidden_size}
            
            prev_size = hidden_size
        
        # Add output layer
        output_linear = f'linear{len(hidden_sizes)+1}'
        self.layers[output_linear] = nn.Linear(prev_size, output_size)
        self.layers['sigmoid'] = nn.Sigmoid()
        
        self.layer_info[output_linear] = {'name': 'Output Layer (Linear)', 'size': output_size}
        self.layer_info['sigmoid'] = {'name': 'Output Layer (Sigmoid)', 'size': output_size}
        
        # Store layer names in order for forward pass
        self.layer_names = ['input'] + list(self.layers.keys())
    
    def forward(self, x):
        activations = {'input': x}
        
        for name in self.layer_names[1:]:  # Skip input layer
            x = self.layers[name](x)
            activations[name] = x
        
        return x, activations

# Function to train model and save snapshots at specified intervals
def train_model_with_snapshots(X_train, y_train, save_epochs, max_epochs):
    """Train the model and save snapshots at specified epochs"""
    # Model parameters
    input_size = 2
    hidden_sizes = [20, 10]
    output_size = 1
    learning_rate = 0.01
    batch_size = 32
    
    # Convert to PyTorch tensors
    X_train_tensor = torch.FloatTensor(X_train)
    y_train_tensor = torch.FloatTensor(y_train).view(-1, 1)
    
    # Create and initialize model
    model = MLP(input_size, hidden_sizes, output_size)
    
    # Loss function and optimizer
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    # Create dataset and dataloader
    dataset = torch.utils.data.TensorDataset(X_train_tensor, y_train_tensor)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    # Dictionary to store model snapshots
    model_snapshots = {}
    
    # Save initial untrained model
    model_snapshots[0] = copy.deepcopy(model)
    
    # Track losses
    losses = []
    epoch_nums = []
    
    # Training loop
    for epoch in tqdm(range(1, max_epochs + 1)):
        model.train()
        epoch_loss = 0
        
        for batch_X, batch_y in dataloader:
            # Forward pass - now returns both outputs and activations
            outputs, _ = model(batch_X)
            loss = criterion(outputs, batch_y)
            
            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
        
        avg_loss = epoch_loss / len(dataloader)
        losses.append(avg_loss)
        epoch_nums.append(epoch)
        
        # if epoch % 10 == 0:
        #     print(f"Epoch {epoch}/{max_epochs}, Loss: {avg_loss:.4f}")
        
        # Save model snapshot if it's in the save_epochs list
        if epoch in save_epochs:
            model_snapshots[epoch] = copy.deepcopy(model)
    
    # Create loss plot
    loss_fig = go.Figure()
    loss_fig.add_trace(go.Scatter(
        x=epoch_nums,
        y=losses,
        mode='lines+markers',
        name='Training Loss'
    ))
    
    loss_fig.update_layout(
        title=f'Training Loss (n={len(X_train)})',
        xaxis_title='Epoch',
        yaxis_title='Binary Cross-Entropy Loss',
        width=800,
        height=400,
    )
    
    return model_snapshots, loss_fig

# Function to create scatter plot grid of hidden features
def plot_column_scatter_grid(data, y_labels, column_names=None, 
                            marker_size=5, opacity=0.7,
                            height=800, width=800,
                            title="Scatter Plot Matrix",
                            color_map=None):
    # Ensure data is a numpy array
    data = np.asarray(data)
    y_labels = np.asarray(y_labels)
    
    # Get number of columns
    n_cols = data.shape[1]
    
    # Generate default column names if not provided
    if column_names is None:
        column_names = [f"Feature {i+1}" for i in range(n_cols)]
    
    # Ensure we have the right number of column names
    if len(column_names) != n_cols:
        raise ValueError(f"Number of column names ({len(column_names)}) "
                         f"doesn't match number of columns in data ({n_cols})")
    
    # Ensure we have the right number of labels
    if len(y_labels) != data.shape[0]:
        raise ValueError(f"Number of labels ({len(y_labels)}) "
                         f"doesn't match number of samples in data ({data.shape[0]})")
    
    # Get unique classes
    unique_classes = np.unique(y_labels)
    
    # Create color map if not provided
    if color_map is None:
        default_colors = ['#636EFA', '#EF553B', '#00CC96', '#AB63FA', '#FFA15A', '#19D3F3', '#FF6692', '#B6E880']
        color_map = {cls: default_colors[i % len(default_colors)] for i, cls in enumerate(unique_classes)}
    
    # Create subplot grid
    fig = make_subplots(
        rows=n_cols, 
        cols=n_cols,
        shared_xaxes=False, 
        shared_yaxes=False
    )
    
    # Loop through all pairs of columns
    for i in range(n_cols):
        for j in range(n_cols):
            # For each class, add a separate trace
            for cls in unique_classes:
                # Get indices for this class
                idx = y_labels == cls
                
                # Skip if no samples for this class
                if not np.any(idx):
                    continue
                
                # Extract data for this class and column pair
                x = data[idx, j]  # Column j will be on x-axis
                y = data[idx, i]  # Column i will be on y-axis
                
                # Add scatter plot trace for this class
                fig.add_trace(
                    go.Scatter(
                        x=x,
                        y=y,
                        mode='markers',
                        marker=dict(
                            size=marker_size,
                            opacity=opacity,
                            color=color_map[cls]
                        ),
                        name=f'Class {cls}',
                        # Only add legend for the first subplot
                        showlegend=(i==0 and j==0)
                    ),
                    row=i+1, col=j+1
                )
            
            # Add axis labels
            if i == n_cols-1:  # Bottom row
                fig.update_xaxes(title_text=column_names[j], row=i+1, col=j+1)
            if j == 0:  # First column
                fig.update_yaxes(title_text=column_names[i], row=i+1, col=j+1)
    
    # Update layout
    fig.update_layout(
        height=height, 
        width=width,
        title_text=title,
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.02,
            xanchor="right",
            x=1
        )
    )
    
    return fig

# Function to extract hidden layer activations for all epochs and layers
def extract_hidden_layers(model_snapshots, X_train):
    """Extract hidden layer activations for all model snapshots"""
    hidden_activations = {}
    
    for epoch, model in model_snapshots.items():
        # print(f"Extracting activations for epoch {epoch}")
        
        x = torch.FloatTensor(X_train)
        _, activations_dict = model(x)  # Get activations directly from model forward pass
        
        # Convert tensor activations to numpy arrays
        activations_dict = {
            layer_name: act.detach().numpy() 
            for layer_name, act in activations_dict.items()
        }
        
        # Store activations for this epoch
        hidden_activations[epoch] = activations_dict
    
    return hidden_activations

# Function to sample features from activation data
def sample_features(data, max_features, rng=None):
    """Sample max_features from data if there are more features than max_features"""
    if rng is None:
        rng = np.random.RandomState(42)  # Use fixed seed for reproducibility
        
    n_features = data.shape[1]
    
    if n_features <= max_features:
        # Use all features if fewer than max_features
        return data, list(range(n_features))
    
    # Randomly sample feature indices without replacement
    feature_indices = rng.choice(n_features, max_features, replace=False)
    feature_indices.sort()  # Sort for consistency in visualization
    
    # Select sampled features
    sampled_data = data[:, feature_indices]
    
    return sampled_data, feature_indices

def create_dual_slider_visualization(hidden_activations, y_train, max_features=5):
    """Create an interactive visualization with sliders for both epoch and layer"""
    # Get available epochs and layers
    epochs = sorted(list(hidden_activations.keys()))
    
    # Get layer names (should be same for all epochs)
    layer_names = list(hidden_activations[epochs[0]].keys())
    
    # Create a random number generator for consistent feature sampling
    rng = np.random.RandomState(42)
    
    # Use first epoch, first layer data to determine dimensions
    initial_data = hidden_activations[epochs[0]][layer_names[0]]
    
    # Initial layer: use the first hidden layer if available
    initial_layer_name = layer_names[min(1, len(layer_names) - 1)]
    
    # Sample features and create feature names for each layer
    layer_features = {}
    for layer_name in layer_names:
        layer_data = hidden_activations[epochs[0]][layer_name]
        sampled_data, feature_indices = sample_features(layer_data, max_features, rng)
        layer_features[layer_name] = {
            'indices': feature_indices,
            'names': [f"Feature {idx+1}" for idx in feature_indices]
        }
    
    # Create human-readable layer names
    layer_display_names = {}
    for layer_name in layer_names:
        if layer_name == 'input':
            layer_display_names[layer_name] = 'Input Layer'
        elif 'linear' in layer_name:
            # Extract number from linear layer name, if any
            num = ''.join(filter(str.isdigit, layer_name))
            if num:
                layer_display_names[layer_name] = f'Hidden Layer {num} (Linear)'
            else:
                layer_display_names[layer_name] = 'Output Layer (Linear)'
        elif 'relu' in layer_name:
            # Extract number from relu layer name, if any
            num = ''.join(filter(str.isdigit, layer_name))
            if num:
                layer_display_names[layer_name] = f'Hidden Layer {num} (ReLU)'
            else:
                layer_display_names[layer_name] = 'ReLU Activation'
        elif 'sigmoid' in layer_name:
            layer_display_names[layer_name] = 'Output Layer (Sigmoid)'
        else:
            # Default case - capitalize the layer name
            layer_display_names[layer_name] = layer_name.capitalize()
    
    # Create initial visualization
    initial_data = hidden_activations[epochs[0]][initial_layer_name]
    sampled_data, _ = sample_features(initial_data, max_features, rng)
    feature_names = layer_features[initial_layer_name]['names']
    
    # Create base figure with initial data
    fig = plot_column_scatter_grid(
        sampled_data, 
        y_train,
        column_names=feature_names,
        title=f"{layer_display_names[initial_layer_name]} Features (Epoch {epochs[0]})"
    )
    
    # Create a lookup table for all possible scatter traces
    trace_lookup = {}
    
    for epoch in tqdm(epochs, desc="Preparing visualization data"):
        for layer_name in layer_names:
            layer_data = hidden_activations[epoch][layer_name]
            feature_indices = layer_features[layer_name]['indices']
            
            # Sample features
            sampled_data, _ = sample_features(layer_data, max_features, rng)
            n_features = sampled_data.shape[1]
            
            # Skip if layer has no dimensions or no features
            if n_features == 0:
                continue
                
            # Create a dict to store scatter data for this epoch and layer
            trace_lookup[(epoch, layer_name)] = []
            
            # Loop through all pairs of features
            for i in range(n_features):
                for j in range(n_features):
                    # For each class
                    unique_classes = np.unique(y_train)
                    
                    for cls in unique_classes:
                        # Get indices for this class
                        idx = y_train == cls
                        
                        # Extract data for this class and feature pair
                        x = sampled_data[idx, j]
                        y = sampled_data[idx, i]
                        
                        # Store x and y data for this trace
                        trace_lookup[(epoch, layer_name)].append({
                            'row': i+1, 
                            'col': j+1, 
                            'x': x, 
                            'y': y,
                            'class': cls
                        })
    
    # Create frames for each combination of epoch and layer
    frames = []
    
    for epoch in epochs:
        for layer_name in layer_names:
            # Skip if this combination doesn't have data
            if (epoch, layer_name) not in trace_lookup:
                continue
                
            n_features = len(layer_features[layer_name]['indices'])
            frame_traces = []
            
            # Create traces for each subplot position
            for i in range(n_features):
                for j in range(n_features):
                    for cls_idx, cls in enumerate(np.unique(y_train)):
                        # Find the corresponding trace data
                        trace_idx = (i * n_features + j) * len(np.unique(y_train)) + cls_idx
                        
                        # Skip if index out of range
                        if trace_idx >= len(trace_lookup[(epoch, layer_name)]):
                            continue
                            
                        trace_data = trace_lookup[(epoch, layer_name)][trace_idx]
                        
                        # Create the trace for this frame
                        trace = go.Scatter(
                            x=trace_data['x'],
                            y=trace_data['y'],
                            mode='markers',
                            showlegend=False  # Legend is from base figure
                        )
                        
                        frame_traces.append(trace)
            
            # Create layout updates for axis labels
            frame_layout = {
                'title': f"{layer_display_names[layer_name]} Features (Epoch {epoch})"
            }
            
            # Update axis labels for each subplot
            for i in range(n_features):
                for j in range(n_features):
                    # Bottom row - update x-axis
                    if i == n_features-1:
                        xaxis_key = f'xaxis{i*n_features+j+1}' if i*n_features+j > 0 else 'xaxis'
                        frame_layout[xaxis_key] = {
                            'title': {'text': layer_features[layer_name]['names'][j]}
                        }
                    
                    # First column - update y-axis
                    if j == 0:
                        yaxis_key = f'yaxis{i*n_features+j+1}' if i*n_features+j > 0 else 'yaxis'
                        frame_layout[yaxis_key] = {
                            'title': {'text': layer_features[layer_name]['names'][i]}
                        }
            
            # Create a frame for this epoch and layer
            frame = go.Frame(
                data=frame_traces,
                name=f"epoch{epoch}_layer{layer_name}",
                layout=frame_layout
            )
            frames.append(frame)
    
    # Add frames to the figure
    fig.frames = frames
    
    # Create sliders for epoch and layer
    epoch_steps = []
    for i, epoch in enumerate(epochs):
        step = {
            "args": [
                [f"epoch{epoch}_layer{initial_layer_name}"],
                {"frame": {"duration": 300, "redraw": True}, "mode": "immediate", "transition": {"duration": 300}}
            ],
            "label": str(epoch),
            "method": "animate"
        }
        epoch_steps.append(step)
    
    layer_steps = []
    for layer_name in layer_names:
        step = {
            "args": [
                [f"epoch{epochs[0]}_layer{layer_name}"],
                {"frame": {"duration": 300, "redraw": True}, "mode": "immediate", "transition": {"duration": 300}}
            ],
            "label": layer_display_names[layer_name],
            "method": "animate"
        }
        layer_steps.append(step)
    
    # Create sliders
    sliders = [
        {
            "active": 0,
            "yanchor": "top",
            "xanchor": "left",
            "currentvalue": {
                "font": {"size": 16},
                "prefix": "Epoch: ",
                "visible": True,
                "xanchor": "right"
            },
            "transition": {"duration": 300},
            "pad": {"b": 10, "t": 50},
            "len": 0.9,
            "x": 0.1,
            "y": 0,
            "steps": epoch_steps
        },
        {
            "active": layer_names.index(initial_layer_name),
            "yanchor": "top",
            "xanchor": "left",
            "currentvalue": {
                "font": {"size": 16},
                "prefix": "Layer: ",
                "visible": True,
                "xanchor": "right"
            },
            "transition": {"duration": 300},
            "pad": {"b": 10, "t": 120},
            "len": 0.9,
            "x": 0.1,
            "y": 0.1,
            "steps": layer_steps
        }
    ]
    
    # Update layout with sliders
    fig.update_layout(
        sliders=sliders,
        updatemenus=[
            {
                "buttons": [
                    {
                        "args": [None, {"frame": {"duration": 300, "redraw": True}, "fromcurrent": True, "transition": {"duration": 300}}],
                        "label": "Play",
                        "method": "animate"
                    },
                    {
                        "args": [[None], {"frame": {"duration": 0, "redraw": True}, "mode": "immediate"}],
                        "label": "Pause",
                        "method": "animate"
                    }
                ],
                "direction": "left",
                "pad": {"r": 10, "t": 87},
                "showactive": False,
                "type": "buttons",
                "x": 0.1,
                "xanchor": "right",
                "y": 0,
                "yanchor": "top"
            }
        ],
        # Use the correct transition property - not animation_opts
        transition={"duration": 300},
        # Extra margin at bottom for sliders
        margin=dict(l=50, r=50, t=100, b=150)
    )
    
    # Add interactive feature
    fig.update_layout(
        hovermode="closest",
        clickmode="event+select"
    )
    
    # Custom JavaScript for slider synchronization
    fig.add_annotation(
        text="Use sliders to explore features across epochs and layers",
        xref="paper", yref="paper",
        x=0.5, y=1.05,
        showarrow=False
    )
    
    return fig

# Main function to run the experiment with dual-slider visualization
def run_experiment_with_dual_sliders(max_features=5):
    """Run the experiment with dual sliders for epoch and layer"""
    # Generate training data
    n_samples = 200
    X_train, y_train = generate_samples(n_samples)
    
    # Set the maximum number of epochs
    max_epochs = 500
    
    # Define epochs to save
    save_epochs = range(0,max_epochs,10)
    
    # Train model and capture snapshots
    print("Training models and capturing snapshots...")
    model_snapshots, loss_fig = train_model_with_snapshots(X_train, y_train, save_epochs, max_epochs)
    
    # Extract hidden layer activations for all epochs and layers
    print("Extracting hidden layer activations...")
    hidden_activations = extract_hidden_layers(model_snapshots, X_train)
    
    # Create interactive visualization with dual sliders
    print("Creating interactive visualization...")
    interactive_fig = create_dual_slider_visualization(hidden_activations, y_train, max_features=max_features)
    
    # Display the figures
    loss_fig.show()
    interactive_fig.show()
    
    return model_snapshots, hidden_activations

if __name__ == "__main__":
    # Make max_features configurable
    max_features = 8  # Change this value to show more or fewer features
    global model_snapshots  # Make this available to the visualization function
    model_snapshots, hidden_activations = run_experiment_with_dual_sliders(max_features)