In [0]:
%load_ext autoreload
%autoreload 2
# Enables autoreload; learn more at https://docs.databricks.com/en/files/workspace-modules.html#autoreload-for-python-modules
# To disable autoreload; run %autoreload 0

In [0]:
import os
import sys
import warnings

sys.path.append(os.path.abspath('..'))
wandb_api_key = dbutils.secrets.get(scope='haroon-scope', key='WANDA_API_KEY')
os.environ["WANDB_API_KEY"] = wandb_api_key
warnings.filterwarnings('ignore')

from trainer import Trainer, TrainingArguments
from bacp import BaCPTrainingArguments, BaCPTrainer
from utils import set_seed

In [0]:

args = TrainingArguments(
    model_name='resnet34',
    model_type='cv',
    dataset_name='cifar10',
    num_classes=10,
    batch_size=512,
    optimizer_type='sgd',
    learning_rate=0.1,
    epochs=100,
    scheduler_type='linear_with_warmup',
)
trainer = Trainer(args)
# trainer.train()
trainer.evaluate()

In [0]:
from trainer import Trainer, TrainingArguments

set_seed()
args = TrainingArguments(
    model_name='resnet34',
    model_type='cv',
    dataset_name='cifar10',
    num_classes=10,
    batch_size=512,
    optimizer_type='sgd',
    learning_rate=0.1,

    pruning_type='magnitude_pruning',
    target_sparsity=0.9999,
    sparsity_scheduler='cubic',
    recovery_epochs=10,
    retrain=True,
    dyrelu_enabled=False,

    trained_weights='/dbfs/research/bacp/resnet34/cifar10/resnet34_cifar10_.pt',
    experiment_type='wo_dyrelu'

)
trainer = Trainer(args)
trainer.model
trainer.train()
trainer.evaluate()

In [0]:
from trainer import Trainer, TrainingArguments

set_seed()
args = TrainingArguments(
    model_name='resnet34',
    model_type='cv',
    dataset_name='cifar10',
    num_classes=10,
    batch_size=512,
    optimizer_type='sgd',
    learning_rate=0.1,

    pruning_type='magnitude_pruning',
    target_sparsity=0.9999,
    sparsity_scheduler='cubic',
    recovery_epochs=10,
    retrain=True,
    dyrelu_enabled=True,

    trained_weights='/dbfs/research/bacp/resnet34/cifar10/resnet34_cifar10_.pt',
    experiment_type='w_dyrelu'

)
trainer = Trainer(args)
trainer.model
trainer.train()
trainer.evaluate()

In [0]:
from torchinfo import summary
summary(trainer.model, input_size=(1, 3, 32, 32 ))

In [0]:
from bacp import BaCPTrainingArguments, BaCPTrainer
from trainer import Trainer, TrainingArguments
from utils import set_seed

set_seed()
bacp_args = BaCPTrainingArguments(
    model_name='resnet34',
    model_type='cv',
    dataset_name='cifar10',
    num_classes=10,
    batch_size=512,
    optimizer_type='sgd',
    learning_rate=0.1,
    tau=0.10,
    trained_weights='/dbfs/research/bacp/resnet34/cifar10/resnet34_cifar10_.pt',
    experiment_type='bacp_pretrain_wo_dyrelu',
    enable_tqdm=True,

    pruning_type='magnitude_pruning',
    target_sparsity=0.9999,
    sparsity_scheduler='cubic',
    recovery_epochs=10,
    retrain=True,

    dyrelu_enabled=False,
)
bacp_trainer = BaCPTrainer(bacp_args)
bacp_trainer.train()


training_args = TrainingArguments(
    model_name=bacp_trainer.model_name,
    model_type=bacp_trainer.model_type,
    dataset_name=bacp_trainer.dataset_name,
    num_classes=bacp_trainer.num_classes,
    batch_size=bacp_trainer.batch_size,
    optimizer_type=bacp_trainer.optimizer_type,
    learning_rate=0.01,
    epochs=100,
    trained_weights=bacp_trainer.save_path,
    experiment_type='bacp_finetune_wo_dyrelu',

    pruning_module=bacp_trainer.get_pruner(),
    pruning_type=bacp_trainer.pruning_type,
    target_sparsity=bacp_trainer.target_sparsity,
    sparsity_scheduler=bacp_trainer.sparsity_scheduler,

    dyrelu_enabled=False,
)
trainer = Trainer(training_args)
trainer.train()
trainer.evaluate()


In [0]:
from bacp import BaCPTrainingArguments, BaCPTrainer
from trainer import Trainer, TrainingArguments
from utils import set_seed

set_seed()
bacp_args = BaCPTrainingArguments(
    model_name='resnet34',
    model_type='cv',
    dataset_name='cifar10',
    num_classes=10,
    batch_size=512,
    optimizer_type='sgd',
    learning_rate=0.1,
    tau=0.10,
    trained_weights='/dbfs/research/bacp/resnet34/cifar10/resnet34_cifar10_.pt',
    experiment_type='bacp_pretrain_w_dyrelu',
    enable_tqdm=True,

    pruning_type='magnitude_pruning',
    target_sparsity=0.9999,
    sparsity_scheduler='cubic',
    recovery_epochs=10,
    retrain=True,

    dyrelu_enabled=True,
)
bacp_trainer = BaCPTrainer(bacp_args)
# bacp_trainer.train()

set_seed()
training_args = TrainingArguments(
    model_name=bacp_trainer.model_name,
    model_type=bacp_trainer.model_type,
    dataset_name=bacp_trainer.dataset_name,
    num_classes=bacp_trainer.num_classes,
    batch_size=bacp_trainer.batch_size,
    optimizer_type=bacp_trainer.optimizer_type,
    learning_rate=0.01,
    epochs=100,
    trained_weights=bacp_trainer.save_path,
    experiment_type='bacp_finetune_w_dyrelu',

    pruning_module=bacp_trainer.get_pruner(),

    dyrelu_enabled=True,
)
trainer = Trainer(training_args)
trainer.train()
trainer.evaluate()

In [0]:
from bacp import BaCPTrainingArguments, BaCPTrainer
from trainer import Trainer, TrainingArguments
from utils import set_seed

set_seed()
bacp_args = BaCPTrainingArguments(
    model_name='resnet34',
    model_type='cv',
    dataset_name='cifar10',
    num_classes=10,
    batch_size=512,
    optimizer_type='sgd',
    learning_rate=0.1,
    tau=0.10,
    trained_weights='/dbfs/research/bacp/resnet34/cifar10/resnet34_cifar10_.pt',
    experiment_type='bacp_pretrain_w_dyrelu_phasing',
    enable_tqdm=True,

    pruning_type='magnitude_pruning',
    target_sparsity=0.9999,
    sparsity_scheduler='cubic',
    recovery_epochs=10,
    retrain=True,
    dyrelu_phasing_en=True,
)
bacp_trainer = BaCPTrainer(bacp_args)
# bacp_trainer.train()

set_seed()
training_args = TrainingArguments(
    model_name=bacp_trainer.model_name,
    model_type=bacp_trainer.model_type,
    dataset_name=bacp_trainer.dataset_name,
    num_classes=bacp_trainer.num_classes,
    batch_size=bacp_trainer.batch_size,
    optimizer_type=bacp_trainer.optimizer_type,
    learning_rate=0.005,
    epochs=100,
    trained_weights=bacp_trainer.save_path,
    experiment_type='bacp_finetune_w_dyrelu_phasing',

    pruning_module=bacp_trainer.get_pruner(),
    # dyrelu_phasing_en=False,
    dyrelu_en=True, 

)
trainer = Trainer(training_args)
trainer.train()
trainer.evaluate()

In [0]:

set_seed()
training_args = TrainingArguments(
    model_name=bacp_trainer.model_name,
    model_type=bacp_trainer.model_type,
    dataset_name=bacp_trainer.dataset_name,
    num_classes=bacp_trainer.num_classes,
    batch_size=bacp_trainer.batch_size,
    optimizer_type=bacp_trainer.optimizer_type,
    learning_rate=0.005,
    epochs=100,
    trained_weights=bacp_trainer.save_path,
    experiment_type='bacp_finetune_w_dyrelu_phasing',

    pruning_module=bacp_trainer.get_pruner(),
    # dyrelu_phasing_en=False,
    # dyrelu_en=True, 
    
)
trainer = Trainer(training_args)
# trainer.train()
trainer.evaluate()

In [0]:
from bacp import BaCPTrainingArguments, BaCPTrainer
from trainer import Trainer, TrainingArguments
from utils import set_seed

set_seed()
bacp_args = BaCPTrainingArguments(
    model_name='resnet34',
    model_type='cv',
    dataset_name='cifar10',
    num_classes=10,
    batch_size=512,
    optimizer_type='sgd',
    learning_rate=0.1,
    tau=0.07,
    trained_weights='/dbfs/research/bacp/resnet34/cifar10/resnet34_cifar10_baseline.pt',
    experiment_type='bacp_v3_with_dyrelu_phasing',
    enable_tqdm=True,

    # Pruning arguments
    pruning_type='magnitude_pruning',
    target_sparsity=0.9999,
    sparsity_scheduler='cubic',
    recovery_epochs=10,
    retrain=True,

    # Finetuning arguments for post BaCP training
    enable_finetune=True,
    ft_epochs=100,
    ft_optimizer_type='adamw',
    ft_learning_rate=0.001,

    # Enabling DyReLU adapters for activation phasing: DyReLU -> ReLU
    dyrelu_phasing_en=True,
)
bacp_trainer = BaCPTrainer(bacp_args)
bacp_trainer.finetune()

In [0]:
bacp_trainer.model

In [0]:
from pruning_factory import check_model_sparsity, check_sparsity_distribution
check_sparsity_distribution(bacp_trainer.model)

## Baseline Accuracy

In [0]:
!python ../scripts/baseline_script.py \
    --model_name resnet34  --model_type cv \
    --dataset_name cifar10 --num_classes 10

In [0]:
from trainer import Trainer, TrainingArguments

args = TrainingArguments(
    model_name='resnet50',
    model_type='cv',
    dataset_name='cifar10',
    num_classes=10,
    batch_size=512,
    optimizer_type='sgd',
    learning_rate=0.01,

    pruning_type='magnitude_pruning',
    target_sparsity=0.9995,
    sparsity_scheduler='cubic',
    recovery_epochs=10,
    retrain=True,

    dyrelu_phase_enabled=True,
)
trainer = Trainer(args)
trainer.train()

In [0]:
trainer.evaluate()

## Pruning Accuracies

In [0]:
!python ../scripts/pruning_script.py \
    --model_name resnet50  --model_type cv \
    --dataset_name cifar10 --num_classes 10 \
    --pruning_type magnitude_pruning --target_sparsity 0.9995 \
    --trained_weights /dbfs/research/resnet50/cifar10/resnet50_cifar10_baseline.pt \
    --learning_rate 0.01

In [0]:
!python ../scripts/pruning_script.py \
    --model_name resnet50  --model_type cv \
    --dataset_name cifar10 --num_classes 10 \
    --pruning_type magnitude_pruning --target_sparsity 0.9995 \
    --trained_weights /dbfs/research/resnet50/cifar10/resnet50_cifar10_baseline.pt \
    --learning_rate 0.01 --dyrelu_phase_enabled

In [0]:
from models import ClassificationAndEncoderNetwork
from utils import load_weights

model = ClassificationAndEncoderNetwork('resnet50', 10)
weights = '/dbfs/research/resnet50/cifar10/resnet50_cifar10_magnitude_pruning_0.95_pruning.pt'
load_weights(model, weights)

In [0]:
summary(model, input_size=(1, 3, 32, 32 ))

In [0]:
from unstructured_pruning import check_model_sparsity, check_sparsity_distribution

check_sparsity_distribution(model)

In [0]:
!python ../scripts/pruning_script.py \
    --model_name resnet50  --model_type cv \
    --dataset_name cifar10 --num_classes 10 \
    --pruning_type magnitude_pruning --target_sparsity 0.95 \
    --trained_weights /dbfs/research/resnet50/cifar10/resnet50_cifar10_baseline.pt \
    --learning_rate 0.01

!python ../scripts/pruning_script.py \
    --model_name resnet50  --model_type cv \
    --dataset_name cifar10 --num_classes 10 \
    --pruning_type magnitude_pruning --target_sparsity 0.97 \
    --trained_weights /dbfs/research/resnet50/cifar10/resnet50_cifar10_baseline.pt \
    --learning_rate 0.01

!python ../scripts/pruning_script.py \
    --model_name resnet50  --model_type cv \
    --dataset_name cifar10 --num_classes 10 \
    --pruning_type magnitude_pruning --target_sparsity 0.99 \
    --trained_weights /dbfs/research/resnet50/cifar10/resnet50_cifar10_baseline.pt \
    --learning_rate 0.01

In [0]:
!python ../scripts/pruning_script.py \
    --model_name resnet50  --model_type cv \
    --dataset_name cifar10 --num_classes 10 \
    --pruning_type snip_pruning --target_sparsity 0.95 \
    --trained_weights /dbfs/research/resnet50/cifar10/resnet50_cifar10_baseline.pt \
    --learning_rate 0.01

!python ../scripts/pruning_script.py \
    --model_name resnet50  --model_type cv \
    --dataset_name cifar10 --num_classes 10 \
    --pruning_type snip_pruning --target_sparsity 0.97 \
    --trained_weights /dbfs/research/resnet50/cifar10/resnet50_cifar10_baseline.pt \
    --learning_rate 0.01
    
!python ../scripts/pruning_script.py \
    --model_name resnet50  --model_type cv \
    --dataset_name cifar10 --num_classes 10 \
    --pruning_type snip_pruning --target_sparsity 0.99 \
    --trained_weights /dbfs/research/resnet50/cifar10/resnet50_cifar10_baseline.pt \
    --learning_rate 0.01

In [0]:
!python ../scripts/pruning_script.py \
    --model_name resnet50  --model_type cv \
    --dataset_name cifar10 --num_classes 10 \
    --pruning_type wanda_pruning --target_sparsity 0.95 \
    --trained_weights /dbfs/research/resnet50/cifar10/resnet50_cifar10_baseline.pt \
    --learning_rate 0.01

!python ../scripts/pruning_script.py \
    --model_name resnet50  --model_type cv \
    --dataset_name cifar10 --num_classes 10 \
    --pruning_type wanda_pruning --target_sparsity 0.97 \
    --trained_weights /dbfs/research/resnet50/cifar10/resnet50_cifar10_baseline.pt \
    --learning_rate 0.01

!python ../scripts/pruning_script.py \
    --model_name resnet50  --model_type cv \
    --dataset_name cifar10 --num_classes 10 \
    --pruning_type wanda_pruning --target_sparsity 0.99 \
    --trained_weights /dbfs/research/resnet50/cifar10/resnet50_cifar10_baseline.pt \
    --learning_rate 0.01

## BaCP

In [0]:
!python ../scripts/bacp_script.py \
    --model_name resnet50  --model_type cv \
    --dataset_name cifar10 --num_classes 10 \
    --pruning_type magnitude_pruning --target_sparsity 0.95 \
    --trained_weights /dbfs/research/resnet50/cifar10/resnet50_cifar10_baseline.pt

In [0]:
from trainer import TrainingArguments, Trainer
from unstructured_pruning import MagnitudePrune, check_model_sparsity
from models import ClassificationAndEncoderNetwork
from utils import load_weights

model = ClassificationAndEncoderNetwork('resnet50', 10) 
pruning_module = MagnitudePrune(model, 5, 0.95)

load_weights(model, '/dbfs/research/resnet50/cifar10/resnet50_cifar10_magnitude_pruning_0.95_bacp_pretraining.pt')
print(check_model_sparsity(model))
zero_masks = {name: (param != 0).float() for name, param in model.named_parameters()}
pruning_module.masks = zero_masks

training_args = TrainingArguments(
    model_name='resnet50',
    model_type='cv',
    dataset_name='cifar10',
    num_classes=10,
    batch_size=512,
    epochs=50,
    optimizer_type='sgd',
    learning_rate=0.005,
    experiment_type='testing_bacp_finetuning',
    pruning_module=pruning_module,
    trained_weights='/dbfs/research/resnet50/cifar10/resnet50_cifar10_magnitude_pruning_0.95_bacp_pretraining.pt',
)

trainer = Trainer(training_args)
trainer.train()
trainer.evaluate()

In [0]:
!python ../scripts/bacp_script.py \
    --model_name resnet50  --model_type cv \
    --dataset_name cifar10 --num_classes 10 \
    --pruning_type magnitude_pruning --target_sparsity 0.95 \
    --trained_weights /dbfs/research/resnet50/cifar10/resnet50_cifar10_baseline.pt

!python ../scripts/bacp_script.py \
    --model_name resnet50  --model_type cv \
    --dataset_name cifar10 --num_classes 10 \
    --pruning_type magnitude_pruning --target_sparsity 0.97 \
    --trained_weights /dbfs/research/resnet50/cifar10/resnet50_cifar10_baseline.pt

!python ../scripts/bacp_script.py \
    --model_name resnet50  --model_type cv \
    --dataset_name cifar10 --num_classes 10 \
    --pruning_type magnitude_pruning --target_sparsity 0.99 \
    --trained_weights /dbfs/research/resnet50/cifar10/resnet50_cifar10_baseline.pt

In [0]:
!python ../scripts/bacp_script.py \
    --model_name resnet50  --model_type cv \
    --dataset_name cifar10 --num_classes 10 \
    --pruning_type snip_pruning --target_sparsity 0.95 \
    --trained_weights /dbfs/research/resnet50/cifar10/resnet50_cifar10_baseline.pt
 
!python ../scripts/bacp_script.py \
    --model_name resnet50  --model_type cv \
    --dataset_name cifar10 --num_classes 10 \
    --pruning_type snip_pruning --target_sparsity 0.97 \
    --trained_weights /dbfs/research/resnet50/cifar10/resnet50_cifar10_baseline.pt 

!python ../scripts/bacp_script.py \
    --model_name resnet50  --model_type cv \
    --dataset_name cifar10 --num_classes 10 \
    --pruning_type snip_pruning --target_sparsity 0.99 \
    --trained_weights /dbfs/research/resnet50/cifar10/resnet50_cifar10_baseline.pt

In [0]:
!python ../scripts/bacp_script.py \
    --model_name resnet50  --model_type cv \
    --dataset_name cifar10 --num_classes 10 \
    --pruning_type wanda_pruning --target_sparsity 0.95 \
    --trained_weights /dbfs/research/resnet50/cifar10/resnet50_cifar10_baseline.pt

!python ../scripts/bacp_script.py \
    --model_name resnet50  --model_type cv \
    --dataset_name cifar10 --num_classes 10 \
    --pruning_type wanda_pruning --target_sparsity 0.97 \
    --trained_weights /dbfs/research/resnet50/cifar10/resnet50_cifar10_baseline.pt

!python ../scripts/bacp_script.py \
    --model_name resnet50  --model_type cv \
    --dataset_name cifar10 --num_classes 10 \
    --pruning_type wanda_pruning --target_sparsity 0.99 \
    --trained_weights /dbfs/research/resnet50/cifar10/resnet50_cifar10_baseline.pt