# Q6 (Bonus): Neural Collapse on Earlier Layers (NC1‚ÄìNC5)

This notebook studies how **all five Neural Collapse properties** emerge progressively through network depth.

| Metric | What it measures | Where computed |
|--------|-----------------|----------------|
| **NC1** | Within-class variability collapse | ALL layers |
| **NC2** equinorm | Class-mean norms equality | ALL layers |
| **NC2** equiangularity | Simplex ETF structure | ALL layers |
| **NC3** | Self-duality (W ‚âà M) | Penultimate only (needs classifier W) |
| **NC4** | NCC agreement with network | ALL layers |
| **NC5** | ID/OOD orthogonality | ALL layers (needs OOD data) |

Key insight (Papyan et al. 2020; Rangamani et al. 2023):
- NC forms **last-to-first** ‚Äî the penultimate layer collapses first
- Collapse propagates backward through the network during extended training
- Earlier layers may never fully collapse, especially if D < C

Reference:
> Papyan et al., *"Prevalence of Neural Collapse during the Terminal Phase of Deep Learning Training"*, PNAS 2020.
> Ben Ammar et al., *"NECO: Neural Collapse Based Out-of-Distribution Detection"*, ICLR 2024.

## Setup

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
if not os.path.exists('/content/OOD-Detection-Project---CSC_5IA23'):
    !git clone https://github.com/DiegoFleury/OOD-Detection-Project---CSC_5IA23/tree/contente/
%cd /content/OOD-Detection-Project---CSC_5IA23

In [None]:
!pip install -q torch torchvision matplotlib seaborn scikit-learn pyyaml imageio tqdm

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import yaml
import glob
import re
import os

from src.models import ResNet18
from src.data import get_cifar100_loaders

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
if device == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
with open('configs/config.yaml', 'r') as f:
    config = yaml.safe_load(f)

print("Configuration:")
print(yaml.dump(config, default_flow_style=False))

## 1. Load Data

- **CIFAR-100** (ID) ‚Äî for NC1‚ÄìNC4
- **SVHN** (OOD) ‚Äî for NC5

In [None]:
# ID data: CIFAR-100
print("Loading CIFAR-100 (ID)...")

train_loader, val_loader, test_loader = get_cifar100_loaders(
    data_dir=config['data']['data_dir'],
    batch_size=config['training']['batch_size'],
    num_workers=config['data']['num_workers'],
    augment=False,
    val_split=config['training']['val_split']
)

print(f"Train batches: {len(train_loader)}")
print(f"Test batches:  {len(test_loader)}")

In [None]:
# OOD data: SVHN (for NC5 computation at each layer)
import torchvision
import torchvision.transforms as transforms

ood_transform = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.5071, 0.4867, 0.4408],
        std=[0.2675, 0.2565, 0.2761]
    ),
])

print("Loading SVHN (OOD for NC5)...")
svhn_dataset = torchvision.datasets.SVHN(
    root=config['data']['data_dir'], split='test',
    transform=ood_transform, download=True,
)
svhn_loader = torch.utils.data.DataLoader(
    svhn_dataset, batch_size=config['training']['batch_size'],
    shuffle=False, num_workers=config['data']['num_workers'],
)
print(f"SVHN test samples: {len(svhn_dataset)}")

## 2. Load Model

In [None]:
model = ResNet18(num_classes=config['model']['num_classes'])

checkpoint_dir = config['paths']['checkpoints']
checkpoints = glob.glob(os.path.join(checkpoint_dir, 'resnet18_cifar100_*.pth'))

def get_epoch_num(path):
    match = re.search(r'epoch(\d+)', path)
    return int(match.group(1)) if match else 0

latest = max(checkpoints, key=get_epoch_num)
epoch_num = get_epoch_num(latest)

ckpt = torch.load(latest, map_location=device, weights_only=False)
if isinstance(ckpt, dict) and 'model_state_dict' in ckpt:
    model.load_state_dict(ckpt['model_state_dict'])
elif isinstance(ckpt, dict) and 'state_dict' in ckpt:
    model.load_state_dict(ckpt['state_dict'])
else:
    model.load_state_dict(ckpt)

model = model.to(device)
model.eval()

print(f"‚úÖ Loaded: {os.path.basename(latest)} (epoch {epoch_num})")

## 3. Import Module

In [None]:
from src.neural_collapse.nc_earlier_layer import (
    analyze_layers_single_checkpoint,
    analyze_layers_across_checkpoints,
    plot_nc_by_layer,
    plot_nc_layers_across_epochs,
    plot_nc_heatmap,
    save_layer_metrics_yaml,
    LayerNCResult,
    LayerNCTracker,
)

print("‚úÖ NC earlier layer module imported!")

In [None]:
figures_dir = os.path.join(config['paths']['figures'], 'nc_layers')
metrics_dir = config['paths']['metrics']
os.makedirs(figures_dir, exist_ok=True)
os.makedirs(metrics_dir, exist_ok=True)

## 4. NC1‚ÄìNC5 Across Layers (Final Checkpoint)

ResNet-18 architecture:

| Layer | Feature dim D | Spatial | D vs C=100 |
|-------|--------------|---------|------------|
| layer1 | 64 | 32√ó32 | D < C ‚ö†Ô∏è NC1 ill-conditioned |
| layer2 | 128 | 16√ó16 | D > C |
| layer3 | 256 | 8√ó8 | D > C |
| layer4 | 512 | 4√ó4 | D > C |
| penultimate | 512 | (GAP) | D > C |

Features at each layer are Global-Average-Pooled to (B, D) before computing metrics.
NC5 uses SVHN as OOD data.

In [None]:
print("üî¨ Computing NC1‚ÄìNC5 at each layer (final checkpoint)...")
print(f"   Model: ResNet-18 / CIFAR-100 (epoch {epoch_num})")
print(f"   OOD: SVHN (for NC5)")
print(f"   Device: {device}")
print()

layer_results = analyze_layers_single_checkpoint(
    model=model,
    loader=train_loader,
    device=device,
    num_classes=config['model']['num_classes'],
    layer_names=['layer1', 'layer2', 'layer3', 'layer4'],
    include_penultimate=True,
    ood_loader=svhn_loader,   # ‚Üê enables NC5 at every layer
)

# Print full table
print(f"{'Layer':<14s} {'D':>4s} {'NC1':>9s} {'NC2norm':>9s} {'NC2ang':>9s}"
      f" {'NC3':>9s} {'NC4':>9s} {'NC5':>9s}")
print("-" * 80)
for r in layer_results:
    nc3 = f"{r.nc3_w_m_dist:.4f}" if r.nc3_w_m_dist is not None else "    ‚Äî"
    nc4 = f"{r.nc4_ncc_mismatch:.4f}" if r.nc4_ncc_mismatch is not None else "    ‚Äî"
    nc5 = f"{r.nc5_orthodev:.4f}" if r.nc5_orthodev is not None else "    ‚Äî"
    print(f"{r.layer_name:<14s} {r.feature_dim:>4d} {r.nc1:>9.4f} "
          f"{r.nc2_equinorm:>9.4f} {r.nc2_equiangularity:>9.4f} "
          f"{nc3:>9s} {nc4:>9s} {nc5:>9s}")

In [None]:
# Full bar chart: NC1‚ÄìNC5 per layer (2√ó3 grid)
fig_bars = plot_nc_by_layer(
    layer_results,
    title_suffix=f" ‚Äî Epoch {epoch_num}",
    save_dir=figures_dir,
)
plt.show()

### 4.1 Interpretation

In [None]:
print("\n" + "=" * 70)
print("LAYER-WISE NC1‚ÄìNC5 ANALYSIS")
print("=" * 70)

# NC1 progression
print("\n--- NC1: Activation Collapse (shallow ‚Üí deep) ---")
for r in layer_results:
    bar_len = min(int(r.nc1 / max(lr.nc1 for lr in layer_results) * 30), 30)
    bar = "‚ñà" * bar_len
    note = "  ‚ö†Ô∏è D < C" if r.feature_dim < config['model']['num_classes'] else ""
    print(f"  {r.layer_name:<14s} (D={r.feature_dim:>3d}) NC1={r.nc1:>8.2f}  {bar}{note}")

# NC4 progression
print("\n--- NC4: NCC Agreement (shallow ‚Üí deep) ---")
for r in layer_results:
    if r.nc4_ncc_mismatch is not None:
        agreement = (1 - r.nc4_ncc_mismatch) * 100
        bar = "‚ñà" * int(agreement / 100 * 30)
        print(f"  {r.layer_name:<14s} NCC agrees with network: {agreement:.1f}%  {bar}")

# NC5 progression
print("\n--- NC5: ID/OOD Orthogonality (shallow ‚Üí deep) ---")
for r in layer_results:
    if r.nc5_orthodev is not None:
        print(f"  {r.layer_name:<14s} OrthoDev = {r.nc5_orthodev:.4f}")

# NC3 (penultimate only)
pen = next((r for r in layer_results if r.layer_name == 'penultimate'), None)
if pen and pen.nc3_w_m_dist is not None:
    print(f"\n--- NC3: Self-Duality (penultimate only) ---")
    print(f"  ‚ÄñW^T ‚àí MÃÉ‚Äñ¬≤ = {pen.nc3_w_m_dist:.4f}")

# Monotonicity check
nc1_list = [r.nc1 for r in layer_results]
monotonic = all(nc1_list[i] >= nc1_list[i+1] for i in range(len(nc1_list)-1))
print(f"\n‚úÖ NC1 monotonically decreasing: {monotonic}")

ratio = layer_results[0].nc1 / layer_results[-1].nc1 if layer_results[-1].nc1 > 0 else float('inf')
print(f"   NC1 ratio (shallowest / deepest): {ratio:.1f}x")

print("\n" + "=" * 70)

## 5. NC Propagation Across Training Epochs

We compute NC1‚ÄìNC5 at each layer for each checkpoint.
This reveals the **temporal dynamics** of how collapse propagates backward.

In [None]:
print("üî¨ Analysing NC1‚ÄìNC5 across layers AND epochs...")
print(f"   (one forward pass per checkpoint √ó {4+1} layers + OOD pass)")
print()

tracker = analyze_layers_across_checkpoints(
    checkpoint_dir=checkpoint_dir,
    model_class=ResNet18,
    loader=train_loader,
    device=device,
    num_classes=config['model']['num_classes'],
    layer_names=['layer1', 'layer2', 'layer3', 'layer4'],
    ood_loader=svhn_loader,  # ‚Üê enables NC5 tracking across epochs
    checkpoint_pattern='resnet18_cifar100_*.pth',
    epoch_regex=r'epoch(\d+)',
    verbose=True,
)

print("\n" + tracker.summary())

### 5.1 Line Plots ‚Äî NC per Layer Across Epochs

Each line = one layer.  Watch collapse propagate from `penultimate` ‚Üí `layer1`.

In [None]:
fig_lines = plot_nc_layers_across_epochs(
    tracker, save_dir=figures_dir,
)
plt.show()

### 5.2 Heatmaps ‚Äî Collapse Propagation

Heatmaps (layer √ó epoch) reveal the backward propagation of collapse:
darker colours = more collapsed.  Should see the bottom rows (deep layers) darken first.

In [None]:
# NC1 heatmap
fig_hm1 = plot_nc_heatmap(tracker, metric='nc1', save_dir=figures_dir)
plt.show()

In [None]:
# NC4 heatmap ‚Äî NCC mismatch across layers
fig_hm_nc4 = plot_nc_heatmap(tracker, metric='nc4_ncc_mismatch', save_dir=figures_dir)
plt.show()

In [None]:
# NC5 heatmap ‚Äî ID/OOD orthogonality across layers
fig_hm_nc5 = plot_nc_heatmap(tracker, metric='nc5_orthodev', save_dir=figures_dir)
plt.show()

In [None]:
# NC2 equiangularity heatmap
fig_hm2 = plot_nc_heatmap(tracker, metric='nc2_equiangularity', save_dir=figures_dir)
plt.show()

In [None]:
# NC2 equinorm heatmap
fig_hm3 = plot_nc_heatmap(tracker, metric='nc2_equinorm', save_dir=figures_dir)
plt.show()

### 5.3 When Does Each Layer Collapse?

In [None]:
print("\n" + "=" * 60)
print("COLLAPSE TIMELINE")
print("=" * 60)

pen_final_nc1 = tracker.data['penultimate']['nc1'][-1]
threshold = pen_final_nc1 * 2
print(f"\nThreshold: NC1 < {threshold:.2f} (= 2 √ó penultimate final NC1)\n")

for layer in tracker.layer_names:
    nc1_series = tracker.data[layer]['nc1']
    collapse_epoch = None
    for i, val in enumerate(nc1_series):
        if val < threshold:
            collapse_epoch = tracker.epochs[i]
            break

    if collapse_epoch is not None:
        print(f"  {layer:<14s}: NC1 < {threshold:.2f} at epoch {collapse_epoch}")
    else:
        final = nc1_series[-1] if nc1_series else float('nan')
        print(f"  {layer:<14s}: never reached threshold (final NC1 = {final:.2f})")

# NC4 timeline: when does NCC agree >90% with network?
print(f"\nNC4 timeline: when does NCC agreement > 90%?\n")
for layer in tracker.layer_names:
    nc4_series = tracker.data[layer]['nc4_ncc_mismatch']
    agree_epoch = None
    for i, val in enumerate(nc4_series):
        if not np.isnan(val) and val < 0.10:  # >90% agreement
            agree_epoch = tracker.epochs[i]
            break
    if agree_epoch is not None:
        print(f"  {layer:<14s}: >90% NCC agreement at epoch {agree_epoch}")
    else:
        final = nc4_series[-1] if nc4_series else float('nan')
        agreement = (1 - final) * 100 if not np.isnan(final) else float('nan')
        print(f"  {layer:<14s}: never reached 90% (final: {agreement:.1f}%)")

print("\n" + "=" * 60)

## 6. Save Metrics

In [None]:
save_layer_metrics_yaml(
    tracker,
    os.path.join(metrics_dir, 'nc_earlier_layers_metrics.yaml'),
)

# Single-checkpoint results as JSON
import json

single_results = []
for r in layer_results:
    d = {
        'layer': r.layer_name,
        'feature_dim': r.feature_dim,
        'nc1': r.nc1,
        'nc2_equinorm': r.nc2_equinorm,
        'nc2_equiangularity': r.nc2_equiangularity,
    }
    if r.nc3_w_m_dist is not None:
        d['nc3_w_m_dist'] = r.nc3_w_m_dist
    if r.nc4_ncc_mismatch is not None:
        d['nc4_ncc_mismatch'] = r.nc4_ncc_mismatch
    if r.nc5_orthodev is not None:
        d['nc5_orthodev'] = r.nc5_orthodev
    single_results.append(d)

json_path = os.path.join(metrics_dir, 'nc_earlier_layers_final.json')
with open(json_path, 'w') as f:
    json.dump({
        'epoch': epoch_num,
        'ood_dataset': 'SVHN',
        'layers': single_results,
    }, f, indent=2)
print(f"üíæ Saved: {json_path}")

## 7. Final Summary

In [None]:
print("\n" + "=" * 70)
print("EARLIER LAYERS NC ANALYSIS ‚Äî FINAL SUMMARY")
print("=" * 70)

print(f"\nModel: ResNet-18 / CIFAR-100 | OOD: SVHN")
print(f"Checkpoints analyzed: {len(tracker.epochs)}")
print(f"Layers: {', '.join(tracker.layer_names)}")
print(f"Metrics: NC1, NC2 (equinorm + equiangularity), NC3, NC4, NC5")

print(f"\n--- Final metrics per layer (epoch {epoch_num}) ---")
print(f"{'Layer':<14s} {'D':>4s} {'NC1':>8s} {'NC4':>8s} {'NC5':>8s}")
print("-" * 48)
for r in layer_results:
    nc4 = f"{r.nc4_ncc_mismatch:.4f}" if r.nc4_ncc_mismatch is not None else "  ‚Äî"
    nc5 = f"{r.nc5_orthodev:.4f}" if r.nc5_orthodev is not None else "  ‚Äî"
    print(f"{r.layer_name:<14s} {r.feature_dim:>4d} {r.nc1:>8.4f} {nc4:>8s} {nc5:>8s}")

# Key observations
print(f"\n--- Key observations ---")
nc1_list = [r.nc1 for r in layer_results]
monotonic = all(nc1_list[i] >= nc1_list[i+1] for i in range(len(nc1_list)-1))
print(f"  NC1 monotonically decreasing (shallow‚Üídeep): {'‚úÖ Yes' if monotonic else '‚ùå No'}")

nc4_vals = [(r.layer_name, r.nc4_ncc_mismatch) for r in layer_results if r.nc4_ncc_mismatch is not None]
if nc4_vals:
    best_ncc = min(nc4_vals, key=lambda x: x[1])
    print(f"  Best NCC agreement: {best_ncc[0]} ({(1-best_ncc[1])*100:.1f}%)")

nc5_vals = [(r.layer_name, r.nc5_orthodev) for r in layer_results if r.nc5_orthodev is not None]
if nc5_vals:
    best_nc5 = min(nc5_vals, key=lambda x: x[1])
    print(f"  Best ID/OOD orthogonality: {best_nc5[0]} (NC5={best_nc5[1]:.4f})")

print(f"\n--- Files saved ---")
print(f"  Figures: {figures_dir}/")
print(f"  Metrics: {metrics_dir}/")

print("\n" + "=" * 70)

## 8. Commit Results to GitHub

In [None]:
# !git add results/figures/nc_layers/
# !git add results/metrics/nc_earlier_layers_*
# !git commit -m "Add Q6 bonus: NC1-NC5 on earlier layers"
# !git push
#
# print("Results committed to GitHub!")