# Feature map visualization (toy)
Load baseline/pruned/retrained models and save activation grids for selected layers.

In [1]:
from pathlib import Path
import sys, json, torch

repo_root = Path("/mnt/hdd/ttoxopeus/basic_UNet")
if str(repo_root) not in sys.path:
    sys.path.insert(0, str(repo_root))

from toy.feature_maps import compare_feature_maps
from toy.datasets import get_synthetic_loaders, ShapesDatasetConfig
from src.models.unet import UNet  # allowlist for torch.load

# ---- configure your run (use a run created after full-model save fix) ----
run_dir = Path("/mnt/hdd/ttoxopeus/basic_UNet/toy/results/multiclass_tinyunet_20260127_141643")  # replace
baseline_ckpt = run_dir / "baseline.pth"           # state_dict
pruned_ckpt   = run_dir / "pruned_model_r0_3.pth"  # full model
retrained_ckpt = run_dir / "retrained_pruned_r0_3.pth"  # optional full model
device = torch.device("cpu")  # force CPU to avoid device mismatches during viz

# load model config from summary
summary = json.load(open(run_dir / "summary.json", "r"))
features = summary["model"]["features"]
out_ch = summary["model"]["out_ch"]
print("features", features, "out_ch", out_ch)

features [2, 4, 8] out_ch 4


In [2]:
# ---- get a sample (synthetic val) ----
cfg = ShapesDatasetConfig(num_samples=200, image_size=64, mode="multiclass", fg_classes=3, seed=42)
val_loader = get_synthetic_loaders(cfg=cfg, batch_size=1, val_ratio=0.2, num_workers=0)[1]
sample_img, _ = next(iter(val_loader))
sample_img = sample_img.to(device)
print("sample_img", sample_img.shape, "device", sample_img.device)

sample_img torch.Size([1, 1, 64, 64]) device cpu


In [3]:
# ---- choose layers and export ----
layers = ["encoders.0.net.0", "encoders.1.net.0", "encoders.2.net.0", "bottleneck.net.0"]
out_dir = run_dir / "feature_maps"

print("baseline exists", baseline_ckpt.exists(), "pruned exists", pruned_ckpt.exists(), "retrained exists", retrained_ckpt.exists())
print("device:", device)

compare_feature_maps(
    baseline_ckpt=baseline_ckpt,
    pruned_ckpt=pruned_ckpt,
    retrained_ckpt=retrained_ckpt if retrained_ckpt.exists() else None,
    sample=sample_img,
    layers=layers,
    features=features,
    out_ch=out_ch,
    device=device,
    out_dir=out_dir,
    max_channels=16,
)
print("Saved feature maps to", out_dir)

baseline exists True pruned exists True retrained exists True
device: cpu
[DEBUG] loaded pruned_model_r0_3.pth type: <class 'src.models.unet.UNet'>
[DEBUG] loaded retrained_pruned_r0_3.pth type: <class 'src.models.unet.UNet'>
ðŸ”§ Generating pruning masks...

Block encoders.0      | Layer encoders.0.net.0          | ratio=0.25 | prune=0 â†’ kept 2/2
Block encoders.0      | Layer encoders.0.net.3          | ratio=0.25 | prune=0 â†’ kept 2/2
Block encoders.1      | Layer encoders.1.net.0          | ratio=0.25 | pruned 1/4 | kept 3/4
Block encoders.1      | Layer encoders.1.net.3          | ratio=0.25 | pruned 1/4 | kept 3/4
Block encoders.2      | Layer encoders.2.net.0          | ratio=0.25 | pruned 2/8 | kept 6/8
Block encoders.2      | Layer encoders.2.net.3          | ratio=0.25 | pruned 2/8 | kept 6/8
Block bottleneck      | Layer bottleneck.net.0          | ratio=0.25 | pruned 4/16 | kept 12/16
Block bottleneck      | Layer bottleneck.net.3          | ratio=0.25 | pruned 4/16 | kep