# Ricci Flow Layer Depth Study - Master Script
## Training + Live Ricci Analysis (Combined)

This notebook combines:
- DNN training for varying layer depths (3-30)
- **Live Ricci curvature analysis** immediately after training each model
- Results saved incrementally to Google Drive

**Implements all paper metrics:**
- Forman-Ricci curvature: R(i,j) = 4 - deg(i) - deg(j)
- Geodesic mass: g_l = Σ γ_l(i,j)
- Aggregated Ricci coefficient (Rho)

**Based on:** `training.py`, `knn_fixed.py`, and "Deep Learning as Ricci Flow" (Baptista et al., 2024)

## 1. Mount Google Drive

In [None]:
# ============================================================================
# GOOGLE DRIVE MOUNT & PATH SETUP (Crucial for persistent storage)
# ============================================================================
from google.colab import drive
import os

# Mount Google Drive
drive.mount('/content/drive')

# Set output directory to Drive
OUTPUT_DIR = '/content/drive/MyDrive/ricci_study_combined'

if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)
    print(f"✓ Created output directory: {OUTPUT_DIR}")
else:
    print(f"✓ Output directory exists: {OUTPUT_DIR}")

## 2. Imports & GPU Detection

In [None]:
# ============================================================================
# IMPORTS
# ============================================================================
import numpy as np
import pandas as pd
import os
import time
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')

# TensorFlow/Keras
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.datasets import mnist

# Ricci Analysis
from sklearn.neighbors import NearestNeighbors
from scipy.sparse import csr_matrix, triu as sp_triu
from scipy.sparse.csgraph import shortest_path
from scipy.stats import pearsonr, spearmanr
from typing import List, Dict, Tuple

# Visualization
import matplotlib.pyplot as plt

# ============================================================================
# GPU DETECTION
# ============================================================================
print("=" * 60)
print("DEVICE DETECTION")
print("=" * 60)

gpus = tf.config.list_physical_devices('GPU')
if gpus:
    print(f"✓ GPU(s) detected: {len(gpus)}")
    for gpu in gpus:
        print(f"  - {gpu.name}")
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print("✓ GPU memory growth enabled")
    except RuntimeError as e:
        print(f"⚠ GPU memory growth setting failed: {e}")
else:
    print("⚠ No GPU detected. Training will use CPU.")
    print("  Tip: Enable GPU in Kaggle/Colab settings for faster training.")

print(f"\nTensorFlow version: {tf.__version__}")
print("=" * 60)

## 3. Configuration

In [None]:
# ============================================================================
# CONFIGURATION
# ============================================================================

# Binary classification digits
DIGIT_A = 4
DIGIT_B = 9

# Layer depths to test (3 to 30)
LAYER_DEPTHS = list(range(3, 31))

# Architecture configurations
ARCHITECTURES = {
    'narrow': {'width': 25, 'bottleneck': False},
    'wide': {'width': 50, 'bottleneck': False},
    'bottleneck': {'width': 50, 'bottleneck': True}
}

# Training parameters
NUM_MODELS = 25       # Models per configuration
EPOCHS = 50
BATCH_SIZE = 32
EARLY_STOP_ACCURACY = 0.99

# Ricci Analysis parameters
K_VALUE = 200         # ~10% of test samples for kNN
ACC_THRESHOLD = 0.0   # Use all models (set higher to filter)

# Output files
RESULTS_CSV = os.path.join(OUTPUT_DIR, 'ricci_combined_results.csv')

print("Configuration:")
print(f"  Digits: {DIGIT_A} vs {DIGIT_B}")
print(f"  Layer Depths: {LAYER_DEPTHS[0]}-{LAYER_DEPTHS[-1]}")
print(f"  Architectures: {list(ARCHITECTURES.keys())}")
print(f"  Models per config: {NUM_MODELS}")
print(f"  k for kNN: {K_VALUE}")
print(f"  Total configurations: {len(ARCHITECTURES) * len(LAYER_DEPTHS)}")

## 4. Load MNIST Data

In [None]:
# ============================================================================
# DATA LOADING
# ============================================================================

print("Loading MNIST data...")
(x_train_full, y_train_full), (x_test_full, y_test_full) = mnist.load_data()

# Filter for binary classification
train_mask = (y_train_full == DIGIT_A) | (y_train_full == DIGIT_B)
test_mask = (y_test_full == DIGIT_A) | (y_test_full == DIGIT_B)

x_train = x_train_full[train_mask].reshape(-1, 784).astype('float32') / 255.0
x_test = x_test_full[test_mask].reshape(-1, 784).astype('float32') / 255.0
y_train = (y_train_full[train_mask] == DIGIT_B).astype('int32')
y_test = (y_test_full[test_mask] == DIGIT_B).astype('int32')

print(f"\nDataset loaded:")
print(f"  Training samples: {x_train.shape[0]}")
print(f"  Test samples: {x_test.shape[0]}")
print(f"  Feature dimension: {x_train.shape[1]}")

## 5. Model Architecture

In [None]:
# ============================================================================
# MODEL ARCHITECTURE & TRAINING
# ============================================================================

class StopAt99(tf.keras.callbacks.Callback):
    """Stop training when accuracy reaches 99%."""
    def on_epoch_end(self, epoch, logs=None):
        if logs.get('accuracy', 0) >= EARLY_STOP_ACCURACY:
            self.model.stop_training = True


def build_model(arch_config: Dict, depth: int, input_dim: int = 784) -> Sequential:
    """Build DNN model based on architecture configuration."""
    model = Sequential()
    width = arch_config['width']
    is_bottleneck = arch_config['bottleneck']
    
    # First hidden layer
    first_width = 50 if is_bottleneck else width
    model.add(Dense(first_width, activation='relu', input_shape=(input_dim,)))
    
    # Remaining hidden layers
    hidden_width = 25 if is_bottleneck else width
    for _ in range(depth - 1):
        model.add(Dense(hidden_width, activation='relu'))
    
    # Output layer
    model.add(Dense(1, activation='sigmoid'))
    
    model.compile(
        loss='binary_crossentropy',
        optimizer=RMSprop(),
        metrics=['accuracy']
    )
    
    return model

print("✓ Model architecture functions defined")

## 6. kNN Graph Building Functions

In [None]:
# ============================================================================
# kNN GRAPH BUILDING (from knn_fixed.py)
# ============================================================================

def build_knn_graph(X: np.ndarray, k: int) -> csr_matrix:
    """
    Return an undirected, unweighted kNN adjacency in CSR format.
    Symmetrize by max and set diagonal to 0.
    """
    if X.ndim == 1:
        X = X.reshape(-1, 1)
    if X.dtype != np.float32 and X.dtype != np.float64:
        X = X.astype(np.float32, copy=False)
    
    knn = NearestNeighbors(n_neighbors=k, metric='euclidean')
    knn.fit(X)
    A = knn.kneighbors_graph(X, mode='connectivity')
    
    # Symmetrize (undirected graph)
    A = A.maximum(A.T)
    
    # Zero-out diagonal (no self-loops)
    A.setdiag(0)
    A.eliminate_zeros()
    
    return A.tocsr()

print("✓ build_knn_graph() defined")

## 7. Geodesic Mass & Forman-Ricci

In [None]:
# ============================================================================
# GEODESIC MASS (Paper Eq. 7)
# g_l = sum_{i<j} gamma_l(i, j)
# ============================================================================

def sum_shortest_paths(A: csr_matrix) -> float:
    """
    Compute geodesic mass g = sum of all-pairs shortest-path distances (i < j).
    """
    dist = shortest_path(A, directed=False, unweighted=True)
    iu = np.triu_indices_from(dist, k=1)
    vals = dist[iu]
    
    # Handle disconnected components
    finite = np.isfinite(vals)
    if not np.all(finite):
        vals = vals[finite]
    
    return float(vals.sum())


# ============================================================================
# FORMAN-RICCI CURVATURE (Paper Eq. 4, 6)
# R(i, j) = 4 - deg(i) - deg(j)
# Ric_l = sum_{(i,j) in E_l} R_l(i, j)
# ============================================================================

def global_forman_ricci(A: csr_matrix) -> float:
    """
    Compute global Forman-Ricci curvature coefficient.
    """
    deg = np.asarray(A.sum(axis=1)).ravel()
    A_ut = sp_triu(A, k=1).tocoo()
    curv = 4.0 - deg[A_ut.row] - deg[A_ut.col]
    return float(curv.sum())

print("✓ sum_shortest_paths() defined")
print("✓ global_forman_ricci() defined")

## 8. Aggregated Ricci Coefficient (Rho)

In [None]:
# ============================================================================
# AGGREGATED RICCI COEFFICIENT (Rho)
# Computes mean Ricci curvature across all layers for a model
# ============================================================================

def compute_aggregated_ricci(activations: List[np.ndarray], X0: np.ndarray, k: int) -> float:
    """
    Compute the aggregated Ricci coefficient (rho) for a single model.
    
    From the paper (Eq. 10):
        Ric = (1 / (L-1)) * sum_{l=1}^{L-1} Ric_l
    
    This is the mean Ricci curvature across all layers, used to
    correlate with model accuracy.
    """
    # Baseline graph on input space (l=0)
    A0 = build_knn_graph(X0, k)
    Ric0 = global_forman_ricci(A0)
    ric_list = [Ric0]
    
    # Hidden layers
    for Xl in activations:
        A = build_knn_graph(np.asarray(Xl), k)
        ric_list.append(global_forman_ricci(A))
    
    # Mean Ricci across all layers
    return float(np.mean(ric_list))

print("✓ compute_aggregated_ricci() defined")

## 9. Main Training + Live Analysis Loop

In [None]:
# ============================================================================
# MAIN TRAINING + LIVE ANALYSIS LOOP
# ============================================================================

def train_and_analyze_one_config(arch_name: str, depth: int):
    """
    Train models and immediately perform Ricci analysis.
    Results are saved incrementally to CSV.
    """
    arch_config = ARCHITECTURES[arch_name]
    
    # Check if already done
    if os.path.exists(RESULTS_CSV):
        existing_df = pd.read_csv(RESULTS_CSV)
        if len(existing_df[(existing_df['architecture'] == arch_name) & (existing_df['depth'] == depth)]) > 0:
            print(f"  [SKIP] {arch_name}/depth_{depth} already in CSV")
            return
    
    # Storage
    accuracy_list = []
    rho_list = []
    
    # Train and analyze each model
    for j in range(NUM_MODELS):
        # Build and train model
        model = build_model(arch_config, depth)
        model.fit(
            x_train, y_train,
            epochs=EPOCHS,
            batch_size=BATCH_SIZE,
            validation_split=0.2,
            callbacks=[StopAt99()],
            verbose=0
        )
        
        # Evaluate
        _, acc = model.evaluate(x_test, y_test, verbose=0)
        accuracy_list.append(acc)
        
        # Extract activations from hidden layers
        activations = []
        inp = x_test
        for layer in model.layers[:-1]:  # Exclude output layer
            inp = layer(inp).numpy()
            activations.append(inp)
        
        # Compute Aggregated Ricci (Rho) immediately
        rho = compute_aggregated_ricci(activations, x_test, K_VALUE)
        rho_list.append(rho)
        
        # Clear memory
        del model
        tf.keras.backend.clear_session()
    
    # Create result row
    result = {
        'architecture': arch_name,
        'depth': depth,
        'k': K_VALUE,
        'mean_accuracy': np.mean(accuracy_list),
        'std_accuracy': np.std(accuracy_list),
        'mean_rho': np.mean(rho_list),
        'std_rho': np.std(rho_list),
        'n_models': NUM_MODELS
    }
    
    # Save incrementally (append mode)
    result_df = pd.DataFrame([result])
    result_df.to_csv(
        RESULTS_CSV,
        mode='a',
        header=not os.path.exists(RESULTS_CSV),
        index=False
    )
    
    print(f"  ✓ {arch_name}/depth_{depth}: acc={result['mean_accuracy']:.4f}±{result['std_accuracy']:.4f}, "
          f"rho={result['mean_rho']:.2e}±{result['std_rho']:.2e}")

print("✓ train_and_analyze_one_config() defined")

## 10. Run Training + Analysis

In [None]:
# ============================================================================
# RUN TRAINING + LIVE ANALYSIS
# ============================================================================

print("=" * 80)
print("TRAINING + LIVE RICCI ANALYSIS")
print(f"Total configurations: {len(ARCHITECTURES) * len(LAYER_DEPTHS)}")
print(f"Models per configuration: {NUM_MODELS}")
print(f"Results will be saved to: {RESULTS_CSV}")
print("=" * 80)

start_time = time.time()

for arch_name in ARCHITECTURES.keys():
    print(f"\n{'='*60}")
    print(f"ARCHITECTURE: {arch_name.upper()}")
    print(f"{'='*60}")
    
    for depth in tqdm(LAYER_DEPTHS, desc=f"{arch_name} depths"):
        train_and_analyze_one_config(arch_name, depth)

total_time = time.time() - start_time

print(f"\n{'='*80}")
print(f"COMPLETE!")
print(f"Total time: {total_time/60:.1f} minutes")
print(f"Results saved to: {RESULTS_CSV}")
print(f"{'='*80}")

## 11. View Results

In [None]:
# ============================================================================
# VIEW RESULTS
# ============================================================================

if os.path.exists(RESULTS_CSV):
    results_df = pd.read_csv(RESULTS_CSV)
    print("Results Summary:")
    print(results_df.to_string(index=False))
else:
    print("No results file found. Run training first.")

## 12. Visualization: Accuracy vs Depth

In [None]:
# ============================================================================
# VISUALIZATION: ACCURACY VS DEPTH
# ============================================================================

if os.path.exists(RESULTS_CSV):
    results_df = pd.read_csv(RESULTS_CSV)
    
    fig, ax = plt.subplots(1, 1, figsize=(12, 6))
    
    for arch_name in ARCHITECTURES.keys():
        arch_data = results_df[results_df['architecture'] == arch_name].sort_values('depth')
        if len(arch_data) > 0:
            ax.errorbar(
                arch_data['depth'],
                arch_data['mean_accuracy'],
                yerr=arch_data['std_accuracy'],
                label=arch_name,
                marker='o',
                capsize=3
            )
    
    ax.set_xlabel('Number of Hidden Layers (Depth)', fontsize=12)
    ax.set_ylabel('Test Accuracy', fontsize=12)
    ax.set_title(f'MNIST {DIGIT_A} vs {DIGIT_B}: Accuracy vs Layer Depth', fontsize=14)
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR, 'accuracy_vs_depth.png'), dpi=150, bbox_inches='tight')
    plt.show()
    
    print(f"\n✓ Plot saved to: {OUTPUT_DIR}/accuracy_vs_depth.png")

## 13. Visualization: Rho vs Depth

In [None]:
# ============================================================================
# VISUALIZATION: RHO VS DEPTH
# ============================================================================

if os.path.exists(RESULTS_CSV):
    results_df = pd.read_csv(RESULTS_CSV)
    
    fig, ax = plt.subplots(1, 1, figsize=(12, 6))
    
    for arch_name in ARCHITECTURES.keys():
        arch_data = results_df[results_df['architecture'] == arch_name].sort_values('depth')
        if len(arch_data) > 0:
            ax.errorbar(
                arch_data['depth'],
                arch_data['mean_rho'],
                yerr=arch_data['std_rho'],
                label=arch_name,
                marker='o',
                capsize=3
            )
    
    ax.set_xlabel('Number of Hidden Layers (Depth)', fontsize=12)
    ax.set_ylabel('Aggregated Ricci Coefficient (ρ)', fontsize=12)
    ax.set_title(f'MNIST {DIGIT_A} vs {DIGIT_B}: Ricci Curvature vs Layer Depth', fontsize=14)
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR, 'rho_vs_depth.png'), dpi=150, bbox_inches='tight')
    plt.show()
    
    print(f"\n✓ Plot saved to: {OUTPUT_DIR}/rho_vs_depth.png")

## 14. Visualization: Accuracy vs Rho Correlation

In [None]:
# ============================================================================
# VISUALIZATION: ACCURACY VS RHO CORRELATION
# ============================================================================

if os.path.exists(RESULTS_CSV):
    results_df = pd.read_csv(RESULTS_CSV)
    
    fig, ax = plt.subplots(1, 1, figsize=(10, 8))
    
    colors = {'narrow': 'blue', 'wide': 'green', 'bottleneck': 'red'}
    
    for arch_name in ARCHITECTURES.keys():
        arch_data = results_df[results_df['architecture'] == arch_name]
        if len(arch_data) > 0:
            ax.scatter(
                arch_data['mean_rho'],
                arch_data['mean_accuracy'],
                label=arch_name,
                color=colors.get(arch_name, 'gray'),
                alpha=0.7,
                s=50
            )
    
    # Compute Spearman correlation
    if len(results_df) > 2:
        rho_corr, p_val = spearmanr(results_df['mean_rho'], results_df['mean_accuracy'])
        ax.set_title(f'Accuracy vs Aggregated Ricci Coefficient\nSpearman ρ = {rho_corr:.3f}, p = {p_val:.2e}', fontsize=14)
    else:
        ax.set_title('Accuracy vs Aggregated Ricci Coefficient', fontsize=14)
    
    ax.set_xlabel('Aggregated Ricci Coefficient (ρ)', fontsize=12)
    ax.set_ylabel('Test Accuracy', fontsize=12)
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR, 'accuracy_vs_rho.png'), dpi=150, bbox_inches='tight')
    plt.show()
    
    print(f"\n✓ Plot saved to: {OUTPUT_DIR}/accuracy_vs_rho.png")

## 15. Summary Statistics

In [None]:
# ============================================================================
# SUMMARY STATISTICS
# ============================================================================

if os.path.exists(RESULTS_CSV):
    results_df = pd.read_csv(RESULTS_CSV)
    
    print("=" * 60)
    print("SUMMARY STATISTICS")
    print("=" * 60)
    
    for arch_name in ARCHITECTURES.keys():
        arch_data = results_df[results_df['architecture'] == arch_name]
        if len(arch_data) > 0:
            print(f"\n{arch_name.upper()}:")
            print(f"  Configurations: {len(arch_data)}")
            print(f"  Best accuracy: {arch_data['mean_accuracy'].max():.4f} at depth {arch_data.loc[arch_data['mean_accuracy'].idxmax(), 'depth']}")
            print(f"  Worst accuracy: {arch_data['mean_accuracy'].min():.4f} at depth {arch_data.loc[arch_data['mean_accuracy'].idxmin(), 'depth']}")
            print(f"  Rho range: [{arch_data['mean_rho'].min():.2e}, {arch_data['mean_rho'].max():.2e}]")
    
    # Overall correlation
    if len(results_df) > 2:
        rho_corr, p_val = spearmanr(results_df['mean_rho'], results_df['mean_accuracy'])
        print(f"\n{'='*60}")
        print(f"OVERALL SPEARMAN CORRELATION (Accuracy vs Rho):")
        print(f"  ρ = {rho_corr:.4f}, p-value = {p_val:.2e}")
        print(f"{'='*60}")