# LLA Demo Notebook

End-to-end demo for the Loss-Landscape-and-Analysis (LLA) pipeline in this repo.

What this notebook does:
- Loads the lightweight pipeline in `lla_pipeline.py`.
- Builds a small model/dataset from config.
- Runs a very short fine-tune to produce checkpoints.
- Computes Hessian axes, evaluates a loss plane, approximates a SAM surface,
  estimates a Hessian spectrum, evaluates a simple mode-connectivity curve,
  and collects basic rank statistics at checkpoints.

Outputs are written to `experiments/Hessian_and_landscape_plot_with_plasticity_loss/results/demo_notebook`.

In [1]:
from pathlib import Path
import json, os, time, sys
import numpy as np
import matplotlib.pyplot as plt
import torch

# Load YAML config with robust path resolution and set up imports
import yaml
# Try to resolve the experiment directory whether the working directory is repo root or the notebook's folder.
try:
    nb_dir = Path(__file__).resolve().parent
except NameError:
    nb_dir = Path.cwd()

# If we're already inside the experiment folder (where this notebook lives)
if (nb_dir / 'lla_pipeline.py').exists() and (nb_dir / 'cfg' / 'config.yaml').exists():
    exp_dir = nb_dir
# Otherwise, assume repo-root working dir and compose the experiment path
elif (nb_dir / 'experiments' / 'Hessian_and_landscape_plot_with_plasticity_loss').exists():
    exp_dir = nb_dir / 'experiments' / 'Hessian_and_landscape_plot_with_plasticity_loss'
elif (Path.cwd() / 'experiments' / 'Hessian_and_landscape_plot_with_plasticity_loss').exists():
    exp_dir = Path.cwd() / 'experiments' / 'Hessian_and_landscape_plot_with_plasticity_loss'
else:
    raise FileNotFoundError('Cannot locate experiment directory. Please run this notebook from the project workspace or the experiment folder.')

# Ensure we can import lla_pipeline.py directly
if str(exp_dir) not in sys.path:
    sys.path.insert(0, str(exp_dir))

# Import the pipeline utilities (from local file in this folder)
from lla_pipeline import (
    prepare_data_and_model,
    quick_finetune_and_checkpoint,
    get_hessian_axes,
    plot_plane,
    compute_spectrum,
    plot_sam_surface,
    fit_mode_connectivity,
    plot_mode_curve,
    compute_rank_stats_at_checkpoints,
)

cfg_path = exp_dir / 'cfg' / 'config.yaml'
with open(cfg_path, 'r') as f:
    cfg = yaml.safe_load(f)

# Light overrides for fast demo
cfg['seed'] = 123
cfg['lla']['training']['epochs_short'] = 0  # keep at 0 or 1 for speed
cfg['lla']['training']['max_checkpoints'] = 4
cfg['lla']['evaluation_data']['eval_batch_size'] = 128
cfg['lla']['planes']['grid_resolution'] = 21
cfg['lla']['spectrum']['top_k'] = 3
cfg['lla']['spectrum']['hutchinson_probes'] = 4

device = cfg.get('device', 'cuda:0' if torch.cuda.is_available() else 'cpu')
cfg['device'] = device
print('Using device:', device)

Using device: cuda:0


In [2]:
# Prepare data and model
train_loader, eval_loader, model = prepare_data_and_model(cfg)
print('Batches per epoch (train):', len(train_loader))
print('Eval batch size:', cfg['lla']['evaluation_data']['eval_batch_size'])

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Batches per epoch (train): 196
Eval batch size: 128
Batches per epoch (train): 196
Eval batch size: 128


In [4]:
# Create output directory for this demo run
out_dir = exp_dir / 'results' / 'demo_notebook'
out_dir.mkdir(parents=True, exist_ok=True)
print('Output dir:', out_dir)

# Short fine-tune and checkpoints
ckpts = quick_finetune_and_checkpoint(cfg, model, train_loader, out_dir)
print('Checkpoints:', [Path(p).name for p in ckpts])

# Load last checkpoint as base state
last_ckpt = ckpts[-1]
base = torch.load(last_ckpt, map_location=device)
base_state = base.get('state_dict', base)
model.load_state_dict({**model.state_dict(), **base_state}, strict=False)
print('Loaded base state from', Path(last_ckpt).name)

Output dir: /hdda/models/my_own_models/architectural_and_learning_on_loss_landscape/experiments/Hessian_and_landscape_plot_with_plasticity_loss/results/demo_notebook
Checkpoints: ['ckpt_final.pt']
Loaded base state from ckpt_final.pt


In [6]:
# Compute Hessian axes and plot a loss plane
v1, v2, meta = get_hessian_axes(model, eval_loader, cfg)
print('Hessian meta:', meta)
planes_dir = out_dir / 'planes'
planes_dir.mkdir(exist_ok=True)
plane_info = plot_plane(model, base_state, (v1, v2), eval_loader, cfg, planes_dir, label='hessian')
print('Plane info:', plane_info)

# Show the heatmap
img_path = Path(plane_info['png'])
img = plt.imread(img_path)
plt.figure(figsize=(5,4))
plt.imshow(img)
plt.axis('off')
plt.show()

Hessian meta: {'rayleigh_top1': 0.35403069853782654, 'rayleigh_top2': 0.29969140887260437, 'dim': 178250, 'power_iters': 20, 'algorithm': 'power_iteration+deflation', 'time_seconds': 0.28147125244140625}
[TIMING][plane:hessian] grid_eval_forward_passes=441 total_time=0.62s avg_per_point=0.0014s
Plane info: {'grid_resolution': 21, 'span': {'alpha': [-1.0, 1.0], 'beta': [-1.0, 1.0]}, 'npy': '/hdda/models/my_own_models/architectural_and_learning_on_loss_landscape/experiments/Hessian_and_landscape_plot_with_plasticity_loss/results/demo_notebook/planes/plane_hessian.npy', 'png': '/hdda/models/my_own_models/architectural_and_learning_on_loss_landscape/experiments/Hessian_and_landscape_plot_with_plasticity_loss/results/demo_notebook/planes/plane_hessian.png'}
[TIMING][plane:hessian] grid_eval_forward_passes=441 total_time=0.62s avg_per_point=0.0014s
Plane info: {'grid_resolution': 21, 'span': {'alpha': [-1.0, 1.0], 'beta': [-1.0, 1.0]}, 'npy': '/hdda/models/my_own_models/architectural_and_lea

In [7]:
# Hessian spectrum (top-k) + Hutchinson trace
spec = compute_spectrum(model, eval_loader, cfg, out_dir)
print(json.dumps(spec, indent=2))

[TIMING][spectrum] dim=178250 topk(k=3)=0.36s hutchinson(probes=4)=0.04s esd_slq(probes=8,m=50)=4.71s total=5.11s
[METRICS][spectrum] top2_eigs≈ [0.2517, 0.1399] (k=3)
{
  "top_k": [
    0.2516879737377167,
    0.13988856971263885,
    -0.14011691510677338
  ],
  "hutchinson_trace": -4.267018139362335,
  "dim": 178250,
  "k_used": 3,
  "probes": 4,
  "timing_seconds": {
    "topk": 0.36278820037841797,
    "hutchinson_trace": 0.03934645652770996,
    "esd_slq": 4.7093658447265625,
    "total": 5.11150050163269
  },
  "esd": {
    "n_grid": 200,
    "probes": 8,
    "m": 50,
    "lanczos_m": 50,
    "avg_steps": 50.0,
    "lam_min": -0.347922095656395,
    "lam_max": 0.3886063784360886,
    "sigma": 0.01339143680168152,
    "png": "/hdda/models/my_own_models/architectural_and_learning_on_loss_landscape/experiments/Hessian_and_landscape_plot_with_plasticity_loss/results/demo_notebook/spectrum_esd.png",
    "npy": "/hdda/models/my_own_models/architectural_and_learning_on_loss_landscape/ex

In [8]:
# SAM approximate surface on the same plane
sam_info = plot_sam_surface(model, base_state, (v1, v2), eval_loader, cfg, out_dir)
print(sam_info)
img = plt.imread(sam_info['png'])
plt.figure(figsize=(5,4))
plt.imshow(img)
plt.axis('off')
plt.show()

[TIMING][sam_surface] grid_eval_points=441 rho=0.05 total_time=1.63s avg_per_point=0.0037s (includes grad norm)
{'npy': '/hdda/models/my_own_models/architectural_and_learning_on_loss_landscape/experiments/Hessian_and_landscape_plot_with_plasticity_loss/results/demo_notebook/sam_surface.npy', 'png': '/hdda/models/my_own_models/architectural_and_learning_on_loss_landscape/experiments/Hessian_and_landscape_plot_with_plasticity_loss/results/demo_notebook/sam_surface.png', 'rho': 0.05}


In [9]:
# Mode connectivity curve between first and last ckpt
def _model_fn():
    from src.models.model_factory import model_factory
    return model_factory(cfg['net'])

A_state = torch.load(ckpts[0], map_location=device)
A = A_state.get('state_dict', A_state)
B_state = torch.load(ckpts[-1], map_location=device)
B = B_state.get('state_dict', B_state)
thetaM, _ = fit_mode_connectivity(_model_fn, A, B, cfg, out_dir)
mode_info = plot_mode_curve(_model_fn, A, thetaM, B, cfg, out_dir)
print(mode_info)
img = plt.imread(mode_info['png'])
plt.figure(figsize=(5,4))
plt.plot(np.load(mode_info['npy']))
plt.xlabel('t index')
plt.ylabel('loss')
plt.title('Mode connectivity (loss vs. t)')
plt.show()

AttributeError: 'dict' object has no attribute 'type'

In [None]:
# Rank stats across checkpoints (keep list short for speed)
subset_ckpts = ckpts[-3:] if len(ckpts) > 3 else ckpts
rank_stats = compute_rank_stats_at_checkpoints(cfg['net'], subset_ckpts, cfg, eval_loader)
print('Computed rank stats for', len(rank_stats['checkpoints']), 'checkpoints')
print(json.dumps(rank_stats['checkpoints'][0], indent=2)[:1000])  # preview first checkpoint only
# Save full JSON
(out_dir / 'rank_stats_demo.json').write_text(json.dumps(rank_stats, indent=2))
print('Saved to', out_dir / 'rank_stats_demo.json')

## Notes
- Adjust `grid_resolution`, `top_k`, `epochs_short`, and `eval_batch_size` above for fidelity vs. speed.
- All functions are lightweight and self-contained (no external repos).
- Artifacts are saved under the `results/demo_notebook` folder for easy review.

## 3D views of 2D loss surfaces

Below we render 3D surfaces for the Hessian-aligned plane, random-direction plane, random-base plane, and the SAM robust surface.
Each figure saves alongside the 2D heatmaps as *_3d.png.

In [None]:
# Render 3D: Hessian-aligned plane
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  # noqa: F401

plane_npy = planes_dir / 'plane_hessian.npy'
Z = np.load(plane_npy)
# alpha/beta extents from spec span in cfg (or infer from shape/grid)
res = Z.shape[0]
a0, a1 = -cfg['lla']['planes'].get('span', 1.0), cfg['lla']['planes'].get('span', 1.0)
b0, b1 = a0, a1
X, Y = np.meshgrid(np.linspace(a0, a1, res), np.linspace(b0, b1, res), indexing='ij')

fig = plt.figure(figsize=(6,5))
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(X, Y, Z, cmap='viridis', linewidth=0, antialiased=True)
ax.set_xlabel('alpha'); ax.set_ylabel('beta'); ax.set_zlabel('loss')
ax.set_title('Loss plane 3D: hessian')
plt.show()

In [None]:
# Render 3D: random-direction plane
Z = np.load(planes_dir / 'plane_random_dirs.npy')
res = Z.shape[0]
a0, a1 = -cfg['lla']['planes'].get('span', 1.0), cfg['lla']['planes'].get('span', 1.0)
b0, b1 = a0, a1
X, Y = np.meshgrid(np.linspace(a0, a1, res), np.linspace(b0, b1, res), indexing='ij')

fig = plt.figure(figsize=(6,5))
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(X, Y, Z, cmap='viridis', linewidth=0, antialiased=True)
ax.set_xlabel('alpha'); ax.set_ylabel('beta'); ax.set_zlabel('loss')
ax.set_title('Loss plane 3D: random_dirs')
plt.show()

In [None]:
# Render 3D: random-base plane
Z = np.load(planes_dir / 'plane_random_base.npy')
res = Z.shape[0]
a0, a1 = -cfg['lla']['planes'].get('span', 1.0), cfg['lla']['planes'].get('span', 1.0)
b0, b1 = a0, a1
X, Y = np.meshgrid(np.linspace(a0, a1, res), np.linspace(b0, b1, res), indexing='ij')

fig = plt.figure(figsize=(6,5))
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(X, Y, Z, cmap='viridis', linewidth=0, antialiased=True)
ax.set_xlabel('alpha'); ax.set_ylabel('beta'); ax.set_zlabel('loss')
ax.set_title('Loss plane 3D: random_base')
plt.show()

In [None]:
# Render 3D: SAM robust surface
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  # noqa: F401

Z = np.load(out_dir / 'sam_surface.npy')
res = Z.shape[0]
a0, a1 = -cfg['lla']['planes'].get('span', 1.0), cfg['lla']['planes'].get('span', 1.0)
b0, b1 = a0, a1
X, Y = np.meshgrid(np.linspace(a0, a1, res), np.linspace(b0, b1, res), indexing='ij')

fig = plt.figure(figsize=(6,5))
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(X, Y, Z, cmap='magma', linewidth=0, antialiased=True)
ax.set_xlabel('alpha'); ax.set_ylabel('beta'); ax.set_zlabel('robust loss')
ax.set_title('SAM approx surface (3D)')
plt.show()