# ViT-B-16 Testing

In [0]:
%load_ext autoreload
%autoreload 2
# Enables autoreload; learn more at https://docs.databricks.com/en/files/workspace-modules.html#autoreload-for-python-modules
# To disable autoreload; run %autoreload 0

In [0]:
import sys
import os
sys.path.append(os.path.abspath('..'))

from bacp import BaCPTrainer, BaCPTrainingArgumentsLLM, BaCPTrainingArgumentsCNN
from models import EncoderProjectionNetwork, ClassificationNetwork
from LLM_trainer import LLMTrainer, LLMTrainingArguments
from CV_trainer import *
from dataset_utils import get_glue_data
from logger import Logger

import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoModelForSequenceClassification, AutoTokenizer

from tqdm import tqdm

from datasets.utils.logging import disable_progress_bar
disable_progress_bar()
import os
os.environ["HF_DATASETS_CACHE"] = "/dbfs/hf_datasets"
os.environ["TOKENIZERS_PARALLELISM"] = "false" 

from utils import *
from constants import *

device = get_device()
print(f"{device = }")
BATCH_SIZE_VIT = 512


## Baseline Accuracies

In [0]:
# Model initialization
model_name = "vitb16"
model_task = "cifar10"

training_args = CVTrainingArguments(
    model_name=model_name,
    model_task=model_task,
    batch_size=BATCH_SIZE_VIT,
    epochs=3,
    optimizer_type='adamw',
    scheduler_type='linear_with_warmup',
    learning_rate=0.00002,
    learning_type="baseline",
)
trainer = Trainer(training_args=training_args)
if False:
    trainer.train()

acc = trainer.evaluate()
print(f"\nAccuracy = {acc}")

## Pruning Accuracies

### Magnitude Pruning

In [0]:
# Model initialization
model_name = "vitb16"
model_task = "cifar10"

# Initializing finetuned weights path
finetuned_weights = f"/dbfs/research/{model_name}/{model_task}/{model_name}_{model_task}_baseline.pt"

# Initializing pruning args
pruning_type = "magnitude_pruning"
target_sparsity = TARGET_SPARSITY_LOW
learning_type = "pruning"

training_args = CVTrainingArguments(
    model_name=model_name,
    model_task=model_task,
    batch_size=BATCH_SIZE,
    finetuned_weights=finetuned_weights,
    pruning_type=pruning_type,
    target_sparsity=target_sparsity,
    learning_type=learning_type,
    optimizer_type='adamw',
    learning_rate=0.0001,
    sparsity_scheduler='cubic'
)
trainer = Trainer(training_args=training_args)
if False:
    trainer.train()

acc = trainer.evaluate()
print(f"\nAccuracy = {acc}")

In [0]:
# Model initialization
model_name = "vitb16"
model_task = "cifar10"

# Initializing finetuned weights path
finetuned_weights = f"/dbfs/research/{model_name}/{model_task}/{model_name}_{model_task}_baseline.pt"

# Initializing pruning args
pruning_type = "magnitude_pruning"
target_sparsity = TARGET_SPARSITY_MID
learning_type = "pruning"

training_args = CVTrainingArguments(
    model_name=model_name,
    model_task=model_task,
    batch_size=BATCH_SIZE,
    finetuned_weights=finetuned_weights,
    pruning_type=pruning_type,
    target_sparsity=target_sparsity,
    learning_type=learning_type,
    optimizer_type='adamw',
    learning_rate=0.0001,
    sparsity_scheduler='cubic'
)
trainer = Trainer(training_args=training_args)
if True:
    trainer.train()

acc = trainer.evaluate()
print(f"\nAccuracy = {acc}")

In [0]:
# Model initialization
model_name = "vitb16"
model_task = "cifar10"

# Initializing finetuned weights path
finetuned_weights = f"/dbfs/research/{model_name}/{model_task}/{model_name}_{model_task}_baseline.pt"

# Initializing pruning args
pruning_type = "magnitude_pruning"
target_sparsity = TARGET_SPARSITY_HIGH
learning_type = "pruning"

training_args = CVTrainingArguments(
    model_name=model_name,
    model_task=model_task,
    batch_size=BATCH_SIZE,
    finetuned_weights=finetuned_weights,
    pruning_type=pruning_type,
    target_sparsity=target_sparsity,
    learning_type=learning_type,
    optimizer_type='adamw',
    learning_rate=0.0001,
    sparsity_scheduler='cubic'
)
trainer = Trainer(training_args=training_args)
if True:
    trainer.train()

acc = trainer.evaluate()
print(f"\nAccuracy = {acc}")

### Movement Pruning

In [0]:
# Model initialization
model_name = "vitb16"
model_task = "cifar10"

# Initializing finetuned weights path
finetuned_weights = f"/dbfs/research/{model_name}/{model_task}/{model_name}_{model_task}_baseline.pt"

# Initializing pruning args
pruning_type = "movement_pruning"
target_sparsity = TARGET_SPARSITY_LOW
learning_type = "pruning"

training_args = CVTrainingArguments(
    model_name=model_name,
    model_task=model_task,
    batch_size=BATCH_SIZE,
    finetuned_weights=finetuned_weights,
    pruning_type=pruning_type,
    target_sparsity=target_sparsity,
    learning_type=learning_type,
    optimizer_type='adamw',
    learning_rate=0.0001,
    sparsity_scheduler='cubic'
)
trainer = Trainer(training_args=training_args)
if True:
    trainer.train()

acc = trainer.evaluate()
print(f"\nAccuracy = {acc}")

In [0]:
# Model initialization
model_name = "vitb16"
model_task = "cifar10"

# Initializing finetuned weights path
finetuned_weights = f"/dbfs/research/{model_name}/{model_task}/{model_name}_{model_task}_baseline.pt"

# Initializing pruning args
pruning_type = "movement_pruning"
target_sparsity = TARGET_SPARSITY_MID
learning_type = "pruning"

training_args = CVTrainingArguments(
    model_name=model_name,
    model_task=model_task,
    batch_size=BATCH_SIZE,
    finetuned_weights=finetuned_weights,
    pruning_type=pruning_type,
    target_sparsity=target_sparsity,
    learning_type=learning_type,
    optimizer_type='adamw',
    learning_rate=0.0001,
    sparsity_scheduler='cubic'
)
trainer = Trainer(training_args=training_args)
if True:
    trainer.train()

acc = trainer.evaluate()
print(f"\nAccuracy = {acc}")

In [0]:
# Model initialization
model_name = "vitb16"
model_task = "cifar10"

# Initializing finetuned weights path
finetuned_weights = f"/dbfs/research/{model_name}/{model_task}/{model_name}_{model_task}_baseline.pt"

# Initializing pruning args
pruning_type = "movement_pruning"
target_sparsity = TARGET_SPARSITY_HIGH
learning_type = "pruning"

training_args = CVTrainingArguments(
    model_name=model_name,
    model_task=model_task,
    batch_size=BATCH_SIZE,
    finetuned_weights=finetuned_weights,
    pruning_type=pruning_type,
    target_sparsity=target_sparsity,
    learning_type=learning_type,
    optimizer_type='adamw',
    learning_rate=0.0001,
    sparsity_scheduler='cubic'
)
trainer = Trainer(training_args=training_args)
if True:
    trainer.train()

acc = trainer.evaluate()
print(f"\nAccuracy = {acc}")

## BaCP Accuracies

In [0]:
# Model initialization
model_name = "vitb16"
model_task = "cifar10"

# Initializing finetuned weights path
finetuned_weights = f"/dbfs/research/{model_name}/{model_task}/{model_name}_{model_task}_baseline.pt"

# Initializing pruning args
pruning_type = 'magnitude_pruning'
target_sparsity = TARGET_SPARSITY_LOW
learning_type = "pruning"

bacp_training_args = BaCPTrainingArgumentsCNN(
    model_name=model_name,
    model_task=model_task,
    batch_size=BATCH_SIZE,
    finetuned_weights=finetuned_weights,
    pruning_type=pruning_type,
    target_sparsity=target_sparsity,
    optimizer_type='adamw',
    learning_rate=0.001,
    sparsity_scheduler='cubic'
)
bacp_trainer = BaCPTrainer(bacp_training_args=bacp_training_args)
if True:
    bacp_trainer.train()

# Finetuning Phase
bacp_trainer.generate_mask_from_model()
pruner = bacp_trainer.get_pruner()

training_args = CVTrainingArguments(
    model_name=bacp_trainer.model_name,
    model_task=bacp_trainer.model_task,
    batch_size=bacp_trainer.batch_size,
    pruning_type=bacp_trainer.pruning_type,
    target_sparsity=bacp_trainer.target_sparsity,
    finetuned_weights=bacp_trainer.cm_save_path,
    epochs=50,
    pruner=pruner,
    finetune=True,
    learning_type="bacp_finetune",
    optimizer_type='adamw',
    learning_rate=0.0005,
)
trainer = Trainer(training_args)
if True:
    trainer.train()

acc = trainer.evaluate()
print(f"Accuracy = {acc}")
