In [3]:
import torch
import torch.nn as nn

def list_leaf_layers(model: nn.Module):
    """
    Returns a list [(idx, name, module)] for all modules "leaf"
    (Conv2d, ReLU, MaxPool2d, Linear, BatchNorm1D, etc.) in order.
    """
    layers = []
    idx = 0
    for name, m in model.named_modules():
        if name == "":
            continue
        if len(list(m.children())) == 0:
            layers.append((idx, name, m))
            idx += 1
    return layers

In [4]:
def print_layer_table(model: nn.Module, max_rows=80):
    layers = list_leaf_layers(model)
    for idx, name, m in layers[:max_rows]:
        print(f"{idx:>3}  {name:<40}  {m.__class__.__name__}")
    return layers

In [5]:
class LayerIndexLogger:
    """
    Hook the exit of layers where indices are in picked_indices.
    """
    def __init__(self, model: nn.Module, picked_indices):
        self.model = model
        self.picked = set(picked_indices)
        self.latest = {}
        self.handles = []

        layers = list_leaf_layers(model)
        self.idx_to_name = {idx: name for idx, name, _ in layers}

        for idx, name, m in layers:
            if idx in self.picked:
                h = m.register_forward_hook(self._make_hook(idx))
                self.handles.append(h)

    def _make_hook(self, idx):
        def hook(mod, inp, out):
            self.latest[idx] = out.detach().cpu()
        return hook

    @torch.no_grad()
    def snapshot(self, x):
        self.latest = {}
        self.model.eval()
        _ = self.model(x)
        return self.latest

    def close(self):
        for h in self.handles:
            h.remove()

In [6]:
def to_BD(t: torch.Tensor):
    if t.dim() == 2:
        return t
    return t.flatten(start_dim=1)

In [40]:
import numpy as np
import matplotlib.pyplot as plt

def ridgeline_subplot(
    ax, A_BD, color,
    max_features=140,
    bins=120,
    xlim=(-8, 8),
    dy=0.012,
    lw=0.9,
    alpha=0.55
):
    """
    A_BD: np.array [B, D]
    Ridgeline avec aire remplie + contour noir
    """
    B, D = A_BD.shape
    idx = np.random.choice(D, size=min(D, max_features), replace=False)

    # tri visuel
    idx = idx[np.argsort(A_BD[:, idx].mean(axis=0))]

    for k, j in enumerate(idx):
        hist, edges = np.histogram(
            A_BD[:, j],
            bins=bins,
            range=xlim,
            density=True
        )
        centers = (edges[:-1] + edges[1:]) / 2
        y = hist + k * dy

        ax.fill_between(
            centers,
            k * dy,
            y,
            color=color,
            alpha=alpha,
            linewidth=0
        )

        ax.plot(
            centers,
            y,
            color="#000000",
            linewidth=lw,
            alpha=0.15
        )

    ax.set_xlim(*xlim)
    ax.set_yticks([])
    ax.grid(True, alpha=0.12)

In [None]:
def plot_layers_ridgeline_compare(
    acts_std, acts_bn,
    picked_indices,
    idx_to_name,
    save_path="layers_compare.png",
    color_std="#E4572E",
    color_bn="#1F77B4",
    max_features=140,
    bins=120,
    xlim=(-8, 8),
    dy=0.012
):
    fig, axes = plt.subplots(len(picked_indices), 2, figsize=(8, 9), sharex=True)

    if len(picked_indices) == 1:
        axes = np.array([axes])

    axes[0, 0].set_title("Standard")
    axes[0, 1].set_title("Standard + BatchNorm")

    for r, idx in enumerate(picked_indices):
        A_std = to_BD(acts_std[idx]).numpy()
        A_bn  = to_BD(acts_bn[idx]).numpy()

        ridgeline_subplot(axes[r, 0], A_std, color_std, max_features=max_features, bins=bins, xlim=xlim, dy=dy)
        ridgeline_subplot(axes[r, 1], A_bn,  color_bn,  max_features=max_features, bins=bins, xlim=xlim, dy=dy)

        label = f"Layer #{idx}  ({idx_to_name.get(idx, '???')})"
        axes[r, 0].set_ylabel(label)

    axes[-1, 0].set_xlabel("Activation value")
    axes[-1, 1].set_xlabel("Activation value")

    plt.tight_layout()
    fig.savefig(save_path, dpi=200)
    plt.close(fig)
    return save_path

In [52]:
import torch
import torch.nn as nn

import sys
from pathlib import Path
sys.path.append(str(Path("..").resolve()))

from data import get_dataloaders
from models.mlp import MLP
from models.cnn import ConvMLP


device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(0)

# 1) Models
# model_std = ConvMLP(in_channels=3, input_size=32, use_bn=False).to(device)
# model_bn  = ConvMLP(in_channels=3, input_size=32, use_bn=True).to(device)

model_std = MLP(input_dim=3 * 32 * 32, hidden_dims=(512, 256, 128), use_bn=False).to(device)
model_bn = MLP(input_dim=3 * 32 * 32, hidden_dims=(512, 256, 128), use_bn=True).to(device)

# 2) OPTIONNAL: to watch table of layers
print("=== Leaf layers (Standard) ===")
layers = print_layer_table(model_std, max_rows=120)

# 3) Choose layers
picked = [0, 6]  # <-- modifie selon ce que tu vois dans la table

train_loader, test_loader = get_dataloaders(root=r"C:\Users\ryans\Desktop\projet_programation\Advanced_ML_BN_2025\data\cifar-10-batches-py")

# 4) Batch from our train_loader
fixed_x, fixed_y = next(iter(train_loader))
fixed_x = fixed_x.to(device)

# 5) Snapshot activations
logger_std = LayerIndexLogger(model_std, picked)
logger_bn  = LayerIndexLogger(model_bn,  picked)

acts_std = logger_std.snapshot(fixed_x)
acts_bn  = logger_bn.snapshot(fixed_x)

# 6) Plot ridgeline
out = plot_layers_ridgeline_compare(
    acts_std=acts_std,
    acts_bn=acts_bn,
    picked_indices=picked,
    idx_to_name=logger_std.idx_to_name,
    save_path=r"C:\Users\ryans\Desktop\projet_programation\Advanced_ML_BN_2025\image\layers_mlp.png",
    max_features=30,
     bins=30,
    xlim=(-6, 6),
    dy=0.05
)
print("Saved:", out)

logger_std.close()
logger_bn.close()

=== Leaf layers (Standard) ===
  0  net.0                                     Linear
  1  net.1                                     ReLU
  2  net.2                                     Linear
  3  net.3                                     ReLU
  4  net.4                                     Linear
  5  net.5                                     ReLU
  6  net.6                                     Linear
Saved: C:\Users\ryans\Desktop\projet_programation\Advanced_ML_BN_2025\image\layers_mlp.png
