# Topological Data Analysis of CNN Weight Spaces

## Overview

This notebook performs comprehensive topological analysis on CNN weight spaces using persistent homology. We analyze 36,468 trained CNN models from the Merged zoo dataset to understand the geometric and topological structure of neural network weight spaces.

### Dataset Information
- **Total Models**: 36,468 trained CNNs
- **Parameters per Model**: 2,464 (2,416 weights + 48 biases)
- **Activation Functions**: 6 types (gelu, relu, silu, tanh, sigmoid, leakyrelu) - 6,078 models each
- **Training Epochs**: 6 checkpoints (11, 16, 21, 26, 31, 36)
- **MNIST Classes**: 10 digits (0-9), one-hot encoded

### Analysis Strategy

We employ **two complementary grouping strategies** to understand different aspects of the weight space:

#### **Group A: Activation Function Analysis**
Groups models by their activation function to understand how different non-linearities shape the weight space topology.
- 6 groups: gelu, relu, silu, tanh, sigmoid, leakyrelu
- Each group contains 6,078 models
- **Goal**: Identify topological signatures unique to each activation function

#### **Group B: MNIST Class Analysis**
Groups models by which MNIST digit class appears in their training set (using one-hot encoding).
- 10 groups: digits 0-9
- Variable group sizes depending on label combinations
- **Goal**: Detect class-specific topological patterns in weight spaces

### Topological Methods

1. **Vietoris-Rips Persistence**: Compute H0 (connected components) and H1 (loops) persistence diagrams
2. **Distance Metrics**: Bottleneck and Wasserstein distances between persistence diagrams
3. **Vectorized Representations**: Persistence landscapes, Betti curves, persistence images
4. **Epoch Evolution**: Track topological changes across training epochs
5. **Multi-parameter Persistence**: 2D filtration using distance and weight norm

### Computational Considerations

- **Memory Management**: Models are subsampled to 100 per group to manage RAM usage
- **Dimensionality Reduction**: PCA reduces 2,464-d weight vectors to 20-d before persistence computation
- **CPU Optimization**: Parallel processing disabled for gudhi/multipers to avoid thread conflicts
- **Figure Quality**: High-DPI (250) figures with large sizes for readability

### Output Structure

```
figures/05_topology/
├── group_A_activation/     # Activation function analysis results
│   ├── pd_*.png           # Persistence diagrams
│   ├── distances_*.png    # Pairwise distance heatmaps
│   ├── landscapes_*.png   # Persistence landscapes
│   └── *.csv             # Summary statistics
└── group_B_class/         # MNIST class analysis results
    ├── pd_class_*.png
    ├── distances_*.png
    └── *.csv
```

In [None]:
# Cell 0 - Environment Setup and Imports
# =====================================

# Set matplotlib for inline display
%matplotlib inline

# Standard library imports
import os
import sys
import warnings
import json
import ast
import gc
import re
from pathlib import Path
from typing import List, Dict, Tuple, Optional, Any

# Numerical and scientific computing
import numpy as np
import pandas as pd
from scipy.spatial.distance import pdist
from scipy.stats import entropy
import matplotlib.pyplot as plt
import seaborn as sns

# Machine learning utilities
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.metrics import pairwise_distances

# Topological Data Analysis
from gudhi import RipsComplex
from gudhi.hera import bottleneck_distance, wasserstein_distance

# Try to import persistence visualization, fallback to custom implementation
try:
    from giotto_tdaviz import PersistenceDiagram
    HAS_GIOTTO_VIZ = True
except ImportError:
    HAS_GIOTTO_VIZ = False
    print("Note: giotto_tdaviz not available, using custom persistence plotting")

# Try to import vectorized representations, fallback to custom implementation
try:
    from giotto_tda import PersistenceLandscape, BettiCurve, PersistenceImage
    HAS_GIOTTO_TDA = True
except ImportError:
    HAS_GIOTTO_TDA = False
    print("Note: giotto_tda not available, persistence landscapes will be skipped")

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore')

# Set CPU threads for memory efficiency
os.environ['OMP_NUM_THREADS'] = '4'
os.environ['MKL_NUM_THREADS'] = '4'
os.environ['NUMEXPR_NUM_THREADS'] = '4'

# Configure matplotlib for high-quality figures
plt.rcParams['figure.dpi'] = 250
plt.rcParams['savefig.dpi'] = 250
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 10
plt.rcParams['axes.titlesize'] = 14
plt.rcParams['axes.labelsize'] = 12
plt.rcParams['xtick.labelsize'] = 10
plt.rcParams['ytick.labelsize'] = 10
plt.rcParams['legend.fontsize'] = 10
plt.rcParams['axes.titleweight'] = 'bold'
plt.rcParams['axes.labelweight'] = 'bold'

# Path configurations
ROOT = Path("/home/aymen/Documents/GitHub/Federated-Continual-learning-/New")
DATA_DIR = ROOT / "data"
MERGED_ZOO = DATA_DIR / "Merged zoo.csv"
FIG_DIR = ROOT / "notebooks_sandbox" / "figures" / "05_topology"

# Create figure directories
FIG_DIR_A = FIG_DIR / "group_A_activation"
FIG_DIR_B = FIG_DIR / "group_B_class"
FIG_DIR.mkdir(parents=True, exist_ok=True)
FIG_DIR_A.mkdir(parents=True, exist_ok=True)
FIG_DIR_B.mkdir(parents=True, exist_ok=True)

# Analysis parameters
RANDOM_SEED = 42
N_SUBSAMPLE = 100
PCA_DIM = 20
MAX_DIM = 1  # Compute H0 and H1

# Column definitions
ACTIVATION_COLS = ["gelu", "relu", "silu", "tanh", "sigmoid", "leakyrelu"]
DIGIT_COLS = [str(i) for i in range(10)]

# Epoch list for analysis
EPOCHS = [11, 16, 21, 26, 31, 36]

print("=" * 70)
print("  TOPOLOGICAL DATA ANALYSIS - ENVIRONMENT SETUP")
print("=" * 70)
print(f"\nPython version: {sys.version}")
print(f"Working directory: {ROOT}")
print(f"Data directory: {DATA_DIR}")
print(f"Figure directory: {FIG_DIR}")
print(f"\nAnalysis configuration:")
print(f"  Random seed: {RANDOM_SEED}")
print(f"  Subsample size: {N_SUBSAMPLE} models per group")
print(f"  PCA dimension: {PCA_DIM}")
print(f"  Max homology dimension: {MAX_DIM}")
print(f"  Epochs to analyze: {EPOCHS}")
print(f"\nPackage availability:")
print(f"  giotto_tdaviz: {HAS_GIOTTO_VIZ}")
print(f"  giotto_tda: {HAS_GIOTTO_TDA}")
print(f"\nOutput directories:")
print(f"  Group A (Activations): {FIG_DIR_A}")
print(f"  Group B (Classes): {FIG_DIR_B}")
print("\n" + "=" * 70)
print("✓ Environment configured successfully")
print("=" * 70)

## Data Loading and Preprocessing

Load the Merged zoo CSV file and prepare it for topological analysis.

**Data Validation Steps**:
1. Load CSV and verify dimensions
2. Extract weight and bias columns (2,464 parameters total)
3. Parse activation function labels (one-hot encoded → string)
4. Parse MNIST class labels (one-hot encoded)
5. Convert epoch and accuracy to numeric types
6. Display sample rows for verification

**Expected Output**: DataFrame with 36,468 rows × 2,483 columns

In [None]:
# Cell 1 - Load and Validate Merged Zoo Data

print("Loading Merged zoo.csv...")

# Load the dataset
df = pd.read_csv(MERGED_ZOO)

print("Dataset Dimensions:")
print(f"  Rows:    {len(df):,}")
print(f"  Columns: {len(df.columns):,}")

# Extract parameter columns
weight_cols = [c for c in df.columns if c.startswith("weight ")]
bias_cols = [c for c in df.columns if c.startswith("bias ")]
param_cols = weight_cols + bias_cols

print("")
print("Parameter Breakdown:")
print(f"  Weights: {len(weight_cols)}")
print(f"  Biases:  {len(bias_cols)}")
print(f"  Total:   {len(param_cols)}")

# Convert activation and digit columns to numeric (one-hot encoded)
for c in ACTIVATION_COLS + DIGIT_COLS:
    if c in df.columns:
        df[c] = pd.to_numeric(df[c], errors="coerce").fillna(0).astype(int)

# Convert metadata columns
df["Accuracy"] = pd.to_numeric(df["Accuracy"], errors="coerce")
df["epoch"] = pd.to_numeric(df["epoch"], errors="coerce").astype(int)

# Extract activation function name from one-hot encoding
def get_activation(row):
    """Extract activation function name from one-hot encoded columns."""
    for act in ACTIVATION_COLS:
        if row.get(act, 0) == 1:
            return act
    return "unknown"

df["activation"] = df.apply(get_activation, axis=1)

# Parse label tuples (for reference, though we'll use one-hot encoding for class analysis)
def parse_label(s):
    """Parse label string into sorted tuple of integers."""
    try:
        return tuple(sorted(ast.literal_eval(s)))
    except Exception:
        return None

df["label_tuple"] = df["label"].apply(parse_label)

# Data validation summary
print("")
print("Data Validation:")
print(f"  Unique epochs:  {sorted(df['epoch'].unique())}")
print(f"  Unique labels:  {df['label_tuple'].nunique()}")
print("  Activation distribution:")
for act in ACTIVATION_COLS:
    count = (df["activation"] == act).sum()
    print(f"    {act:12s}: {count:5d} models")

print("")
print("MNIST Class distribution (models containing each digit):")
for digit in DIGIT_COLS:
    count = (df[digit] == 1).sum()
    print(f"    Digit {digit}: {count:5d} models")

print("")
print("Data loaded successfully!")
print("=" * 70)
print("Sample Data (first 5 rows):")
print("=" * 70)

# Display sample with key columns
display_cols = ["label", "activation", "epoch", "Accuracy"] + DIGIT_COLS
display(df[display_cols].head())

print("")
print(f"Memory usage: {df.memory_usage(deep=True).sum() / 1024**2:.1f} MB")

## Helper Functions for Persistence Analysis

Define utility functions for topological data analysis with memory-efficient implementations.

**Function Catalog**:

1. **`subsample_group()`**: Randomly sample models from a group to manage memory
2. **`compute_persistence_diagram()`**: Calculate Vietoris-Rips persistence using gudhi
3. **`plot_persistence_diagrams()`**: Visualize persistence diagrams with high-quality formatting
4. **`persistence_summary()`**: Extract statistical features from persistence diagrams
5. **`smart_annotate_heatmap()`**: Conditionally annotate heatmap cells based on spacing

**Memory Management**:
- Automatic garbage collection after large computations
- Figures closed after saving to free memory
- PCA used for dimensionality reduction before persistence computation

In [None]:
# Cell 2 - Helper Functions for Topological Analysis

def subsample_group(df_group, param_cols, n=100, seed=42):
    """Randomly sample models from a group for memory efficiency."""
    if len(df_group) <= n:
        return df_group[param_cols].values
    else:
        np.random.seed(seed)
        indices = np.random.choice(len(df_group), n, replace=False)
        return df_group.iloc[indices][param_cols].values

def compute_persistence_diagram(X, max_dim=1, max_edge=1.0):
    """Compute Vietoris-Rips persistence diagram from point cloud."""
    if len(X) < 3:
        return {0: np.array([]), 1: np.array([])}, None
    
    rips = RipsComplex(points=X, max_edge_length=max_edge)
    simplex_tree = rips.create_simplex_tree(max_dimension=max_dim + 2)
    
    persistence = simplex_tree.persistence()
    
    diagrams = {}
    for dim in range(max_dim + 1):
        dim_pers = [(b, d) for b, d, p_dim in persistence if p_dim == dim and d > b]
        if dim_pers:
            diagrams[dim] = np.array(dim_pers)
        else:
            diagrams[dim] = np.array([])
    
    return diagrams, simplex_tree

def plot_persistence_diagrams(diagrams, title="Persistence Diagrams", 
                             save_path=None, figsize=(16, 7)):
    """Plot persistence diagrams with high-quality formatting."""
    fig, axes = plt.subplots(1, len(diagrams), figsize=figsize)
    if len(diagrams) == 1:
        axes = [axes]
    
    colors = ['#e74c3c', '#3498db']
    
    for idx, (dim, dgm) in enumerate(diagrams.items()):
        ax = axes[idx]
        
        if len(dgm) > 0:
            ax.scatter(dgm[:, 0], dgm[:, 1], c=colors[dim], s=30, alpha=0.7, 
                      edgecolors='black', linewidth=0.5)
            
            # Plot diagonal
            max_val = max(np.max(dgm), 1.0)
            ax.plot([0, max_val], [0, max_val], 'k--', alpha=0.3, linewidth=1)
            
            # Set limits
            ax.set_xlim(0, max_val * 1.1)
            ax.set_ylim(0, max_val * 1.1)
        else:
            ax.set_xlim(0, 1)
            ax.set_ylim(0, 1)
        
        ax.set_xlabel("Birth", fontweight='bold')
        ax.set_ylabel("Death", fontweight='bold')
        ax.set_title(f"H{dim} Persistence", fontweight='bold')
        ax.grid(True, alpha=0.3)
        ax.set_aspect('equal')
    
    fig.suptitle(title, fontsize=16, fontweight='bold')
    plt.tight_layout()
    
    if save_path:
        fig.savefig(save_path, dpi=250, bbox_inches='tight')
        print(f"  ✓ Saved: {save_path.name}")
    
    plt.show()
    return fig

def persistence_summary(diagrams):
    """Extract summary statistics from persistence diagrams."""
    summary = {}
    
    for dim, dgm in diagrams.items():
        if len(dgm) > 0:
            persistence = dgm[:, 1] - dgm[:, 0]
            summary[f'H{dim}_n_features'] = len(dgm)
            summary[f'H{dim}_mean_persistence'] = np.mean(persistence)
            summary[f'H{dim}_std_persistence'] = np.std(persistence)
            summary[f'H{dim}_max_persistence'] = np.max(persistence)
            summary[f'H{dim}_total_persistence'] = np.sum(persistence)
        else:
            summary[f'H{dim}_n_features'] = 0
            summary[f'H{dim}_mean_persistence'] = 0
            summary[f'H{dim}_std_persistence'] = 0
            summary[f'H{dim}_max_persistence'] = 0
            summary[f'H{dim}_total_persistence'] = 0
    
    return summary

def smart_annotate_heatmap(data, ax, fmt=".3f", fontsize=9, threshold=0.1):
    """Conditionally annotate heatmap cells based on spacing."""
    n = data.shape[0]
    
    if n <= 4:
        # Small matrix - annotate all cells
        for i in range(n):
            for j in range(n):
                if i != j and not np.isnan(data[i, j]):
                    ax.text(j + 0.5, i + 0.5, fmt.format(data[i, j]),
                           ha='center', va='center', fontsize=fontsize,
                           fontweight='bold', color='white' if data[i, j] > threshold else 'black')
    elif n <= 6:
        # Medium matrix - annotate significant values
        for i in range(n):
            for j in range(n):
                if i != j and not np.isnan(data[i, j]) and data[i, j] > np.nanpercentile(data, 75):
                    ax.text(j + 0.5, i + 0.5, fmt.format(data[i, j]),
                           ha='center', va='center', fontsize=fontsize,
                           fontweight='bold')
    # Large matrix - no annotations for readability

print("✓ Helper functions defined successfully")