In [0]:
%load_ext autoreload
%autoreload 2
# Enables autoreload; learn more at https://docs.databricks.com/en/files/workspace-modules.html#autoreload-for-python-modules
# To disable autoreload; run %autoreload 0

In [0]:
!pip install wandb

In [0]:
import os
import sys
import warnings

sys.path.append(os.path.abspath('..'))
wandb_api_key = dbutils.secrets.get(scope='haroon-scope', key='WANDA_API_KEY')
os.environ["WANDB_API_KEY"] = wandb_api_key
warnings.filterwarnings('ignore')

from trainer import Trainer, TrainingArguments
from bacp import BaCPTrainingArguments, BaCPTrainer
from utils import set_seed

## Testing

In [0]:
import torch
import torch.nn as nn
import numpy as np
from torch.optim import SGD
from model_factory import ClassificationAndEncoderNetwork
from dataset_factory import load_cv_dataloaders
from pruning_factory import check_sparsity_distribution, check_model_sparsity, layer_check, RigLPruner
from utils import set_seed

set_seed()
model = ClassificationAndEncoderNetwork('resnet34', 10)
optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
data = load_cv_dataloaders('cifar10', 'supervised', 128, 32, 1, 24, '/dbfs/cache')
trainloader = data['trainloader']

# --- Sparsity schedule settings ---
delta_t = 100         # prune/regrow frequency (every Δt steps)

# --- Sparsity range ---
s = 0.90


T_end = 10
Tc = int((2 / 3) * T_end) or 1
vals = []
pruner = RigLPruner(model, T_end, s, 'f_decay')
print(f"{check_model_sparsity(model) = }")
for epoch in range(T_end):
    correct, total = 0, 0
    for t, (batch, label) in enumerate(trainloader):
        # Move data
        batch, label = batch.to('cuda'), label.to('cuda')

        # Forward + loss
        optimizer.zero_grad()
        outputs = model(batch)
        loss = torch.nn.CrossEntropyLoss()(outputs, label)
        loss.backward()

        if epoch <= Tc:
            if t % delta_t == 0:
                pruner.ratio_step(epoch, Tc)
                pruner.prune(model)
                vals.append(pruner.s_target)

        optimizer.step()
        pruner.apply_mask(model)

        preds = outputs.argmax(dim=1)
        correct += (preds == label).sum().item()
        total += label.size(0)
        acc = (correct / total) * 100

        if t % 100 == 0:
            print(
                f"[Epoch {epoch} | Step {t}] "
                f"Loss={loss.item():.4f}, "
                f"Acc={acc:.3}, "
                f"s_target={pruner.s_target:.5f}, "
                f"Sparsity={check_model_sparsity(model):.5f}, "
                f"s_curr={pruner.s_curr:.5f}"
            )

    epoch_acc = correct / total
    print(f"Epoch {epoch} done. Accuracy: {epoch_acc:.4f}\n")



In [0]:
import matplotlib.pyplot as plt

plt.plot(vals)

In [0]:
model.eval()
correct, total = 0, 0
with torch.no_grad():
    for batch, label in data['testloader']:
        batch, label = batch.to('cuda'), label.to('cuda')
        outputs = model(batch)
        correct += (outputs.argmax(dim=1) == label).sum().item()
        total += label.size(0)
print(f"Test Accuracy: {100 * correct / total:.2f}%")

In [0]:

prune_imps, prune_cache = [], {}
regrow_imps, regrow_cache = [], {}

for name, param in model.named_parameters():
    if name in self.masks:
        mask = self.masks[name]

        # Active weight importance
        prune_imp = torch.abs(param) * mask
        prune_cache[name] = prune_imp
        prune_imps.append(prune_imp.view(-1))

        # Inactive gradient importance
        grad = param.grad
        if grad is None:
            regrow_imp = torch.zeros_like(prune_imp)
        else:
            regrow_imp = torch.abs(grad.data) * (1.0 - mask)
        regrow_cache[name] = regrow_imp
        regrow_imps.append(regrow_imp.view(-1))

global_prune_imps = torch.cat(prune_imps)
global_regrow_imps = torch.cat(regrow_imps)
total_weights = global_prune_imps.numel()

k = int(self.s_target * total_weights)
print(f'Pruning {k} weights out of {total_weights}')
if k == 0:
    print(f'Warning: k is zero, no pruning will be performed')
    return

prune_values, _ = torch.topk(-global_prune_imps, k)
prune_thresh = -prune_values.min().item()

regrow_values, _ = torch.topk(global_regrow_imps, k)
regrow_thresh = regrow_values.min().item()

for name, param in self.model.named_parameters():
    if name in self.masks:
        new_mask = self.masks[name].clone()
        sparsity = 1.0 - float(new_mask.sum()) / float(new_mask.numel())
        print(f'Layer {name} sparsity: {sparsity}')
        
        new_mask[prune_cache[name] < prune_thresh] = 0
        new_mask[regrow_cache[name] > regrow_thresh] = 1
        self.masks[name] = new_mask

print(check_model_sparsity(self.model))

## TEST A

In [0]:
!python ../scripts/baseline_script.py \
    --model_name resnet34  --model_type cv \
    --dataset_name cifar10 --num_classes 10 \
    --log_to_wandb  --experiment_type baseline_standard_relu \
    --databricks_env

In [0]:
!python ../scripts/pruning_script.py \
    --model_name resnet34  --model_type cv \
    --dataset_name cifar10 --num_classes 10 \
    --log_to_wandb  --experiment_type baseline_standard_relu \
    --databricks_env \
    --trained_weights /dbfs/research/bacp/resnet34/cifar10/20251223/resnet34_cifar10_baseline_standard_relu_20251223_101434.pt \
    --pruning_type magnitude_pruning --target_sparsity 0.9995 --sparsity_scheduler cubic

## TEST B

In [0]:
!python ../scripts/baseline_script.py \
    --model_name resnet34  --model_type cv \
    --dataset_name cifar10 --num_classes 10 \
    --log_to_wandb  --experiment_type baseline_dyrelu_phasing \
    --databricks_env --dyrelu_phasing_en

In [0]:
!python ../scripts/pruning_script.py \
    --model_name resnet34  --model_type cv \
    --dataset_name cifar10 --num_classes 10 \
    --log_to_wandb  --experiment_type baseline_dyrelu_phasing \
    --databricks_env --dyrelu_phasing_en \
    --trained_weights /dbfs/research/bacp/resnet34/cifar10/20251223/resnet34_cifar10_baseline_dyrelu_phasing_20251223_100219.pt \
    --pruning_type magnitude_pruning --target_sparsity 0.9995 --sparsity_scheduler cubic

## TEST C

In [0]:
!python ../scripts/baseline_script.py \
    --model_name resnet34  --model_type cv \
    --dataset_name cifar10 --num_classes 10 \
    --log_to_wandb  --experiment_type baseline_dyrelu \
    --databricks_env --dyrelu_en

In [0]:
!python ../scripts/pruning_script.py \
    --model_name resnet34  --model_type cv \
    --dataset_name cifar10 --num_classes 10 \
    --log_to_wandb  --experiment_type baseline_dyrelu \
    --databricks_env --dyrelu_en \
    --trained_weights /dbfs/research/bacp/resnet34/cifar10/20251223/resnet34_cifar10_baseline_dyrelu_20251223_102425.pt \
    --pruning_type magnitude_pruning --target_sparsity 0.9995 --sparsity_scheduler cubic

## Pruning Accuracies

In [0]:
from model_factory import ClassificationAndEncoderNetwork
from utils import load_weights

model = ClassificationAndEncoderNetwork('resnet50', 10)

In [0]:
import torch
import torch.nn as nn
import types

def apply_weight_sharing_resnet(model, R=2):
    master_idx = R - 1
    layer_names = ['layer1', 'layer2', 'layer3', 'layer4']

    for layer_name in layer_names:
        if not hasattr(model.model, layer_name): continue

        layer_container = getattr(model.model, layer_name)
        num_blocks = len(layer_container)
        if num_blocks <= R:
            continue

        master_block = layer_container[master_idx]
        for i in range(master_idx + 1, num_blocks):
            block = layer_container[i]
            for name, m_child in master_block.named_children():
                if not hasattr(block, name): continue

                s_child = getattr(block, name)
                if isinstance(m_child, nn.Conv2d):
                    if m_child.weight.shape == s_child.weight.shape:
                        del s_child.weight
                        s_child.weight = m_child.weight

                        if not hasattr(s_child, 'scaler'):
                            s_child.scaler = nn.Parameter(torch.tensor(1.0).to(m_child.weight.device))

                        def new_forward(self, x):
                            return self._conv_forward(x, self.weight * self.scaler, self.bias)
                        
                        s_child.forward = types.MethodType(new_forward, s_child)

apply_weight_sharing_resnet(model)

In [0]:
import torch.nn as nn

def weight_sharing_test(model):
    unique_weights = set()
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d):

            addr = id(module.weight)
            unique_weights.add(addr)

            has_scaler = hasattr(module, 'scaler')
            scaler_status = "Yes" if has_scaler else 'No'

            if "layer" in name and 'conv' in name:
                print(f"Layer: {name:<30} | Weight Addr: {addr} | Has Scaler: {scaler_status}")

    print(f"\nTotal Conv2d Layers: {len([m for m in model.modules() if isinstance(m, nn.Conv2d)])}")
    print(f"Unique Weight Tensors: {len(unique_weights)}")

    if len(unique_weights) < len([m for m in model.modules() if isinstance(m, nn.Conv2d)]):
            print("✅ SUCCESS: Fewer unique weights than layers. Sharing is active.")
    else:
        print("❌ FAILURE: Every layer has a unique weight. Sharing FAILED.")

weight_sharing_test(model)

In [0]:
!python ../scripts/pruning_script.py \
    --model_name resnet34  --model_type cv \
    --dataset_name cifar10 --num_classes 10 \
    --pruning_type magnitude_pruning --target_sparsity 0.99 --sparsity_scheduler cubic \
    --epochs 5 --recovery_epochs 10 \
    --trained_weights /dbfs/research/bacp/resnet34/cifar10/resnet34_cifar10_baseline.pt \
    --learning_rate 0.01

In [0]:
!python ../scripts/pruning_script.py \
    --model_name resnet34  --model_type cv \
    --dataset_name cifar10 --num_classes 10 \
    --pruning_type snip_pruning --target_sparsity 0.99 --sparsity_scheduler cubic \
    --epochs 5 --recovery_epochs 10 \
    --trained_weights /dbfs/research/bacp/resnet34/cifar10/resnet34_cifar10_baseline.pt \
    --learning_rate 0.01

In [0]:


!python ../scripts/pruning_script.py \
    --model_name resnet50  --model_type cv \
    --dataset_name cifar10 --num_classes 10 \
    --pruning_type snip_pruning --target_sparsity 0.97 \
    --trained_weights /dbfs/research/resnet50/cifar10/resnet50_cifar10_baseline.pt \
    --learning_rate 0.01
    
!python ../scripts/pruning_script.py \
    --model_name resnet50  --model_type cv \
    --dataset_name cifar10 --num_classes 10 \
    --pruning_type snip_pruning --target_sparsity 0.99 \
    --trained_weights /dbfs/research/resnet50/cifar10/resnet50_cifar10_baseline.pt \
    --learning_rate 0.01

In [0]:
!python ../scripts/pruning_script.py \
    --model_name resnet50  --model_type cv \
    --dataset_name cifar10 --num_classes 10 \
    --pruning_type wanda_pruning --target_sparsity 0.95 \
    --trained_weights /dbfs/research/resnet50/cifar10/resnet50_cifar10_baseline.pt \
    --learning_rate 0.01

!python ../scripts/pruning_script.py \
    --model_name resnet50  --model_type cv \
    --dataset_name cifar10 --num_classes 10 \
    --pruning_type wanda_pruning --target_sparsity 0.97 \
    --trained_weights /dbfs/research/resnet50/cifar10/resnet50_cifar10_baseline.pt \
    --learning_rate 0.01

!python ../scripts/pruning_script.py \
    --model_name resnet50  --model_type cv \
    --dataset_name cifar10 --num_classes 10 \
    --pruning_type wanda_pruning --target_sparsity 0.99 \
    --trained_weights /dbfs/research/resnet50/cifar10/resnet50_cifar10_baseline.pt \
    --learning_rate 0.01

## BaCP

    def __init__(
        self, 
        model_name:         str, 
        num_classes:        int, 
        num_out_features:   int = None, 
        device:             str = 'cuda', 
        adapt:              bool = True, 
        pretrained:         bool = True, 
        freeze:             bool = False, 
        dyrelu_en:          bool = False,
        dyrelu_phasing_en:  bool = False
        ):

In [0]:
from model_factory import *
from utils import *
from pruning_factory import *
model = ClassificationAndEncoderNetwork('resnet34', 10, 128, 'cuda', True)
load_weights(model, '/dbfs/research/bacp/resnet34/cifar10/20251205/resnet34_cifar10_rigl_pruning_0.99_bacp_20251205_101159.pt')

check_sparsity_distribution(model)


In [0]:
from model_factory import *
from utils import *
from pruning_factory import *
model = ClassificationAndEncoderNetwork('resnet34', 10, 128, 'cuda', True)
load_weights(model, '/dbfs/research/bacp/resnet34/cifar10/20251205/resnet34_cifar10_rigl_pruning_0.99_bacp_20251205_101159_finetune.pt')

check_sparsity_distribution(model)


In [0]:
!python ../scripts/bacp_script.py \
    --model_name resnet34 --model_type cv \
    --dataset_name cifar10 --num_classes 10 \
    --pruning_type rigl_pruning --target_sparsity 0.9995 --sparsity_scheduler f_decay --recovery_epochs 0 \
    --trained_weights /dbfs/research/bacp/resnet34/cifar10/resnet34_cifar10_baseline.pt \
    --enable_finetune \
    --log_to_wandb \
    --databricks_env --dyrelu_phasing_en


59.17% w/ 99.97% sparsity using RigL on BaCP: 150 epochs (50 - 100)



In [0]:
from model_factory import *
from utils import *
from pruning_factory import *
model = ClassificationAndEncoderNetwork('resnet34', 10, 128, 'cuda', True)
load_weights(model, '/dbfs/research/bacp/resnet34/cifar10/20251210/resnet34_cifar10_rigl_pruning_0.9995_bacp_20251210_102030_finetune.pt')

check_sparsity_distribution(model)
