In [3]:
import os
import sys

PROJECT_ROOT = "/media/ttoxopeus/basic_UNet"

if PROJECT_ROOT not in sys.path:
    sys.path.append(PROJECT_ROOT)
    print(f"✅ Added project root to sys.path:\n   {PROJECT_ROOT}")
else:
    print(f"ℹ️ Project root already in sys.path:\n   {PROJECT_ROOT}")


from src.models.unet import UNet
from src.pruning.model_inspect import model_to_dataframe_with_l1, get_pruning_masks_blockwise, compute_actual_prune_ratios
from src.pruning.rebuild import rebuild_pruned_unet, find_prev_conv_name

ℹ️ Project root already in sys.path:
   /media/ttoxopeus/basic_UNet


In [4]:
# Create baseline model
model = UNet(in_ch=1, out_ch=4, features=[64, 128, 256, 512])
df = model_to_dataframe_with_l1(model)
display(df)

Unnamed: 0,Layer,Type,Shape,In Ch,Out Ch,Num Params,Mean L1,Min L1,Max L1,L1 Std,Block Ratio,Post-Prune Ratio
0,encoders.0.net.0,Conv2d,"(64, 1, 3, 3)",1.0,64.0,576,1.542949,1.020199,2.098652,0.253058,,
1,encoders.0.net.3,Conv2d,"(64, 64, 3, 3)",64.0,64.0,36864,12.042466,11.409177,12.74559,0.29756,,
2,encoders.1.net.0,Conv2d,"(128, 64, 3, 3)",64.0,128.0,73728,12.004503,11.072059,12.81948,0.32425,,
3,encoders.1.net.3,Conv2d,"(128, 128, 3, 3)",128.0,128.0,147456,16.97381,16.42378,17.564632,0.275449,,
4,encoders.2.net.0,Conv2d,"(256, 128, 3, 3)",128.0,256.0,294912,16.99577,16.111313,17.875732,0.299822,,
5,encoders.2.net.3,Conv2d,"(256, 256, 3, 3)",256.0,256.0,589824,24.004383,23.349602,25.077776,0.292504,,
6,encoders.3.net.0,Conv2d,"(512, 256, 3, 3)",256.0,512.0,1179648,24.002745,23.199892,24.840578,0.297298,,
7,encoders.3.net.3,Conv2d,"(512, 512, 3, 3)",512.0,512.0,2359296,33.960526,33.220943,34.96822,0.292697,,
8,bottleneck.net.0,Conv2d,"(1024, 512, 3, 3)",512.0,1024.0,4718592,33.940369,33.008186,34.70388,0.274648,,
9,bottleneck.net.3,Conv2d,"(1024, 1024, 3, 3)",1024.0,1024.0,9437184,48.011463,46.962585,49.167828,0.286315,,


In [5]:
block_ratios = {
    # --- Encoder DoubleConvs ---
    "encoders.0": 0.1,
    "encoders.1": 0.2,
    "encoders.2": 0.3,
    "encoders.3": 0.4,

    # --- Bottleneck ---
    "bottleneck": 0.5,

    # --- Decoder DoubleConvs only (skip ConvTranspose2d ones) ---
    "decoders.1": 0.4,
    "decoders.3": 0.3,
    "decoders.5": 0.2,
    "decoders.7": 0.1,
}

masks = get_pruning_masks_blockwise(df, block_ratios, default_ratio=0.25)

Block encoders.0 | Layer encoders.0.net.0 | ratio=0.10 | threshold=1.1280
Block encoders.0 | Layer encoders.0.net.3 | ratio=0.10 | threshold=11.5428
Block encoders.1 | Layer encoders.1.net.0 | ratio=0.20 | threshold=11.4215
Block encoders.1 | Layer encoders.1.net.3 | ratio=0.20 | threshold=16.6520
Block encoders.2 | Layer encoders.2.net.0 | ratio=0.30 | threshold=16.6406
Block encoders.2 | Layer encoders.2.net.3 | ratio=0.30 | threshold=23.8681
Block encoders.3 | Layer encoders.3.net.0 | ratio=0.40 | threshold=23.8562
Block encoders.3 | Layer encoders.3.net.3 | ratio=0.40 | threshold=33.9199
Block bottleneck.net | Layer bottleneck.net.0 | ratio=0.25 | threshold=33.4321
Block bottleneck.net | Layer bottleneck.net.3 | ratio=0.25 | threshold=47.5139
Block decoders.1 | Layer decoders.1.net.0 | ratio=0.40 | threshold=47.7930
Block decoders.1 | Layer decoders.1.net.3 | ratio=0.40 | threshold=33.7407
Block decoders.3 | Layer decoders.3.net.0 | ratio=0.30 | threshold=33.5337
Block decoders.3 |

In [None]:
pruned_model = rebuild_pruned_unet(
    model,
    masks, 
    save_path="/media/ttoxopeus/basic_UNet/results/UNet_ACDC/exp1/pruned/pruned_model.pth")

🔧 Rebuilding pruned UNet architecture...
Encoder features (after pruning): [57, 102, 179, 307]
Bottleneck out_channels: 768
💾 Saved pruned model to: /media/ttoxopeus/basic_UNet/results/UNet_ACDC/exp1/pruned
✅ UNet successfully rebuilt.


In [11]:
post_ratios = compute_actual_prune_ratios(model, pruned_model)


In [12]:
df_pruned = model_to_dataframe_with_l1(
    pruned_model,
    remove_nan_layers=True,
    block_ratios=block_ratios,
    post_prune_ratios=post_ratios
)
display(df_pruned)

Unnamed: 0,Layer,Type,Shape,In Ch,Out Ch,Num Params,Mean L1,Min L1,Max L1,L1 Std,Block Ratio,Post-Prune Ratio
0,encoders.0.net.0,Conv2d,"(57, 1, 3, 3)",1.0,57.0,513,1.546916,0.774313,2.057778,0.297653,0.1,0.1094
1,encoders.0.net.3,Conv2d,"(57, 57, 3, 3)",57.0,57.0,29241,10.657154,10.057249,11.190207,0.282252,0.1,0.1094
2,encoders.1.net.0,Conv2d,"(102, 57, 3, 3)",57.0,102.0,52326,10.632481,9.881829,11.336972,0.277702,0.2,0.2031
3,encoders.1.net.3,Conv2d,"(102, 102, 3, 3)",102.0,102.0,93636,13.493286,12.710206,14.209844,0.263084,0.2,0.2031
4,encoders.2.net.0,Conv2d,"(179, 102, 3, 3)",102.0,179.0,164322,13.521672,12.914568,14.091743,0.258456,0.3,0.3008
5,encoders.2.net.3,Conv2d,"(179, 179, 3, 3)",179.0,179.0,288369,16.767097,16.065771,17.299786,0.226833,0.3,0.3008
6,encoders.3.net.0,Conv2d,"(307, 179, 3, 3)",179.0,307.0,494577,16.765408,16.133646,17.492386,0.243084,0.4,0.4004
7,encoders.3.net.3,Conv2d,"(307, 307, 3, 3)",307.0,307.0,848241,20.374737,19.750858,20.969288,0.226352,0.4,0.4004
8,bottleneck.net.0,Conv2d,"(768, 307, 3, 3)",307.0,614.0,2121984,20.360487,19.677406,21.134785,0.224747,0.5,0.4004
9,bottleneck.net.3,Conv2d,"(768, 768, 3, 3)",614.0,614.0,5308416,36.016911,35.133488,36.638866,0.239787,0.5,0.4004
