In [1]:
import os
import json
import sys
import pandas as pd
import torch

PROJECT_ROOT = "/mnt/hdd/ttoxopeus/basic_UNet"

if PROJECT_ROOT not in sys.path:
    sys.path.append(PROJECT_ROOT)
    print(f"✅ Added project root to sys.path:\n   {PROJECT_ROOT}")
else:
    print(f"ℹ️ Project root already in sys.path:\n   {PROJECT_ROOT}")


from src.training.eval import evaluate
from src.models.unet import UNet
from src.pruning.model_inspect import model_to_dataframe_with_l1, get_pruning_masks_blockwise, compute_actual_prune_ratios, compute_l1_norms, compute_l1_stats, inspect_model_l1
from src.pruning.rebuild import rebuild_pruned_unet, find_prev_conv_name

✅ Added project root to sys.path:
   /mnt/hdd/ttoxopeus/basic_UNet


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load the model
model = UNet(in_ch=1, out_ch=4, enc_features=[64, 128, 256, 512, 512], dec_features=None, bottleneck_out=None)
state = torch.load("/mnt/hdd/ttoxopeus/basic_UNet/results/UNet_ACDC/exp48/baseline/training/final_model.pth", map_location="cpu")
model.load_state_dict(state)
model.eval()
print("✅ Model loaded successfully.")

✅ Model loaded successfully.


In [3]:
df_summary = inspect_model_l1(model, save_dir="./l1_analysis")
pd.set_option("display.max_rows", None)
display(df_summary)

Unnamed: 0,Layer,Type,Shape,In Ch,Out Ch,Num Params,Mean L1,Min L1,Max L1,L1 Std,Block Ratio,Post-Prune Ratio
0,encoders.0.net.0,Conv2d,"(64, 1, 3, 3)",1,64,576,1.513661,0.846551,2.152828,0.273044,,
1,encoders.0.net.1,BatchNorm2d,"(64,)",64,64,128,,,,,,
2,encoders.0.net.3,Conv2d,"(64, 64, 3, 3)",64,64,36864,19.995201,13.506686,27.711575,2.814086,,
3,encoders.0.net.4,BatchNorm2d,"(64,)",64,64,128,,,,,,
4,encoders.1.net.0,Conv2d,"(128, 64, 3, 3)",64,128,73728,20.709234,13.457222,31.478107,4.045705,,
5,encoders.1.net.1,BatchNorm2d,"(128,)",128,128,256,,,,,,
6,encoders.1.net.3,Conv2d,"(128, 128, 3, 3)",128,128,147456,32.650726,20.267021,48.277267,7.106401,,
7,encoders.1.net.4,BatchNorm2d,"(128,)",128,128,256,,,,,,
8,encoders.2.net.0,Conv2d,"(256, 128, 3, 3)",128,256,294912,36.69519,22.007919,52.998169,6.053425,,
9,encoders.2.net.1,BatchNorm2d,"(256,)",256,256,512,,,,,,


In [5]:
norms = compute_l1_norms(model)
stats = compute_l1_stats(norms)
df = model_to_dataframe_with_l1(model, stats, remove_nan_layers=False)

pd.set_option("display.max_rows", None)
display(df)

Unnamed: 0,Layer,Type,Shape,In Ch,Out Ch,Num Params,Mean L1,Min L1,Max L1,L1 Std,Block Ratio,Post-Prune Ratio
0,encoders.0.net.0,Conv2d,"(64, 1, 3, 3)",1,64,576,1.551331,0.941511,2.025421,0.252934,,
1,encoders.0.net.1,BatchNorm2d,"(64,)",64,64,128,,,,,,
2,encoders.0.net.3,Conv2d,"(64, 64, 3, 3)",64,64,36864,20.288008,14.83642,26.34874,2.82051,,
3,encoders.0.net.4,BatchNorm2d,"(64,)",64,64,128,,,,,,
4,encoders.1.net.0,Conv2d,"(128, 64, 3, 3)",64,128,73728,20.560497,13.653543,29.073994,4.398217,,
5,encoders.1.net.1,BatchNorm2d,"(128,)",128,128,256,,,,,,
6,encoders.1.net.3,Conv2d,"(128, 128, 3, 3)",128,128,147456,32.678024,21.367718,49.799278,6.812839,,
7,encoders.1.net.4,BatchNorm2d,"(128,)",128,128,256,,,,,,
8,encoders.2.net.0,Conv2d,"(256, 128, 3, 3)",128,256,294912,34.095963,20.778049,52.540321,6.267731,,
9,encoders.2.net.1,BatchNorm2d,"(256,)",256,256,512,,,,,,
