In [0]:
%load_ext autoreload
%autoreload 2
# Enables autoreload; learn more at https://docs.databricks.com/en/files/workspace-modules.html#autoreload-for-python-modules
# To disable autoreload; run %autoreload 0

In [0]:
import sys
import os
sys.path.append(os.path.abspath('..'))

import torch
import torch.nn as nn
import torch.optim as optim
from datasets.utils.logging import disable_progress_bar
disable_progress_bar()
os.environ["HF_DATASETS_CACHE"] = "/dbfs/hf_datasets"
os.environ["TOKENIZERS_PARALLELISM"] = "false" 

from trainer import Trainer, TrainingArguments
from bacp import BaCPTrainer, BaCPTrainingArguments
from utils import *
from constants import *

from ablation_modules import TemperatureSweep, LearningRateSweep, BaCPLearningRateSweep, BaCPDataViewSweep

device = get_device()
print(f"{device = }")


In [0]:
model_name = 'vgg11'
model_task = 'cifar10'
finetuned_weights = f"/dbfs/research/{model_name}/{model_task}/{model_name}_{model_task}_baseline.pt"

lr_sweeper = LearningRateSweep(
    model_name=model_name,
    model_task=model_task,
    batch_size=BATCH_SIZE,
    finetuned_weights=finetuned_weights,
    pruning_type='magnitude_pruning',
    target_sparsity=TARGET_SPARSITY_LOW,
    sparsity_scheduler='cubic',
)
lr_sweeper.sweep()
print(lr_sweeper.history)

In [0]:
model_name = 'vgg11'
model_task = 'cifar10'
finetuned_weights = f"/dbfs/research/{model_name}/{model_task}/{model_name}_{model_task}_baseline.pt"

lr_sweeper = BaCPLearningRateSweep(
    model_name=model_name,
    model_task=model_task,
    batch_size=BATCH_SIZE,
    finetuned_weights=finetuned_weights,
    pruning_type='magnitude_pruning',
    target_sparsity=TARGET_SPARSITY_LOW,
    sparsity_scheduler='cubic',
    epochs=1,
    retraining_epoch=1,
    finetune_epochs=1,

)
lr_sweeper.sweep()
print(lr_sweeper.history)

In [0]:
model_name = 'vgg11'
model_task = 'cifar10'
finetuned_weights = f"/dbfs/research/{model_name}/{model_task}/{model_name}_{model_task}_baseline.pt"

lr_sweeper = BaCPLearningRateSweep(
    model_name=model_name,
    model_task=model_task,
    batch_size=BATCH_SIZE,
    finetuned_weights=finetuned_weights,
    pruning_type='magnitude_pruning',
    target_sparsity=TARGET_SPARSITY_LOW,
    sparsity_scheduler='cubic',
    epochs=1,
    recovery_epochs=10,
    finetune_epochs=10,
)
lr_sweeper.sweep()
print(lr_sweeper.history)

In [0]:
model_name = 'vgg11'
model_task = 'cifar10'
finetuned_weights = f"/dbfs/research/{model_name}/{model_task}/{model_name}_{model_task}_baseline.pt"

dv_sweeper = BaCPDataViewSweep(
    model_name=model_name,
    model_task=model_task,
    batch_size=BATCH_SIZE,
    opt_type_and_lr=('sgd', 0.1),
    finetune_opt_type_and_lr=('adamw', 0.0001),
    finetuned_weights=finetuned_weights,

    pruning_type='magnitude_pruning',
    target_sparsity=TARGET_SPARSITY_LOW,
    sparsity_scheduler='cubic',
    epochs=1,
    finetune_epochs=10,
)
dv_sweeper.sweep()
print(dv_sweeper.history)

In [0]:
MODEL_NAME = 'vgg11'
MODEL_TASK = 'cifar10'
finetuned_weights = f"/dbfs/research/{MODEL_NAME}/{MODEL_TASK}/{MODEL_NAME}_{MODEL_TASK}_baseline.pt"

bacp_training_args = BaCPTrainingArguments(
    model_name=MODEL_NAME,
    model_task=MODEL_TASK,
    batch_size=BATCH_SIZE,
    optimizer_type='sgd',
    learning_rate=0.1,
    pruning_type='magnitude_pruning',
    target_sparsity=TARGET_SPARSITY_LOW,
    sparsity_scheduler='cubic',
    finetuned_weights=finetuned_weights,
    learning_type='bacp_pruning'
)
bacp_trainer = BaCPTrainer(bacp_training_args=bacp_training_args)
if True:
    bacp_trainer.train()

# Finetuning Phase
bacp_trainer.generate_mask_from_model()
training_args = TrainingArguments(
    model_name=bacp_trainer.model_name,
    model_task=bacp_trainer.model_task,
    batch_size=bacp_trainer.batch_size,
    optimizer_type='adamw',
    learning_rate=0.0001,
    pruner=bacp_trainer.get_pruner(),
    pruning_type=bacp_trainer.pruning_type,
    target_sparsity=bacp_trainer.target_sparsity,
    epochs=50,
    finetuned_weights=bacp_trainer.save_path,
    finetune=True,
    learning_type="bacp_finetune",
)
trainer = Trainer(training_args)
if True:
    trainer.train()

metrics = trainer.evaluate()
print(f"\n{metrics}")


In [0]:
MODEL_NAME = 'vgg11'
MODEL_TASK = 'cifar10'
finetuned_weights = f"/dbfs/research/{MODEL_NAME}/{MODEL_TASK}/{MODEL_NAME}_{MODEL_TASK}_baseline.pt"

bacp_training_args = BaCPTrainingArguments(
    model_name=MODEL_NAME,
    model_task=MODEL_TASK,
    batch_size=BATCH_SIZE,
    optimizer_type='sgd',
    learning_rate=0.1,
    pruning_type='magnitude_pruning',
    target_sparsity=TARGET_SPARSITY_HIGH,
    sparsity_scheduler='cubic',
    finetuned_weights=finetuned_weights,
    learning_type='bacp_pruning'
)
bacp_trainer = BaCPTrainer(bacp_training_args=bacp_training_args)
if True:
    bacp_trainer.train()

# Finetuning Phase
bacp_trainer.generate_mask_from_model()
training_args = TrainingArguments(
    model_name=bacp_trainer.model_name,
    model_task=bacp_trainer.model_task,
    batch_size=bacp_trainer.batch_size,
    optimizer_type='adamw',
    learning_rate=0.0001,
    pruner=bacp_trainer.get_pruner(),
    pruning_type=bacp_trainer.pruning_type,
    target_sparsity=bacp_trainer.target_sparsity,
    epochs=50,
    finetuned_weights=bacp_trainer.save_path,
    finetune=True,
    learning_type="bacp_finetune",
)
trainer = Trainer(training_args)
if True:
    trainer.train()

metrics = trainer.evaluate()
print(f"\n{metrics}")


In [0]:
MODEL_NAME = 'vgg11'
MODEL_TASK = 'cifar10'

training_args = TrainingArguments(
    model_name=MODEL_NAME,
    model_task=MODEL_TASK,
    batch_size=BATCH_SIZE,
    optimizer_type='sgd',
    learning_rate=0.01,
    epochs=50,
    learning_type='contrastive_baseline',
    criterion_type='contrastive'
)
trainer = Trainer(training_args=training_args)
if False:
    trainer.train()


training_args = TrainingArguments(
    model_name=MODEL_NAME,
    model_task=MODEL_TASK,
    batch_size=BATCH_SIZE,
    optimizer_type='sgd',
    learning_rate=0.01,
    finetuned_weights='/dbfs/research/vgg11/cifar10/vgg11_cifar10_contrastive_baseline.pt',
    epochs=50,
    learning_type='contrastive_baseline_finetune',
    finetune=True,
)
trainer = Trainer(training_args=training_args)
if True:
    trainer.train()

In [0]:
MODEL_NAME = 'vgg11'
MODEL_TASK = 'cifar10'
finetuned_weights = '/dbfs/research/vgg11/cifar10/vgg11_cifar10_contrastive_baseline.pt'

bacp_training_args = BaCPTrainingArguments(
    model_name=MODEL_NAME,
    model_task=MODEL_TASK,
    batch_size=BATCH_SIZE,
    optimizer_type='sgd',
    learning_rate=0.1,
    pruning_type='magnitude_pruning',
    target_sparsity=TARGET_SPARSITY_LOW,
    sparsity_scheduler='cubic',
    finetuned_weights=finetuned_weights,
    learning_type='bacp_pruning'
)
bacp_trainer = BaCPTrainer(bacp_training_args=bacp_training_args)
if True:
    bacp_trainer.train()

# Finetuning Phase
bacp_trainer.generate_mask_from_model()
training_args = TrainingArguments(
    model_name=bacp_trainer.model_name,
    model_task=bacp_trainer.model_task,
    batch_size=bacp_trainer.batch_size,
    optimizer_type='sgd',
    learning_rate=0.01,
    pruner=bacp_trainer.get_pruner(),
    pruning_type=bacp_trainer.pruning_type,
    target_sparsity=bacp_trainer.target_sparsity,
    epochs=50,
    finetuned_weights=bacp_trainer.save_path,
    finetune=True,
    learning_type="bacp_finetune",
)
trainer = Trainer(training_args)
if True:
    trainer.train()

metrics = trainer.evaluate()
print(f"\n{metrics}")


In [0]:
MODEL_NAME = 'vgg11'
MODEL_TASK = 'cifar10'
finetuned_weights = '/dbfs/research/vgg11/cifar10/vgg11_cifar10_contrastive_baseline.pt'

bacp_training_args = BaCPTrainingArguments(
    model_name=MODEL_NAME,
    model_task=MODEL_TASK,
    batch_size=BATCH_SIZE,
    optimizer_type='sgd',
    learning_rate=0.1,
    pruning_type='magnitude_pruning',
    target_sparsity=TARGET_SPARSITY_MID,
    sparsity_scheduler='cubic',
    finetuned_weights=finetuned_weights,
    learning_type='bacp_pruning'
)
bacp_trainer = BaCPTrainer(bacp_training_args=bacp_training_args)
if True:
    bacp_trainer.train()

# Finetuning Phase
bacp_trainer.generate_mask_from_model()
training_args = TrainingArguments(
    model_name=bacp_trainer.model_name,
    model_task=bacp_trainer.model_task,
    batch_size=bacp_trainer.batch_size,
    optimizer_type='sgd',
    learning_rate=0.01,
    pruner=bacp_trainer.get_pruner(),
    pruning_type=bacp_trainer.pruning_type,
    target_sparsity=bacp_trainer.target_sparsity,
    epochs=50,
    finetuned_weights=bacp_trainer.save_path,
    finetune=True,
    learning_type="bacp_finetune",
)
trainer = Trainer(training_args)
if True:
    trainer.train()

metrics = trainer.evaluate()
print(f"\n{metrics}")


In [0]:
MODEL_NAME = 'vgg11'
MODEL_TASK = 'cifar10'
finetuned_weights = '/dbfs/research/vgg11/cifar10/vgg11_cifar10_contrastive_baseline.pt'

bacp_training_args = BaCPTrainingArguments(
    model_name=MODEL_NAME,
    model_task=MODEL_TASK,
    batch_size=BATCH_SIZE,
    optimizer_type='sgd',
    learning_rate=0.1,
    pruning_type='magnitude_pruning',
    target_sparsity=TARGET_SPARSITY_HIGH,
    sparsity_scheduler='cubic',
    finetuned_weights=finetuned_weights,
    learning_type='bacp_pruning'
)
bacp_trainer = BaCPTrainer(bacp_training_args=bacp_training_args)
if True:
    bacp_trainer.train()

# Finetuning Phase
bacp_trainer.generate_mask_from_model()
training_args = TrainingArguments(
    model_name=bacp_trainer.model_name,
    model_task=bacp_trainer.model_task,
    batch_size=bacp_trainer.batch_size,
    optimizer_type='sgd',
    learning_rate=0.01,
    pruner=bacp_trainer.get_pruner(),
    pruning_type=bacp_trainer.pruning_type,
    target_sparsity=bacp_trainer.target_sparsity,
    epochs=50,
    finetuned_weights=bacp_trainer.save_path,
    finetune=True,
    learning_type="bacp_finetune",
)
trainer = Trainer(training_args)
if True:
    trainer.train()

metrics = trainer.evaluate()
print(f"\n{metrics}")


In [0]:
MODEL_NAME = 'vgg11'
MODEL_TASK = 'cifar10'
finetuned_weights = '/dbfs/research/vgg11/cifar10/vgg11_cifar10_contrastive_baseline.pt'

bacp_training_args = BaCPTrainingArguments(
    model_name=MODEL_NAME,
    model_task=MODEL_TASK,
    batch_size=BATCH_SIZE,
    optimizer_type='sgd',
    learning_rate=0.1,
    pruning_type='movement_pruning',
    target_sparsity=TARGET_SPARSITY_LOW,
    sparsity_scheduler='cubic',
    finetuned_weights=finetuned_weights,
    learning_type='bacp_pruning'
)
bacp_trainer = BaCPTrainer(bacp_training_args=bacp_training_args)
if False:
    bacp_trainer.train()

# Finetuning Phase
bacp_trainer.generate_mask_from_model()
training_args = TrainingArguments(
    model_name=bacp_trainer.model_name,
    model_task=bacp_trainer.model_task,
    batch_size=bacp_trainer.batch_size,
    optimizer_type='adamw',
    learning_rate=0.00,
    pruner=bacp_trainer.get_pruner(),
    pruning_type=bacp_trainer.pruning_type,
    target_sparsity=bacp_trainer.target_sparsity,
    epochs=50,
    finetuned_weights=bacp_trainer.save_path,
    finetune=True,
    learning_type="bacp_finetune",
)
trainer = Trainer(training_args)
if True:
    trainer.train()

metrics = trainer.evaluate()
print(f"\n{metrics}")


In [0]:
MODEL_NAME = 'vgg11'
MODEL_TASK = 'cifar10'
finetuned_weights = '/dbfs/research/vgg11/cifar10/vgg11_cifar10_contrastive_baseline.pt'

bacp_training_args = BaCPTrainingArguments(
    model_name=MODEL_NAME,
    model_task=MODEL_TASK,
    batch_size=BATCH_SIZE,
    optimizer_type='sgd',
    learning_rate=0.1,
    pruning_type='movement_pruning',
    target_sparsity=TARGET_SPARSITY_MID,
    sparsity_scheduler='cubic',
    finetuned_weights=finetuned_weights,
    learning_type='bacp_pruning'
)
bacp_trainer = BaCPTrainer(bacp_training_args=bacp_training_args)
if True:
    bacp_trainer.train()

# Finetuning Phase
bacp_trainer.generate_mask_from_model()
training_args = TrainingArguments(
    model_name=bacp_trainer.model_name,
    model_task=bacp_trainer.model_task,
    batch_size=bacp_trainer.batch_size,
    optimizer_type='sgd',
    learning_rate=0.01,
    pruner=bacp_trainer.get_pruner(),
    pruning_type=bacp_trainer.pruning_type,
    target_sparsity=bacp_trainer.target_sparsity,
    epochs=50,
    finetuned_weights=bacp_trainer.save_path,
    finetune=True,
    learning_type="bacp_finetune",
)
trainer = Trainer(training_args)
if True:
    trainer.train()

metrics = trainer.evaluate()
print(f"\n{metrics}")


In [0]:
MODEL_NAME = 'vgg11'
MODEL_TASK = 'cifar10'
finetuned_weights = '/dbfs/research/vgg11/cifar10/vgg11_cifar10_contrastive_baseline.pt'

bacp_training_args = BaCPTrainingArguments(
    model_name=MODEL_NAME,
    model_task=MODEL_TASK,
    batch_size=BATCH_SIZE,
    optimizer_type='sgd',
    learning_rate=0.1,
    pruning_type='movement_pruning',
    target_sparsity=TARGET_SPARSITY_HIGH,
    sparsity_scheduler='cubic',
    finetuned_weights=finetuned_weights,
    learning_type='bacp_pruning'
)
bacp_trainer = BaCPTrainer(bacp_training_args=bacp_training_args)
if True:
    bacp_trainer.train()

# Finetuning Phase
bacp_trainer.generate_mask_from_model()
training_args = TrainingArguments(
    model_name=bacp_trainer.model_name,
    model_task=bacp_trainer.model_task,
    batch_size=bacp_trainer.batch_size,
    optimizer_type='sgd',
    learning_rate=0.01,
    pruner=bacp_trainer.get_pruner(),
    pruning_type=bacp_trainer.pruning_type,
    target_sparsity=bacp_trainer.target_sparsity,
    epochs=50,
    finetuned_weights=bacp_trainer.save_path,
    finetune=True,
    learning_type="bacp_finetune",
)
trainer = Trainer(training_args)
if True:
    trainer.train()

metrics = trainer.evaluate()
print(f"\n{metrics}")
