# Ricci Flow Layer Depth Study - Colab Version
**Complete notebook: Training + Full Ricci Analysis (based on knn_fixed.py)**

Implements all paper metrics:
- Forman-Ricci curvature: R(i,j) = 4 - deg(i) - deg(j)
- Geodesic mass: g_l = Σ γ_l(i,j)
- Δg_l = g_l - g_{l-1} correlation with Ric_{l-1}

## 1. Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')
OUTPUT_DIR = '/content/drive/MyDrive/ricci_layer_depth_study'
import os
os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f'✓ Output: {OUTPUT_DIR}')

## 2. Imports & GPU

In [None]:
import numpy as np
import pandas as pd
import os
import time
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.datasets import mnist

from sklearn.neighbors import NearestNeighbors
from scipy.sparse import csr_matrix, triu as sp_triu
from scipy.sparse.csgraph import shortest_path
from scipy.stats import pearsonr, spearmanr
import matplotlib.pyplot as plt
from typing import List, Dict, Tuple

gpus = tf.config.list_physical_devices('GPU')
print(f'GPU: {"Yes" if gpus else "No"}, TF: {tf.__version__}')

## 3. Configuration

In [None]:
DIGIT_A, DIGIT_B = 4, 9
LAYER_DEPTHS = list(range(3, 31))
ARCHITECTURES = {'narrow': {'width': 25, 'bn': False}, 'wide': {'width': 50, 'bn': False}, 'bottleneck': {'width': 50, 'bn': True}}
NUM_MODELS = 25
EPOCHS, BATCH_SIZE = 50, 32
K_VALUE = 200  # ~10% of test samples
ACC_THRESHOLD = 0.0  # Use all models (set higher to filter)
TRAINING_CSV = os.path.join(OUTPUT_DIR, 'training_checkpoint.csv')
RICCI_CSV = os.path.join(OUTPUT_DIR, 'ricci_full_results.csv')
print(f'Config: {DIGIT_A}vs{DIGIT_B}, depths 3-30, k={K_VALUE}')

## 4. Load Data

In [None]:
(x_train_full, y_train_full), (x_test_full, y_test_full) = mnist.load_data()
train_mask = (y_train_full == DIGIT_A) | (y_train_full == DIGIT_B)
test_mask = (y_test_full == DIGIT_A) | (y_test_full == DIGIT_B)
x_train = x_train_full[train_mask].reshape(-1, 784).astype('float32') / 255.0
x_test = x_test_full[test_mask].reshape(-1, 784).astype('float32') / 255.0
y_train = (y_train_full[train_mask] == DIGIT_B).astype('int32')
y_test = (y_test_full[test_mask] == DIGIT_B).astype('int32')
print(f'Train: {len(x_train)}, Test: {len(x_test)}')

## 5. Training Functions

In [None]:
class StopAt99(tf.keras.callbacks.Callback):
    def on_epoch_end(self, e, logs=None):
        if logs.get('accuracy', 0) >= 0.99: self.model.stop_training = True

def train_one_config(arch_name, depth):
    cfg = ARCHITECTURES[arch_name]
    out_dir = os.path.join(OUTPUT_DIR, arch_name, f'depth_{depth}')
    os.makedirs(out_dir, exist_ok=True)
    if os.path.exists(os.path.join(out_dir, 'accuracy.npy')):
        print(f'✓ {arch_name}/depth_{depth} already trained')
        return
    model_predict = np.empty(NUM_MODELS, dtype=object)
    accuracy_list = []
    for j in range(NUM_MODELS):
        model = Sequential()
        model.add(Dense(50 if cfg['bn'] else cfg['width'], activation='relu', input_shape=(784,)))
        for _ in range(depth - 1):
            model.add(Dense(25 if cfg['bn'] else cfg['width'], activation='relu'))
        model.add(Dense(1, activation='sigmoid'))
        model.compile(loss='binary_crossentropy', optimizer=RMSprop(), metrics=['accuracy'])
        model.fit(x_train, y_train, epochs=EPOCHS, batch_size=BATCH_SIZE, validation_split=0.2, callbacks=[StopAt99()], verbose=0)
        _, acc = model.evaluate(x_test, y_test, verbose=0)
        accuracy_list.append(acc)
        acts = []
        inp = x_test
        for layer in model.layers[:-1]:
            inp = layer(inp).numpy()
            acts.append(inp)
        model_predict[j] = acts
    np.save(os.path.join(out_dir, 'model_predict.npy'), model_predict)
    np.save(os.path.join(out_dir, 'accuracy.npy'), np.array(accuracy_list))
    pd.DataFrame(x_test).to_csv(os.path.join(out_dir, 'x_test.csv'), index=False, header=None)
    row = {'architecture': arch_name, 'depth': depth, 'mean_acc': np.mean(accuracy_list), 'std_acc': np.std(accuracy_list)}
    pd.DataFrame([row]).to_csv(TRAINING_CSV, mode='a', header=not os.path.exists(TRAINING_CSV), index=False)
    print(f'✓ {arch_name}/depth_{depth}: acc={row["mean_acc"]:.4f}')

print('✓ Training functions ready')

## 6. Ricci Analysis Functions (from knn_fixed.py)

In [None]:
# ============================================================================
# kNN GRAPH BUILDING (knn_fixed.py lines 37-54)
# ============================================================================
def build_knn_graph(X: np.ndarray, k: int) -> csr_matrix:
    """Return undirected, unweighted kNN adjacency in CSR.
    Symmetrize by max and set diagonal to 0."""
    if X.ndim == 1:
        X = X.reshape(-1, 1)
    if X.dtype != np.float32 and X.dtype != np.float64:
        X = X.astype(np.float32, copy=False)
    knn = NearestNeighbors(n_neighbors=k, metric='euclidean')
    knn.fit(X)
    A = knn.kneighbors_graph(X, mode='connectivity')
    A = A.maximum(A.T)  # symmetrize
    A.setdiag(0)
    A.eliminate_zeros()
    return A.tocsr()

# ============================================================================
# GEODESIC MASS (knn_fixed.py lines 57-71)
# g_l = sum_{i<j} gamma_l(i,j) - Paper Eq. 7
# ============================================================================
def sum_shortest_paths(A: csr_matrix) -> float:
    """Compute g = sum of all-pairs shortest-path distances (i<j)."""
    dist = shortest_path(A, directed=False, unweighted=True)
    iu = np.triu_indices_from(dist, k=1)
    vals = dist[iu]
    finite = np.isfinite(vals)
    if not np.all(finite):
        print(f'[WARN] Disconnected: {(~finite).sum()} inf distances ignored')
        vals = vals[finite]
    return float(vals.sum())

# ============================================================================
# FORMAN-RICCI CURVATURE (knn_fixed.py lines 74-83)
# R(i,j) = 4 - deg(i) - deg(j) - Paper Eq. 4
# Ric_l = sum of R over all edges - Paper Eq. 6
# ============================================================================
def global_forman_ricci(A: csr_matrix) -> float:
    """Global Ricci coefficient = sum of edge curvatures."""
    deg = np.asarray(A.sum(axis=1)).ravel()
    A_ut = sp_triu(A, k=1).tocoo()
    curv = 4.0 - deg[A_ut.row] - deg[A_ut.col]
    return float(curv.sum())

print('✓ Graph functions ready')

In [None]:
# ============================================================================
# ANALYZE MODEL LAYERS (knn_fixed.py lines 90-108)
# Returns BOTH g and Ric for all layers
# ============================================================================
def analyze_model_layers(activations: List[np.ndarray], X0: np.ndarray, k: int) -> Dict[str, np.ndarray]:
    """For one model: build graphs, compute (g_l, Ric_l) for l=0..L.
    l=0 = baseline on raw test input X0."""
    A0 = build_knn_graph(X0, k)
    g0 = sum_shortest_paths(A0)
    Ric0 = global_forman_ricci(A0)
    g_list = [g0]
    ric_list = [Ric0]
    for Xl in activations:
        A = build_knn_graph(np.asarray(Xl), k)
        g_list.append(sum_shortest_paths(A))
        ric_list.append(global_forman_ricci(A))
    return {'g': np.array(g_list, dtype=float), 'Ric': np.array(ric_list, dtype=float)}

# ============================================================================
# COLLECT ACROSS MODELS (knn_fixed.py lines 111-140)
# Returns DataFrames for Δg_l and Ric_{l-1}
# ============================================================================
def collect_across_models(models: List, X0: np.ndarray, k: int,
                          acc: np.ndarray, acc_threshold: float) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """Run analysis over models passing accuracy threshold.
    Returns mfr (Ric_{l-1}) and msc (Δg_l) DataFrames."""
    keep = np.where(acc > acc_threshold)[0]
    if keep.size == 0:
        print(f'[WARN] No models > {acc_threshold}, using all')
        keep = np.arange(len(models))
    rows_fr, rows_sc = [], []
    for m in keep:
        acts = models[m]
        res = analyze_model_layers(acts, X0, k)
        g, Ric = res['g'], res['Ric']
        L = len(acts)
        dgs = g[1:] - g[:-1]  # Δg_l for l=1..L
        for l in range(1, L+1):
            rows_sc.append({'layer': l, 'mod': int(m), 'ssr': float(dgs[l-1])})
            rows_fr.append({'layer': l-1, 'mod': int(m), 'ssr': float(Ric[l-1])})
    msc = pd.DataFrame(rows_sc, columns=['layer', 'mod', 'ssr'])  # Δg_l
    mfr = pd.DataFrame(rows_fr, columns=['layer', 'mod', 'ssr'])  # Ric_{l-1}
    return mfr, msc

# ============================================================================
# CORRELATION REPORT (knn_fixed.py lines 143-164)
# Correlates Δg_l with Ric_{l-1} - Paper Eq. 5
# ============================================================================
def correlation_report(mfr: pd.DataFrame, msc: pd.DataFrame) -> Dict[str, float]:
    """Compute Pearson r between Δg_l and Ric_{l-1}."""
    mfr_shifted = mfr.copy()
    mfr_shifted['layer'] = mfr_shifted['layer'] + 1
    merged = msc.merge(mfr_shifted, on=['mod', 'layer'], how='inner', suffixes=('_dg', '_fr'))
    if len(merged) < 2:
        return {'r_all': np.nan, 'p_all': np.nan, 'r_skip': np.nan, 'p_skip': np.nan}
    r_all = pearsonr(merged['ssr_dg'].values, merged['ssr_fr'].values)
    merged_skip = merged[merged['layer'] != 1]
    if len(merged_skip) < 2:
        r_skip = (np.nan, np.nan)
    else:
        r_skip = pearsonr(merged_skip['ssr_dg'].values, merged_skip['ssr_fr'].values)
    return {'r_all': float(r_all[0]), 'p_all': float(r_all[1]),
            'r_skip': float(r_skip[0]), 'p_skip': float(r_skip[1])}

# ============================================================================
# AGGREGATED RHO (mean Ricci across layers)
# ============================================================================
def compute_aggregated_rho(activations: List[np.ndarray], X0: np.ndarray, k: int) -> float:
    """Mean Ricci coefficient across all layers."""
    res = analyze_model_layers(activations, X0, k)
    return float(np.mean(res['Ric']))

print('✓ Analysis functions ready')

In [None]:
# ============================================================================
# FULL RICCI ANALYSIS FOR ONE CONFIG
# ============================================================================
def ricci_one_config(arch_name, depth):
    """Run full Ricci analysis on one architecture/depth."""
    if os.path.exists(RICCI_CSV):
        df = pd.read_csv(RICCI_CSV)
        if len(df[(df['architecture']==arch_name) & (df['depth']==depth)]) > 0:
            print(f'✓ {arch_name}/depth_{depth} ricci done')
            return
    
    out_dir = os.path.join(OUTPUT_DIR, arch_name, f'depth_{depth}')
    model_predict = np.load(os.path.join(out_dir, 'model_predict.npy'), allow_pickle=True)
    accuracy = np.load(os.path.join(out_dir, 'accuracy.npy'))
    X0 = pd.read_csv(os.path.join(out_dir, 'x_test.csv'), header=None).values
    
    # Full analysis: collect Δg and Ric
    mfr, msc = collect_across_models(model_predict, X0, K_VALUE, accuracy, ACC_THRESHOLD)
    stats = correlation_report(mfr, msc)
    
    # Aggregated rho per model
    rho_list = [compute_aggregated_rho(model_predict[m], X0, K_VALUE) for m in range(len(model_predict))]
    
    # Save per-layer data
    mfr.to_csv(os.path.join(out_dir, 'mfr.csv'), index=False)
    msc.to_csv(os.path.join(out_dir, 'msc.csv'), index=False)
    
    row = {
        'architecture': arch_name, 'depth': depth, 'k': K_VALUE,
        'mean_acc': np.mean(accuracy), 'std_acc': np.std(accuracy),
        'mean_rho': np.mean(rho_list), 'std_rho': np.std(rho_list),
        'r_all': stats['r_all'], 'p_all': stats['p_all'],
        'r_skip': stats['r_skip'], 'p_skip': stats['p_skip'],
        'n_models': len(model_predict)
    }
    pd.DataFrame([row]).to_csv(RICCI_CSV, mode='a', header=not os.path.exists(RICCI_CSV), index=False)
    print(f'✓ {arch_name}/depth_{depth}: rho={row["mean_rho"]:.2e}, r_all={stats["r_all"]:.4f}')

print('✓ ricci_one_config() ready')

---
# TRAINING: NARROW

In [None]:
train_one_config('narrow', 3)

In [None]:
train_one_config('narrow', 4)

In [None]:
train_one_config('narrow', 5)

In [None]:
train_one_config('narrow', 6)

In [None]:
train_one_config('narrow', 7)

In [None]:
train_one_config('narrow', 8)

In [None]:
train_one_config('narrow', 9)

In [None]:
train_one_config('narrow', 10)

In [None]:
train_one_config('narrow', 11)

In [None]:
train_one_config('narrow', 12)

In [None]:
train_one_config('narrow', 13)

In [None]:
train_one_config('narrow', 14)

In [None]:
train_one_config('narrow', 15)

In [None]:
train_one_config('narrow', 16)

In [None]:
train_one_config('narrow', 17)

In [None]:
train_one_config('narrow', 18)

In [None]:
train_one_config('narrow', 19)

In [None]:
train_one_config('narrow', 20)

In [None]:
train_one_config('narrow', 21)

In [None]:
train_one_config('narrow', 22)

In [None]:
train_one_config('narrow', 23)

In [None]:
train_one_config('narrow', 24)

In [None]:
train_one_config('narrow', 25)

In [None]:
train_one_config('narrow', 26)

In [None]:
train_one_config('narrow', 27)

In [None]:
train_one_config('narrow', 28)

In [None]:
train_one_config('narrow', 29)

In [None]:
train_one_config('narrow', 30)

---
# TRAINING: WIDE

In [None]:
train_one_config('wide', 3)

In [None]:
train_one_config('wide', 4)

In [None]:
train_one_config('wide', 5)

In [None]:
train_one_config('wide', 6)

In [None]:
train_one_config('wide', 7)

In [None]:
train_one_config('wide', 8)

In [None]:
train_one_config('wide', 9)

In [None]:
train_one_config('wide', 10)

In [None]:
train_one_config('wide', 11)

In [None]:
train_one_config('wide', 12)

In [None]:
train_one_config('wide', 13)

In [None]:
train_one_config('wide', 14)

In [None]:
train_one_config('wide', 15)

In [None]:
train_one_config('wide', 16)

In [None]:
train_one_config('wide', 17)

In [None]:
train_one_config('wide', 18)

In [None]:
train_one_config('wide', 19)

In [None]:
train_one_config('wide', 20)

In [None]:
train_one_config('wide', 21)

In [None]:
train_one_config('wide', 22)

In [None]:
train_one_config('wide', 23)

In [None]:
train_one_config('wide', 24)

In [None]:
train_one_config('wide', 25)

In [None]:
train_one_config('wide', 26)

In [None]:
train_one_config('wide', 27)

In [None]:
train_one_config('wide', 28)

In [None]:
train_one_config('wide', 29)

In [None]:
train_one_config('wide', 30)

---
# TRAINING: BOTTLENECK

In [None]:
train_one_config('bottleneck', 3)

In [None]:
train_one_config('bottleneck', 4)

In [None]:
train_one_config('bottleneck', 5)

In [None]:
train_one_config('bottleneck', 6)

In [None]:
train_one_config('bottleneck', 7)

In [None]:
train_one_config('bottleneck', 8)

In [None]:
train_one_config('bottleneck', 9)

In [None]:
train_one_config('bottleneck', 10)

In [None]:
train_one_config('bottleneck', 11)

In [None]:
train_one_config('bottleneck', 12)

In [None]:
train_one_config('bottleneck', 13)

In [None]:
train_one_config('bottleneck', 14)

In [None]:
train_one_config('bottleneck', 15)

In [None]:
train_one_config('bottleneck', 16)

In [None]:
train_one_config('bottleneck', 17)

In [None]:
train_one_config('bottleneck', 18)

In [None]:
train_one_config('bottleneck', 19)

In [None]:
train_one_config('bottleneck', 20)

In [None]:
train_one_config('bottleneck', 21)

In [None]:
train_one_config('bottleneck', 22)

In [None]:
train_one_config('bottleneck', 23)

In [None]:
train_one_config('bottleneck', 24)

In [None]:
train_one_config('bottleneck', 25)

In [None]:
train_one_config('bottleneck', 26)

In [None]:
train_one_config('bottleneck', 27)

In [None]:
train_one_config('bottleneck', 28)

In [None]:
train_one_config('bottleneck', 29)

In [None]:
train_one_config('bottleneck', 30)

---
# RICCI ANALYSIS: NARROW

In [None]:
ricci_one_config('narrow', 3)

In [None]:
ricci_one_config('narrow', 4)

In [None]:
ricci_one_config('narrow', 5)

In [None]:
ricci_one_config('narrow', 6)

In [None]:
ricci_one_config('narrow', 7)

In [None]:
ricci_one_config('narrow', 8)

In [None]:
ricci_one_config('narrow', 9)

In [None]:
ricci_one_config('narrow', 10)

In [None]:
ricci_one_config('narrow', 11)

In [None]:
ricci_one_config('narrow', 12)

In [None]:
ricci_one_config('narrow', 13)

In [None]:
ricci_one_config('narrow', 14)

In [None]:
ricci_one_config('narrow', 15)

In [None]:
ricci_one_config('narrow', 16)

In [None]:
ricci_one_config('narrow', 17)

In [None]:
ricci_one_config('narrow', 18)

In [None]:
ricci_one_config('narrow', 19)

In [None]:
ricci_one_config('narrow', 20)

In [None]:
ricci_one_config('narrow', 21)

In [None]:
ricci_one_config('narrow', 22)

In [None]:
ricci_one_config('narrow', 23)

In [None]:
ricci_one_config('narrow', 24)

In [None]:
ricci_one_config('narrow', 25)

In [None]:
ricci_one_config('narrow', 26)

In [None]:
ricci_one_config('narrow', 27)

In [None]:
ricci_one_config('narrow', 28)

In [None]:
ricci_one_config('narrow', 29)

In [None]:
ricci_one_config('narrow', 30)

---
# RICCI ANALYSIS: WIDE

In [None]:
ricci_one_config('wide', 3)

In [None]:
ricci_one_config('wide', 4)

In [None]:
ricci_one_config('wide', 5)

In [None]:
ricci_one_config('wide', 6)

In [None]:
ricci_one_config('wide', 7)

In [None]:
ricci_one_config('wide', 8)

In [None]:
ricci_one_config('wide', 9)

In [None]:
ricci_one_config('wide', 10)

In [None]:
ricci_one_config('wide', 11)

In [None]:
ricci_one_config('wide', 12)

In [None]:
ricci_one_config('wide', 13)

In [None]:
ricci_one_config('wide', 14)

In [None]:
ricci_one_config('wide', 15)

In [None]:
ricci_one_config('wide', 16)

In [None]:
ricci_one_config('wide', 17)

In [None]:
ricci_one_config('wide', 18)

In [None]:
ricci_one_config('wide', 19)

In [None]:
ricci_one_config('wide', 20)

In [None]:
ricci_one_config('wide', 21)

In [None]:
ricci_one_config('wide', 22)

In [None]:
ricci_one_config('wide', 23)

In [None]:
ricci_one_config('wide', 24)

In [None]:
ricci_one_config('wide', 25)

In [None]:
ricci_one_config('wide', 26)

In [None]:
ricci_one_config('wide', 27)

In [None]:
ricci_one_config('wide', 28)

In [None]:
ricci_one_config('wide', 29)

In [None]:
ricci_one_config('wide', 30)

---
# RICCI ANALYSIS: BOTTLENECK

In [None]:
ricci_one_config('bottleneck', 3)

In [None]:
ricci_one_config('bottleneck', 4)

In [None]:
ricci_one_config('bottleneck', 5)

In [None]:
ricci_one_config('bottleneck', 6)

In [None]:
ricci_one_config('bottleneck', 7)

In [None]:
ricci_one_config('bottleneck', 8)

In [None]:
ricci_one_config('bottleneck', 9)

In [None]:
ricci_one_config('bottleneck', 10)

In [None]:
ricci_one_config('bottleneck', 11)

In [None]:
ricci_one_config('bottleneck', 12)

In [None]:
ricci_one_config('bottleneck', 13)

In [None]:
ricci_one_config('bottleneck', 14)

In [None]:
ricci_one_config('bottleneck', 15)

In [None]:
ricci_one_config('bottleneck', 16)

In [None]:
ricci_one_config('bottleneck', 17)

In [None]:
ricci_one_config('bottleneck', 18)

In [None]:
ricci_one_config('bottleneck', 19)

In [None]:
ricci_one_config('bottleneck', 20)

In [None]:
ricci_one_config('bottleneck', 21)

In [None]:
ricci_one_config('bottleneck', 22)

In [None]:
ricci_one_config('bottleneck', 23)

In [None]:
ricci_one_config('bottleneck', 24)

In [None]:
ricci_one_config('bottleneck', 25)

In [None]:
ricci_one_config('bottleneck', 26)

In [None]:
ricci_one_config('bottleneck', 27)

In [None]:
ricci_one_config('bottleneck', 28)

In [None]:
ricci_one_config('bottleneck', 29)

In [None]:
ricci_one_config('bottleneck', 30)

---
# View Results

In [None]:
ricci_df = pd.read_csv(RICCI_CSV)
print(ricci_df.to_string(index=False))
print(f'\nOverall r_all mean: {ricci_df["r_all"].mean():.4f}')
print(f'Accuracy vs Rho correlation: {pearsonr(ricci_df["mean_acc"], ricci_df["mean_rho"])[0]:.4f}')

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
colors = {'narrow': 'blue', 'wide': 'green', 'bottleneck': 'red'}

# 1. Accuracy vs Rho
for arch in colors:
    d = ricci_df[ricci_df['architecture'] == arch]
    axes[0,0].scatter(d['mean_rho'], d['mean_acc'], c=colors[arch], label=arch, s=50)
axes[0,0].set_xlabel('Mean Ricci (ρ)'); axes[0,0].set_ylabel('Accuracy')
axes[0,0].set_title('Accuracy vs Ricci Coefficient'); axes[0,0].legend(); axes[0,0].grid(True)

# 2. r_all vs Depth
for arch in colors:
    d = ricci_df[ricci_df['architecture'] == arch]
    axes[0,1].plot(d['depth'], d['r_all'], 'o-', c=colors[arch], label=arch)
axes[0,1].set_xlabel('Depth'); axes[0,1].set_ylabel('r_all (Δg vs Ric)')
axes[0,1].set_title('Correlation Strength vs Depth'); axes[0,1].legend(); axes[0,1].grid(True)

# 3. Rho vs Depth
for arch in colors:
    d = ricci_df[ricci_df['architecture'] == arch]
    axes[1,0].plot(d['depth'], d['mean_rho'], 'o-', c=colors[arch], label=arch)
axes[1,0].set_xlabel('Depth'); axes[1,0].set_ylabel('Mean Ricci (ρ)')
axes[1,0].set_title('Ricci Coefficient vs Depth'); axes[1,0].legend(); axes[1,0].grid(True)

# 4. Accuracy vs Depth
for arch in colors:
    d = ricci_df[ricci_df['architecture'] == arch]
    axes[1,1].plot(d['depth'], d['mean_acc'], 'o-', c=colors[arch], label=arch)
axes[1,1].set_xlabel('Depth'); axes[1,1].set_ylabel('Accuracy')
axes[1,1].set_title('Accuracy vs Depth'); axes[1,1].legend(); axes[1,1].grid(True)

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'full_analysis.png'), dpi=150)
plt.show()