# Ricci Flow Layer Depth Study
## DNN Training for Varying Layer Depths (3-30)

This notebook trains DNNs with varying layer depths across three architectures (Narrow, Wide, Bottleneck) on MNIST 4 vs 9 binary classification.

**Purpose:** Generate activations for Ricci curvature analysis to study how layer depth affects Ricci flow-like behavior.

**Based on:** `training.py` and theoretical framework from "Deep Learning as Ricci Flow" (Baptista et al., 2024)

## 1. Setup & GPU Detection

In [None]:
# ============================================================================
# IMPORTS
# ============================================================================
import numpy as np
import pandas as pd
import os
import time
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')

# TensorFlow/Keras imports
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.datasets import mnist

# ============================================================================
# GPU DETECTION (for Kaggle/Colab)
# ============================================================================
print("=" * 60)
print("DEVICE DETECTION")
print("=" * 60)

# Check for GPU
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    print(f"✓ GPU(s) detected: {len(gpus)}")
    for gpu in gpus:
        print(f"  - {gpu.name}")
    # Enable memory growth to avoid OOM
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print("✓ GPU memory growth enabled")
    except RuntimeError as e:
        print(f"⚠ GPU memory growth setting failed: {e}")
else:
    print("⚠ No GPU detected. Training will use CPU.")
    print("  Tip: Enable GPU in Kaggle/Colab settings for faster training.")

print(f"\nTensorFlow version: {tf.__version__}")
print("=" * 60)

## 2. Configuration

In [None]:
# ============================================================================
# CONFIGURATION
# ============================================================================

# Problem definition
DIGIT_A = 4  # First class
DIGIT_B = 9  # Second class (hard problem: 4 vs 9 are visually similar)

# Layer depths to test (3 to 30)
LAYER_DEPTHS = list(range(3, 31))  # [3, 4, 5, ..., 30]

# Architecture configurations
ARCHITECTURES = {
    'narrow': {'width': 25, 'bottleneck': False},
    'wide': {'width': 50, 'bottleneck': False},
    'bottleneck': {'width': 50, 'bottleneck': True}  # 50 neurons first layer, 25 rest
}

# Training parameters (based on training.py)
NUM_MODELS = 25  # Models per configuration for statistical robustness
EPOCHS = 50
BATCH_SIZE = 32
VALIDATION_SPLIT = 0.2
EARLY_STOP_ACCURACY = 0.99  # Stop when training accuracy hits 99%

# Output directory
OUTPUT_DIR = 'layer_depth_study_outputs'

print("Configuration:")
print(f"  Problem: MNIST {DIGIT_A} vs {DIGIT_B}")
print(f"  Layer depths: {LAYER_DEPTHS[0]} to {LAYER_DEPTHS[-1]} ({len(LAYER_DEPTHS)} values)")
print(f"  Architectures: {list(ARCHITECTURES.keys())}")
print(f"  Models per config: {NUM_MODELS}")
print(f"  Early stopping: {EARLY_STOP_ACCURACY*100}% training accuracy")
print(f"  Total training runs: {len(ARCHITECTURES) * len(LAYER_DEPTHS) * NUM_MODELS}")

## 3. Data Loading (MNIST 4 vs 9)

In [None]:
# ============================================================================
# DATA LOADING
# ============================================================================

def load_mnist_binary(digit_a, digit_b):
    """Load MNIST data filtered for binary classification."""
    # Load full MNIST
    (x_train_full, y_train_full), (x_test_full, y_test_full) = mnist.load_data()
    
    # Filter for our two digits
    train_mask = (y_train_full == digit_a) | (y_train_full == digit_b)
    test_mask = (y_test_full == digit_a) | (y_test_full == digit_b)
    
    x_train = x_train_full[train_mask]
    y_train = y_train_full[train_mask]
    x_test = x_test_full[test_mask]
    y_test = y_test_full[test_mask]
    
    # Flatten images (28x28 -> 784)
    x_train = x_train.reshape(-1, 784).astype('float32')
    x_test = x_test.reshape(-1, 784).astype('float32')
    
    # Normalize to [0, 1]
    x_train = x_train / 255.0
    x_test = x_test / 255.0
    
    # Convert labels to binary (digit_a -> 0, digit_b -> 1)
    y_train = (y_train == digit_b).astype('int32')
    y_test = (y_test == digit_b).astype('int32')
    
    return x_train, y_train, x_test, y_test

# Load data
print("Loading MNIST data...")
x_train, y_train, x_test, y_test = load_mnist_binary(DIGIT_A, DIGIT_B)

print(f"\nDataset loaded:")
print(f"  Training samples: {x_train.shape[0]} (Class 0: {np.sum(y_train==0)}, Class 1: {np.sum(y_train==1)})")
print(f"  Test samples: {x_test.shape[0]} (Class 0: {np.sum(y_test==0)}, Class 1: {np.sum(y_test==1)})")
print(f"  Feature dimension: {x_train.shape[1]}")

## 4. Model Architecture Builders

In [None]:
# ============================================================================
# ARCHITECTURE BUILDERS
# ============================================================================

def build_dnn(arch_name, arch_config, depth, input_dim):
    """
    Build DNN model based on architecture configuration.
    
    Based on training.py structure:
    - Narrow: 25 neurons per layer
    - Wide: 50 neurons per layer
    - Bottleneck: 50 neurons first layer, 25 for rest
    """
    model = Sequential()
    
    width = arch_config['width']
    is_bottleneck = arch_config['bottleneck']
    
    # First hidden layer
    if is_bottleneck:
        model.add(Dense(units=50, activation='relu', input_shape=(input_dim,)))
    else:
        model.add(Dense(units=width, activation='relu', input_shape=(input_dim,)))
    
    # Remaining hidden layers (depth - 1 more layers)
    for _ in range(depth - 1):
        if is_bottleneck:
            model.add(Dense(units=25, activation='relu'))
        else:
            model.add(Dense(units=width, activation='relu'))
    
    # Output layer
    model.add(Dense(units=1, activation='sigmoid'))
    
    # Compile (based on training.py)
    model.compile(
        loss='binary_crossentropy',
        optimizer=RMSprop(),
        metrics=['accuracy']
    )
    
    return model


# Test architecture building
print("Testing architecture builders...")
for arch_name, arch_config in ARCHITECTURES.items():
    test_model = build_dnn(arch_name, arch_config, depth=5, input_dim=784)
    print(f"  {arch_name}: {len(test_model.layers)} layers ({test_model.count_params()} params)")
print("✓ All architectures build successfully")

## 5. Training Function with Early Stopping

In [None]:
# ============================================================================
# TRAINING FUNCTION
# ============================================================================

class AccuracyThresholdCallback(tf.keras.callbacks.Callback):
    """Stop training when training accuracy reaches threshold."""
    def __init__(self, threshold=0.99):
        super().__init__()
        self.threshold = threshold
    
    def on_epoch_end(self, epoch, logs=None):
        if logs.get('accuracy') >= self.threshold:
            self.model.stop_training = True


def train_single_model(arch_name, arch_config, depth, x_train, y_train, x_test, y_test):
    """
    Train a single DNN model and extract activations.
    
    Returns:
        activations: list of numpy arrays (one per hidden layer)
        accuracy: test accuracy
    """
    # Build model
    model = build_dnn(arch_name, arch_config, depth, input_dim=x_train.shape[1])
    
    # Early stopping callback (stop at 99% training accuracy)
    early_stop = AccuracyThresholdCallback(threshold=EARLY_STOP_ACCURACY)
    
    # Train model (based on training.py)
    model.fit(
        x_train, y_train,
        epochs=EPOCHS,
        batch_size=BATCH_SIZE,
        validation_split=VALIDATION_SPLIT,
        callbacks=[early_stop],
        verbose=0
    )
    
    # Evaluate on test set
    _, accuracy = model.evaluate(x_test, y_test, verbose=0)
    
    # Extract activations from all hidden layers (exclude output layer)
    # Based on training.py: model_predict[j] = activations
    activations = []
    current_input = x_test
    for layer in model.layers[:-1]:  # Exclude output layer
        current_output = layer(current_input)
        activations.append(current_output.numpy())
        current_input = current_output
    
    return activations, accuracy


print("Training function defined.")
print(f"  Early stopping: Training accuracy >= {EARLY_STOP_ACCURACY*100}%")

## 6. Output Directory Structure

In [None]:
# ============================================================================
# OUTPUT STRUCTURE HELPER
# ============================================================================

def create_output_dirs():
    """
    Create output directory structure:
    
    layer_depth_study_outputs/
    ├── narrow/
    │   ├── depth_3/
    │   │   └── models_b25/
    │   ├── depth_4/
    │   └── ...
    ├── wide/
    │   └── ...
    └── bottleneck/
        └── ...
    """
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    
    for arch_name in ARCHITECTURES.keys():
        for depth in LAYER_DEPTHS:
            path = os.path.join(OUTPUT_DIR, arch_name, f"depth_{depth}", f"models_b{NUM_MODELS}")
            os.makedirs(path, exist_ok=True)
    
    print(f"✓ Output directory structure created: {OUTPUT_DIR}/")
    print(f"  Total folders: {len(ARCHITECTURES) * len(LAYER_DEPTHS)}")


def save_model_outputs(arch_name, depth, model_predict, accuracy_list, x_test, y_test):
    """
    Save model outputs in format compatible with knn_fixed.py:
    - model_predict.npy: object array of activation lists
    - accuracy.npy: array of accuracy values
    - x_test.csv: test features (headerless)
    - y_test.csv: test labels (headerless)
    """
    output_path = os.path.join(OUTPUT_DIR, arch_name, f"depth_{depth}", f"models_b{NUM_MODELS}")
    
    np.save(os.path.join(output_path, "model_predict.npy"), model_predict)
    np.save(os.path.join(output_path, "accuracy.npy"), np.array(accuracy_list))
    pd.DataFrame(x_test).to_csv(os.path.join(output_path, "x_test.csv"), index=False, header=None)
    pd.DataFrame(y_test).to_csv(os.path.join(output_path, "y_test.csv"), index=False, header=None)


# Create directory structure
create_output_dirs()

## 7. Main Training Loop

In [None]:
# ============================================================================
# MAIN TRAINING LOOP
# ============================================================================

# Storage for summary CSV
summary_results = []

# Total configurations
total_configs = len(ARCHITECTURES) * len(LAYER_DEPTHS)
config_count = 0

print("=" * 80)
print("TRAINING STARTED")
print(f"Total configurations: {total_configs}")
print(f"Models per configuration: {NUM_MODELS}")
print("=" * 80)

start_time_total = time.time()

# Loop over architectures
for arch_name, arch_config in ARCHITECTURES.items():
    print(f"\n{'='*60}")
    print(f"ARCHITECTURE: {arch_name.upper()}")
    print(f"{'='*60}")
    
    # Loop over layer depths
    for depth in tqdm(LAYER_DEPTHS, desc=f"{arch_name} depths"):
        config_count += 1
        
        # Storage for this configuration
        model_predict = np.empty(NUM_MODELS, dtype=object)
        accuracy_list = []
        
        # Train NUM_MODELS models
        for j in range(NUM_MODELS):
            activations, accuracy = train_single_model(
                arch_name, arch_config, depth,
                x_train, y_train, x_test, y_test
            )
            model_predict[j] = activations
            accuracy_list.append(accuracy)
        
        # Save outputs (compatible with knn_fixed.py)
        save_model_outputs(arch_name, depth, model_predict, accuracy_list, x_test, y_test)
        
        # Store summary statistics
        mean_acc = np.mean(accuracy_list)
        std_acc = np.std(accuracy_list)
        min_acc = np.min(accuracy_list)
        max_acc = np.max(accuracy_list)
        
        summary_results.append({
            'architecture': arch_name,
            'depth': depth,
            'num_models': NUM_MODELS,
            'mean_accuracy': mean_acc,
            'std_accuracy': std_acc,
            'min_accuracy': min_acc,
            'max_accuracy': max_acc
        })
        
        # Progress update every 7 depths
        if depth % 7 == 0:
            tqdm.write(f"  depth={depth}: mean_acc={mean_acc:.4f} ± {std_acc:.4f}")

total_time = time.time() - start_time_total

print(f"\n{'='*80}")
print(f"TRAINING COMPLETE!")
print(f"Total time: {total_time/60:.1f} minutes")
print(f"{'='*80}")

## 8. Save Summary CSV

In [None]:
# ============================================================================
# SAVE SUMMARY CSV
# ============================================================================

# Create DataFrame
summary_df = pd.DataFrame(summary_results)

# Save to CSV
summary_csv_path = os.path.join(OUTPUT_DIR, 'training_summary.csv')
summary_df.to_csv(summary_csv_path, index=False)

print(f"Summary saved to: {summary_csv_path}")
print(f"\nSummary statistics:")
print(summary_df.to_string(index=False))

## 9. Quick Visualization

In [None]:
# ============================================================================
# QUICK VISUALIZATION
# ============================================================================
import matplotlib.pyplot as plt

fig, ax = plt.subplots(1, 1, figsize=(12, 6))

for arch_name in ARCHITECTURES.keys():
    arch_data = summary_df[summary_df['architecture'] == arch_name]
    ax.errorbar(
        arch_data['depth'], 
        arch_data['mean_accuracy'],
        yerr=arch_data['std_accuracy'],
        label=arch_name,
        marker='o',
        capsize=3
    )

ax.set_xlabel('Number of Hidden Layers (Depth)', fontsize=12)
ax.set_ylabel('Test Accuracy', fontsize=12)
ax.set_title(f'MNIST {DIGIT_A} vs {DIGIT_B}: Accuracy vs Layer Depth', fontsize=14)
ax.legend()
ax.grid(True, alpha=0.3)
ax.set_xlim([2.5, 30.5])

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'accuracy_vs_depth.png'), dpi=150, bbox_inches='tight')
plt.show()

print(f"\n✓ Visualization saved to: {OUTPUT_DIR}/accuracy_vs_depth.png")

## 10. Verification & Next Steps

In [None]:
# ============================================================================
# VERIFICATION
# ============================================================================

print("=" * 60)
print("VERIFICATION")
print("=" * 60)

# Check a sample output
sample_path = os.path.join(OUTPUT_DIR, 'narrow', 'depth_5', f'models_b{NUM_MODELS}')
print(f"\nChecking sample output: {sample_path}")

# Load and verify
sample_acc = np.load(os.path.join(sample_path, 'accuracy.npy'))
sample_model = np.load(os.path.join(sample_path, 'model_predict.npy'), allow_pickle=True)

print(f"  accuracy.npy shape: {sample_acc.shape}")
print(f"  accuracy range: [{sample_acc.min():.4f}, {sample_acc.max():.4f}]")
print(f"  model_predict.npy: {len(sample_model)} models")
print(f"  Activations per model: {len(sample_model[0])} layers")
print(f"  Activation shapes: {[a.shape for a in sample_model[0]]}")

print("\n" + "=" * 60)
print("NEXT STEPS")
print("=" * 60)
print("""
The outputs are now ready for Ricci curvature analysis using knn_fixed.py.

For each architecture/depth combination:
  1. Load model_predict.npy and x_test.csv
  2. Build kNN graphs on activations
  3. Compute Forman-Ricci curvature
  4. Correlate with accuracy from training_summary.csv
""")