In [0]:
%restart_python

In [0]:
%load_ext autoreload
%autoreload 2

In [0]:
import sys
import os
sys.path.append(os.path.abspath('..'))

import torch.nn as nn
import torch.optim as optim

# importing modules from main directory
from contrastive_learning import ContrastiveLearner
from models import EncoderProjectionNetwork, ClassificationNetwork
from datasets_class import CreateDatasets
from supervised_learning import train, test
from unstructured_pruning import MagnitudePrune, MovementPrune, LocalMagnitudePrune, LocalMovementPrune
from torch.utils.data import DataLoader
from bacp import BaCPLearner, create_models_for_bacp
from logger import Logger
from utils import *
from constants import *


In [0]:
# BATCH_SIZE = 512  # For baseline pruning
BATCH_SIZE = 256    # For BaCP 
NUM_WORKERS = 24
SIZE = 224
BATCH_SIZE_VITL16 = 128
BATCH_SIZE_VIBL16 = 512

In [0]:
datasets = CreateDatasets('./data')

# Data for supervised learning
trainset_c10_cls_fn, testset_c10_cls_fn = datasets.get_dataset_fn('supervised', 'cifar10', SIZE)
trainset_c10_cls, testset_c10_cls = trainset_c10_cls_fn(), testset_c10_cls_fn()

trainloader_c10_cls = DataLoader(trainset_c10_cls, BATCH_SIZE_VITL16, shuffle=True)
testloader_c10_cls = DataLoader(testset_c10_cls, BATCH_SIZE_VITL16, shuffle=False)


# Baseline Accuracies

## ViT-Base-16

In [0]:
# Creating a classification net for downstream task
model_name = 'vitb16'    
cls_model = ClassificationNetwork(model_name, CIFAR10_CLASSES, False).to(get_device())

# Initializing Hyperparameters
optimizer_type = 'adam'
optimizer_cfg = {
    'model':        cls_model,
    'lr':           LR_VITB16,
    'momentum':     MOMENTUM,
    'weight_decay': WEIGHT_DECAY,
}
optimizer = set_optimizer(optimizer_type, optimizer_cfg)
scheduler = None
criterion = nn.CrossEntropyLoss()
save_path = f'/dbfs/{model_name}_weights/{model_name}_sup_c10.pt'
logger = Logger(model_name, learning_type='cls')

config = {
    'trainloader':      trainloader_c10_cls,
    'testloader':       testloader_c10_cls,
    'optimizer':        optimizer,
    'scheduler':        scheduler,
    'criterion':        nn.CrossEntropyLoss(),
    'epochs':           EPOCHS_VITB16,
    "batch_size":       BATCH_SIZE,
    "save_path":        save_path,
    "logger":           logger,
}

# Set True if trained
is_trained = True
if not is_trained:
    losses, train_accuracies, test_accuracies = train(cls_model,
                                                      config)
    graph_losses_n_accs(losses, train_accuracies, test_accuracies)
    
# Evaluating model
load_weights(cls_model, save_path)
acc = test(cls_model, testloader_c10_cls)
print(f"\nAccuracy of model is: {acc}")

## ViT-Large-16

In [0]:
# Creating a classification net for downstream task
model_name = 'vitl16'    
cls_model = ClassificationNetwork(model_name, CIFAR10_CLASSES, False).to(get_device())

# Initializing Hyperparameters
optimizer_cfg = {
    'model':        cls_model,
    'lr':           LR_VITL16,
    'momentum':     MOMENTUM,
    'weight_decay': WEIGHT_DECAY,
}
optimizer = set_optimizer('adam', optimizer_cfg)
save_path = f'/dbfs/{model_name}_weights/{model_name}_sup_c10.pt'

config = {
    'trainloader':      trainloader_c10_cls,
    'testloader':       testloader_c10_cls,
    'optimizer':        optimizer,
    'scheduler':        None,
    'criterion':        nn.CrossEntropyLoss(),
    'epochs':           EPOCHS_VITL16,
    "batch_size":       BATCH_SIZE,
    "save_path":        save_path,
    "logger":           Logger(model_name, learning_type='cls'),
    "lambda_reg":       0,
    'recover_epochs':   0,
    'pruning_type':     "",
}

# Set True if trained
is_trained = False
if not is_trained:
    losses, train_accuracies, test_accuracies = train(cls_model, config)
    graph_losses_n_accs(losses, train_accuracies, test_accuracies)


# Evaluating model
acc = test(cls_model, testloader_c10_cls)
load_weights(cls_model, save_path)
print(f"\nAccuracy of model is: {acc}")

# Pruning Accuracies

## Sparsity: 0.95

### ViT-Base-16

#### Magnitude Pruning - Movement Pruning

In [0]:
#######################################
########## MAGNITUDE PRUNING ##########
#######################################

# Creating classification model
model_name = 'vitb16'    
cls_model = ClassificationNetwork(model_name, CIFAR10_CLASSES, False).to(get_device())
load_weights(cls_model, f'/dbfs/{model_name}_weights/{model_name}_sup_c10.pt')

# Initializing Hyperparameters
optimizer_type = 'adam'
optimizer_cfg = {
    'model':        cls_model,
    'lr':           LR_VITB16,
    'momentum':     MOMENTUM,
    'weight_decay': WEIGHT_DECAY
}
optimizer = set_optimizer(optimizer_type, optimizer_cfg)
scheduler = None
criterion = nn.CrossEntropyLoss()
logger = Logger(model_name, learning_type='cls')
save_path = f'/dbfs/{model_name}_weights/{model_name}_magp_095.pt'

# Initialing pruning method
pruner = LocalMagnitudePrune(1, TARGET_SPARSITY_LOW)
config = {
    'trainloader':      trainloader_c10_cls,
    'testloader':       testloader_c10_cls,
    'optimizer':        optimizer,
    'scheduler':        scheduler,
    'criterion':        nn.CrossEntropyLoss(),
    'epochs':           1,
    "batch_size":       BATCH_SIZE,
    "save_path":        save_path,
    "logger":           logger,
    'recover_epochs':   50,
    'pruning_type':    'magnitude_pruning',
}

# Set False to train
is_trained = False
if not is_trained:
    losses, train_accuracies, test_accuracies = train(cls_model,
                                                      config,
                                                      pruner)
    torch.save(cls_model.state_dict(), save_path)
    
    # Displaying loss-acc graph
    graph_losses_n_accs(losses, 
                        train_accuracies, 
                        test_accuracies)

# Evaluating model
load_weights(cls_model, save_path)
print(f"Sparsity of pruned model: {get_model_sparsity(cls_model):.3f}")
pruned_acc = test(cls_model, testloader_c10_cls)
print(f"Accuracy of pruned model is: {pruned_acc}")


In [0]:
######################################
########## MOVEMENT PRUNING ##########
######################################

# Creating classification model
model_name = 'vitb16'    
cls_model = ClassificationNetwork(model_name, CIFAR10_CLASSES, False).to(get_device())
load_weights(cls_model, f'/dbfs/{model_name}_weights/{model_name}_sup_c10.pt')

# Initializing Hyperparameters
optimizer_type = 'adam'
optimizer_cfg = {
    'model':        cls_model,
    'lr':           LR_VITB16,
    'momentum':     MOMENTUM,
    'weight_decay': WEIGHT_DECAY
}
optimizer = set_optimizer(optimizer_type, optimizer_cfg)
scheduler = None
criterion = nn.CrossEntropyLoss()
logger = Logger(model_name, learning_type='cls')
save_path = f'/dbfs/{model_name}_weights/{model_name}_mvmp_095.pt'

# Initialing pruning method
# pruner = MovementPrune(1, TARGET_SPARSITY_LOW)
pruner = LocalMovementPrune(1, TARGET_SPARSITY_LOW)
config = {
    'trainloader':      trainloader_c10_cls,
    'testloader':       testloader_c10_cls,
    'optimizer':        optimizer,
    'scheduler':        scheduler,
    'criterion':        nn.CrossEntropyLoss(),
    'epochs':           1,
    "batch_size":       BATCH_SIZE,
    "save_path":        save_path,
    "logger":           logger,
    'recover_epochs':   50,
    'pruning_type':    'movement_pruning',
}

# Set False to train
is_trained = False
if not is_trained:
    losses, train_accuracies, test_accuracies = train(cls_model,
                                                      config,
                                                      pruner)
    torch.save(cls_model.state_dict(), save_path)
    
    # Displaying loss-acc graph
    graph_losses_n_accs(losses, 
                        train_accuracies, 
                        test_accuracies)

# Evaluating model
load_weights(cls_model, save_path)
print(f"Sparsity of pruned model: {get_model_sparsity(cls_model):.3f}")
pruned_acc = test(cls_model, testloader_c10_cls)
print(f"Accuracy of pruned model is: {pruned_acc}")


### ViT-Large-16

#### Magnitude Prunining - Movement Pruning

In [0]:
#######################################
########## MAGNITUDE PRUNING ##########
#######################################

# Creating classification model
model_name = 'vitl16'    
cls_model = ClassificationNetwork(model_name, CIFAR10_CLASSES, False).to(get_device())
load_weights(cls_model, f'/dbfs/{model_name}_weights/{model_name}_sup_c10.pt')

# Initializing Hyperparameters
optimizer_cfg = {
    'model':        cls_model,
    'lr':           LR_VITL16,
    'momentum':     MOMENTUM,
    'weight_decay': WEIGHT_DECAY
}
optimizer = set_optimizer('adam', optimizer_cfg)
logger = Logger(model_name, learning_type='cls')
save_path = f'/dbfs/{model_name}_weights/{model_name}_magp_095.pt'

# Initialing pruning method
pruner = LocalMagnitudePrune(1, TARGET_SPARSITY_LOW)
config = {
    'trainloader':      trainloader_c10_cls,
    'testloader':       testloader_c10_cls,
    'optimizer':        optimizer,
    'scheduler':        None,
    'criterion':        nn.CrossEntropyLoss(),
    'epochs':           1,
    "batch_size":       BATCH_SIZE_VITL16,
    "save_path":        save_path,
    "logger":           logger,
    'recover_epochs':   50,
    'pruning_type':    'local_magnitude_pruning',
}

# Set False to train
is_trained = False
if not is_trained:
    losses, train_accuracies, test_accuracies = train(cls_model,
                                                      config,
                                                      pruner)
    torch.save(cls_model.state_dict(), save_path)
    
    # Displaying loss-acc graph
    graph_losses_n_accs(losses, 
                        train_accuracies, 
                        test_accuracies)

# Evaluating model
load_weights(cls_model, save_path)
print(f"Sparsity of pruned model: {get_model_sparsity(cls_model):.3f}")
pruned_acc = test(cls_model, testloader_c10_cls)
print(f"Accuracy of pruned model is: {pruned_acc}")


In [0]:
######################################
########## MOVEMENT PRUNING ##########
######################################

# Creating classification model
model_name = 'vitl16'    
cls_model = ClassificationNetwork(model_name, CIFAR10_CLASSES, False).to(get_device())
load_weights(cls_model, f'/dbfs/{model_name}_weights/{model_name}_sup_c10.pt')

# Initializing Hyperparameters
optimizer_type = 'adam'
optimizer_cfg = {
    'model':        cls_model,
    'lr':           LR_VITL16,
    'momentum':     MOMENTUM,
    'weight_decay': WEIGHT_DECAY
}
optimizer = set_optimizer(optimizer_type, optimizer_cfg)
scheduler = None
criterion = nn.CrossEntropyLoss()
logger = Logger(model_name, learning_type='cls')
save_path = f'/dbfs/{model_name}_weights/{model_name}_mvmp_095.pt'

# Initialing pruning method
pruner = LocalMovementPrune(1, TARGET_SPARSITY_LOW)
config = {
    'trainloader':      trainloader_c10_cls,
    'testloader':       testloader_c10_cls,
    'optimizer':        optimizer,
    'scheduler':        scheduler,
    'criterion':        nn.CrossEntropyLoss(),
    'epochs':           1,
    "batch_size":       BATCH_SIZE_VITL16,
    "save_path":        save_path,
    "logger":           logger,
    'recover_epochs':   50,
    'pruning_type':    'local_movement_pruning',
}

# Set False to train
is_trained = False
if not is_trained:
    losses, train_accuracies, test_accuracies = train(cls_model,
                                                      config,
                                                      pruner)
    torch.save(cls_model.state_dict(), save_path)
    
    # Displaying loss-acc graph
    graph_losses_n_accs(losses, 
                        train_accuracies, 
                        test_accuracies)

# Evaluating model
load_weights(cls_model, save_path)
print(f"Sparsity of pruned model: {get_model_sparsity(cls_model):.3f}")
pruned_acc = test(cls_model, testloader_c10_cls)
print(f"Accuracy of pruned model is: {pruned_acc}")

## Sparsity: 0.97

### ViT-B-16

#### Magnitude Pruning - Movement Pruning

In [0]:
#######################################
########## MAGNITUDE PRUNING ##########
#######################################

# Creating classification model
model_name = 'vitb16'    
cls_model = ClassificationNetwork(model_name, CIFAR10_CLASSES, False).to(get_device())
load_weights(cls_model, f'/dbfs/{model_name}_weights/{model_name}_sup_c10.pt')

# Initializing Hyperparameters
optimizer_type = 'adam'
optimizer_cfg = {
    'model':        cls_model,
    'lr':           LR_VITB16,
    'momentum':     MOMENTUM,
    'weight_decay': WEIGHT_DECAY
}
optimizer = set_optimizer(optimizer_type, optimizer_cfg)
scheduler = None
criterion = nn.CrossEntropyLoss()
logger = Logger(model_name, learning_type='cls')
save_path = f'/dbfs/{model_name}_weights/{model_name}_magp_097.pt'

# Initialing pruning method
pruner = LocalMagnitudePrune(1, TARGET_SPARSITY_MID)
config = {
    'trainloader':      trainloader_c10_cls,
    'testloader':       testloader_c10_cls,
    'optimizer':        optimizer,
    'scheduler':        scheduler,
    'criterion':        nn.CrossEntropyLoss(),
    'epochs':           1,
    "batch_size":       BATCH_SIZE,
    "save_path":        save_path,
    "logger":           logger,
    'recover_epochs':   50,
    'pruning_type':    'local_magnitude_pruning',
}

# Set False to train
is_trained = False
if not is_trained:
    losses, train_accuracies, test_accuracies = train(cls_model,
                                                      config,
                                                      pruner)
    torch.save(cls_model.state_dict(), save_path)
    
    # Displaying loss-acc graph
    graph_losses_n_accs(losses, 
                        train_accuracies, 
                        test_accuracies)

# Evaluating model
load_weights(cls_model, save_path)
print(f"Sparsity of pruned model: {get_model_sparsity(cls_model):.3f}")
pruned_acc = test(cls_model, testloader_c10_cls)
print(f"Accuracy of pruned model is: {pruned_acc}")


In [0]:
######################################
########## MOVEMENT PRUNING ##########
######################################

# Creating classification model
model_name = 'vitb16'    
cls_model = ClassificationNetwork(model_name, CIFAR10_CLASSES, False).to(get_device())
load_weights(cls_model, f'/dbfs/{model_name}_weights/{model_name}_sup_c10.pt')

# Initializing Hyperparameters
optimizer_type = 'adam'
optimizer_cfg = {
    'model':        cls_model,
    'lr':           LR_VITB16,
    'momentum':     MOMENTUM,
    'weight_decay': WEIGHT_DECAY
}
optimizer = set_optimizer(optimizer_type, optimizer_cfg)
scheduler = None
criterion = nn.CrossEntropyLoss()
logger = Logger(model_name, learning_type='cls')
save_path = f'/dbfs/{model_name}_weights/{model_name}_mvmp_097.pt'

# Initialing pruning method
pruner = LocalMovementPrune(1, TARGET_SPARSITY_MID)
config = {
    'trainloader':      trainloader_c10_cls,
    'testloader':       testloader_c10_cls,
    'optimizer':        optimizer,
    'scheduler':        scheduler,
    'criterion':        nn.CrossEntropyLoss(),
    'epochs':           1,
    "batch_size":       BATCH_SIZE,
    "save_path":        save_path,
    "logger":           logger,
    'recover_epochs':   50,
    'pruning_type':    'local_movement_pruning',
}

# Set False to train
is_trained = False
if not is_trained:
    losses, train_accuracies, test_accuracies = train(cls_model,
                                                      config,
                                                      pruner)
    torch.save(cls_model.state_dict(), save_path)
    
    # Displaying loss-acc graph
    graph_losses_n_accs(losses, 
                        train_accuracies, 
                        test_accuracies)

# Evaluating model
load_weights(cls_model, save_path)
print(f"Sparsity of pruned model: {get_model_sparsity(cls_model):.3f}")
pruned_acc = test(cls_model, testloader_c10_cls)
print(f"Accuracy of pruned model is: {pruned_acc}")


### ViT-L-16

#### Magnitude Pruning - Movement Pruning

In [0]:
#######################################
########## MAGNITUDE PRUNING ##########
#######################################

# Creating classification model
model_name = 'vitl16'    
cls_model = ClassificationNetwork(model_name, CIFAR10_CLASSES, False).to(get_device())
load_weights(cls_model, f'/dbfs/{model_name}_weights/{model_name}_sup_c10.pt')

# Initializing Hyperparameters
optimizer_type = 'adam'
optimizer_cfg = {
    'model':        cls_model,
    'lr':           LR_VITL16,
    'momentum':     MOMENTUM,
    'weight_decay': WEIGHT_DECAY
}
optimizer = set_optimizer(optimizer_type, optimizer_cfg)
scheduler = None
criterion = nn.CrossEntropyLoss()
logger = Logger(model_name, learning_type='cls')
save_path = f'/dbfs/{model_name}_weights/{model_name}_magp_097.pt'

# Initialing pruning method
pruner = LocalMagnitudePrune(1, TARGET_SPARSITY_MID)
config = {
    'trainloader':      trainloader_c10_cls,
    'testloader':       testloader_c10_cls,
    'optimizer':        optimizer,
    'scheduler':        scheduler,
    'criterion':        nn.CrossEntropyLoss(),
    'epochs':           1,
    "batch_size":       BATCH_SIZE_VITL16,
    "save_path":        save_path,
    "logger":           logger,
    'recover_epochs':   23,
    'pruning_type':    'local_magnitude_pruning',
}

# Set False to train
is_trained = False
if not is_trained:
    losses, train_accuracies, test_accuracies = train(cls_model,
                                                      config,
                                                      pruner)
    torch.save(cls_model.state_dict(), save_path)
    
    # Displaying loss-acc graph
    graph_losses_n_accs(losses, 
                        train_accuracies, 
                        test_accuracies)

# Evaluating model
load_weights(cls_model, save_path)
print(f"Sparsity of pruned model: {get_model_sparsity(cls_model):.3f}")
pruned_acc = test(cls_model, testloader_c10_cls)
print(f"Accuracy of pruned model is: {pruned_acc}")


In [0]:
######################################
########## MOVEMENT PRUNING ##########
######################################

# Creating classification model
model_name = 'vitl16'    
cls_model = ClassificationNetwork(model_name, CIFAR10_CLASSES, False).to(get_device())
load_weights(cls_model, f'/dbfs/{model_name}_weights/{model_name}_sup_c10.pt')

# Initializing Hyperparameters
optimizer_type = 'adam'
optimizer_cfg = {
    'model':        cls_model,
    'lr':           LR_VITL16,
    'momentum':     MOMENTUM,
    'weight_decay': WEIGHT_DECAY
}
optimizer = set_optimizer(optimizer_type, optimizer_cfg)
scheduler = None
criterion = nn.CrossEntropyLoss()
logger = Logger(model_name, learning_type='cls')
save_path = f'/dbfs/{model_name}_weights/{model_name}_mvmp_097.pt'

# Initialing pruning method
pruner = LocalMovementPrune(1, TARGET_SPARSITY_MID)
config = {
    'trainloader':      trainloader_c10_cls,
    'testloader':       testloader_c10_cls,
    'optimizer':        optimizer,
    'scheduler':        scheduler,
    'criterion':        nn.CrossEntropyLoss(),
    'epochs':           1,
    "batch_size":       BATCH_SIZE_VITL16,
    "save_path":        save_path,
    "logger":           logger,
    'recover_epochs':   50,
    'pruning_type':    'local_movement_pruning',
}

# Set False to train
is_trained = False
if not is_trained:
    losses, train_accuracies, test_accuracies = train(cls_model,
                                                      config,
                                                      pruner)
    torch.save(cls_model.state_dict(), save_path)
    
    # Displaying loss-acc graph
    graph_losses_n_accs(losses, 
                        train_accuracies, 
                        test_accuracies)

# Evaluating model
load_weights(cls_model, save_path)
print(f"Sparsity of pruned model: {get_model_sparsity(cls_model):.3f}")
pruned_acc = test(cls_model, testloader_c10_cls)
print(f"Accuracy of pruned model is: {pruned_acc}")


## Sparsity: 0.99

### ViT-B-16

#### Magnitude Pruning - Movement Pruning

In [0]:
#######################################
########## MAGNITUDE PRUNING ##########
#######################################

# Creating classification model
model_name = 'vitb16'    
cls_model = ClassificationNetwork(model_name, CIFAR10_CLASSES, False).to(get_device())
load_weights(cls_model, f'/dbfs/{model_name}_weights/{model_name}_sup_c10.pt')

# Initializing Hyperparameters
optimizer_type = 'adam'
optimizer_cfg = {
    'model':        cls_model,
    'lr':           LR_VITB16,
    'momentum':     MOMENTUM,
    'weight_decay': WEIGHT_DECAY
}
optimizer = set_optimizer(optimizer_type, optimizer_cfg)
scheduler = None
criterion = nn.CrossEntropyLoss()
logger = Logger(model_name, learning_type='cls')
save_path = f'/dbfs/{model_name}_weights/{model_name}_magp_099.pt'

# Initialing pruning method
pruner = LocalMagnitudePrune(1, TARGET_SPARSITY_HIGH)
config = {
    'trainloader':      trainloader_c10_cls,
    'testloader':       testloader_c10_cls,
    'optimizer':        optimizer,
    'scheduler':        scheduler,
    'criterion':        nn.CrossEntropyLoss(),
    'epochs':           1,
    "batch_size":       BATCH_SIZE,
    "save_path":        save_path,
    "logger":           logger,
    'recover_epochs':   50,
    'pruning_type':    'local_magnitude_pruning',
}

# Set False to train
is_trained = False
if not is_trained:
    losses, train_accuracies, test_accuracies = train(cls_model,
                                                      config,
                                                      pruner)
    torch.save(cls_model.state_dict(), save_path)
    
    # Displaying loss-acc graph
    graph_losses_n_accs(losses, 
                        train_accuracies, 
                        test_accuracies)

# Evaluating model
load_weights(cls_model, save_path)
print(f"Sparsity of pruned model: {get_model_sparsity(cls_model):.3f}")
pruned_acc = test(cls_model, testloader_c10_cls)
print(f"Accuracy of pruned model is: {pruned_acc}")


In [0]:
######################################
########## MOVEMENT PRUNING ##########
######################################

# Creating classification model
model_name = 'vitb16'    
cls_model = ClassificationNetwork(model_name, CIFAR10_CLASSES, False).to(get_device())
load_weights(cls_model, f'/dbfs/{model_name}_weights/{model_name}_sup_c10.pt')

# Initializing Hyperparameters
optimizer_type = 'adam'
optimizer_cfg = {
    'model':        cls_model,
    'lr':           LR_VITB16,
    'momentum':     MOMENTUM,
    'weight_decay': WEIGHT_DECAY
}
optimizer = set_optimizer(optimizer_type, optimizer_cfg)
scheduler = None
criterion = nn.CrossEntropyLoss()
logger = Logger(model_name, learning_type='cls')
save_path = f'/dbfs/{model_name}_weights/{model_name}_mvmp_099.pt'

# Initialing pruning method
pruner = LocalMovementPrune(1, TARGET_SPARSITY_HIGH)
config = {
    'trainloader':      trainloader_c10_cls,
    'testloader':       testloader_c10_cls,
    'optimizer':        optimizer,
    'scheduler':        scheduler,
    'criterion':        nn.CrossEntropyLoss(),
    'epochs':           1,
    "batch_size":       BATCH_SIZE,
    "save_path":        save_path,
    "logger":           logger,
    'recover_epochs':   50,
    'pruning_type':    'local_movement_pruning',
}

# Set False to train
is_trained = False
if not is_trained:
    losses, train_accuracies, test_accuracies = train(cls_model,
                                                      config,
                                                      pruner)
    torch.save(cls_model.state_dict(), save_path)
    
    # Displaying loss-acc graph
    graph_losses_n_accs(losses, 
                        train_accuracies, 
                        test_accuracies)

# Evaluating model
load_weights(cls_model, save_path)
print(f"Sparsity of pruned model: {get_model_sparsity(cls_model):.3f}")
pruned_acc = test(cls_model, testloader_c10_cls)
print(f"Accuracy of pruned model is: {pruned_acc}")


### ViT-L-16

#### Magnitude Pruning - Movement Pruning

In [0]:
#######################################
########## MAGNITUDE PRUNING ##########
#######################################

# Creating classification model
model_name = 'vitl16'    
cls_model = ClassificationNetwork(model_name, CIFAR10_CLASSES, False).to(get_device())
load_weights(cls_model, f'/dbfs/{model_name}_weights/{model_name}_sup_c10.pt')

# Initializing Hyperparameters
optimizer_type = 'adam'
optimizer_cfg = {
    'model':        cls_model,
    'lr':           LR_VITL16,
    'momentum':     MOMENTUM,
    'weight_decay': WEIGHT_DECAY
}
optimizer = set_optimizer(optimizer_type, optimizer_cfg)
scheduler = None
criterion = nn.CrossEntropyLoss()
logger = Logger(model_name, learning_type='cls')
save_path = f'/dbfs/{model_name}_weights/{model_name}_magp_099.pt'

# Initialing pruning method
pruner = LocalMagnitudePrune(1, TARGET_SPARSITY_HIGH)
config = {
    'trainloader':      trainloader_c10_cls,
    'testloader':       testloader_c10_cls,
    'optimizer':        optimizer,
    'scheduler':        scheduler,
    'criterion':        nn.CrossEntropyLoss(),
    'epochs':           1,
    "batch_size":       BATCH_SIZE_VITL16,
    "save_path":        save_path,
    "logger":           logger,
    'recover_epochs':   50,
    'pruning_type':    'local_magnitude_pruning',
}

# Set False to train
is_trained = False
if not is_trained:
    losses, train_accuracies, test_accuracies = train(cls_model,
                                                      config,
                                                      pruner)
    torch.save(cls_model.state_dict(), save_path)
    
    # Displaying loss-acc graph
    graph_losses_n_accs(losses, 
                        train_accuracies, 
                        test_accuracies)

# Evaluating model
load_weights(cls_model, save_path)
print(f"Sparsity of pruned model: {get_model_sparsity(cls_model):.3f}")
pruned_acc = test(cls_model, testloader_c10_cls)
print(f"Accuracy of pruned model is: {pruned_acc}")


In [0]:
######################################
########## MOVEMENT PRUNING ##########
######################################

# Creating classification model
model_name = 'vitl16'    
cls_model = ClassificationNetwork(model_name, CIFAR10_CLASSES, False).to(get_device())
load_weights(cls_model, f'/dbfs/{model_name}_weights/{model_name}_sup_c10.pt')

# Initializing Hyperparameters
optimizer_type = 'adam'
optimizer_cfg = {
    'model':        cls_model,
    'lr':           LR_VITL16,
    'momentum':     MOMENTUM,
    'weight_decay': WEIGHT_DECAY
}
optimizer = set_optimizer(optimizer_type, optimizer_cfg)
scheduler = None
criterion = nn.CrossEntropyLoss()
logger = Logger(model_name, learning_type='cls')
save_path = f'/dbfs/{model_name}_weights/{model_name}_mvmp_099.pt'

# Initialing pruning method
pruner = LocalMovementPrune(1, TARGET_SPARSITY_HIGH)
config = {
    'trainloader':      trainloader_c10_cls,
    'testloader':       testloader_c10_cls,
    'optimizer':        optimizer,
    'scheduler':        scheduler,
    'criterion':        nn.CrossEntropyLoss(),
    'epochs':           1,
    "batch_size":       BATCH_SIZE_VITL16,
    "save_path":        save_path,
    "logger":           logger,
    'recover_epochs':   50,
    'pruning_type':    'local_movement_pruning',
}

# Set False to train
is_trained = False
if not is_trained:
    losses, train_accuracies, test_accuracies = train(cls_model,
                                                      config,
                                                      pruner)
    torch.save(cls_model.state_dict(), save_path)
    
    # Displaying loss-acc graph
    graph_losses_n_accs(losses, 
                        train_accuracies, 
                        test_accuracies)

# Evaluating model
load_weights(cls_model, save_path)
print(f"Sparsity of pruned model: {get_model_sparsity(cls_model):.3f}")
pruned_acc = test(cls_model, testloader_c10_cls)
print(f"Accuracy of pruned model is: {pruned_acc}")


# BaCP Accuracies

## Sparsity: 0.95

### ViT-Base-16

In [0]:
#######################################
########## MAGNITUDE PRUNING ##########
#######################################

# Creating projection models for BaCP framework
model_name = 'vitb16'
                                  
# Projection networks
finetuned_weights = f'/dbfs/{model_name}_weights/{model_name}_sup_c10.pt'
pre_trained_model, current_model, finetuned_model = create_models_for_cap(model_name, finetuned_weights)
print(f"Current model sparsity: {get_model_sparsity(current_model)}")

# Initializing Hyperparameters
optimizer_type = 'adam'
optimizer_cfg = {
    'model':        current_model,
    'lr':           0.0001,
    'momentum':     MOMENTUM,
    'weight_decay': WEIGHT_DECAY
}
optimizer = set_optimizer(optimizer_type, optimizer_cfg)
scheduler = None
criterion = nn.CrossEntropyLoss()
save_path = f'/dbfs/{model_name}_weights/{model_name}_bacp_magp_095'
logger = Logger(model_name, learning_type='bacp')

# Initializing pruner
pruner = LocalMagnitudePrune(1, TARGET_SPARSITY_LOW)

config = {
    'model_name':       model_name,
    'optimizer':        optimizer,
    'scheduler':        scheduler,               
    'criterion':        criterion,
    'target_sparsity':  TARGET_SPARSITY_LOW,   
    'logger':           logger,
    'batch_size':       BATCH_SIZE,     
    'num_classes':      CIFAR10_CLASSES,
    'lambdas':          [0.25, 0.25, 0.25, 0.25], 
    'save_path':        save_path,
    'pruner':           pruner,
    'recovery_epochs':  0,
    'epochs':           20,
    'pruning_epochs':   1,
}

cap_learner = BaCPLearner(current_model, pre_trained_model, finetuned_model, config)

# Set False to train
is_trained = False
if not is_trained:
    cap_learner.train(trainloader_c10_cl)

In [0]:
# Creating classification model with unfrozen parameters
model_name = 'vitb16'
cls_model = cap_learner.create_classification_net(False)
print(f"Current model sparsity: {get_model_sparsity(cls_model)}")

# Generate masks from model
cap_learner.generate_mask_from_model()

# Initializing Hyperparameters
optimizer_type = 'adam'
optimizer_cfg = {
    'model':        cls_model,
    'lr':           0.0001,
    'momentum':     MOMENTUM,
    'weight_decay': WEIGHT_DECAY
}
optimizer = set_optimizer(optimizer_type, optimizer_cfg)
scheduler = None
criterion = nn.CrossEntropyLoss()
save_path = f'/dbfs/{model_name}_weights/{model_name}_bacp_magp_095_ds.pt'
logger = Logger(model_name, learning_type='cls')

# Initializing pruner
pruner = cap_learner.get_pruner()
pruning_type = 'magnitude_pruning'

config = {
    'trainloader':      trainloader_c10_cls,
    'testloader':       testloader_c10_cls,
    'optimizer':        optimizer,
    'scheduler':        scheduler,
    'criterion':        nn.CrossEntropyLoss(),
    'epochs':           30,
    "batch_size":       BATCH_SIZE,
    "save_path":        save_path,
    "logger":           logger,
    'recover_epochs':   0,
    'pruning_type':     pruning_type,
    'stop_epochs':      5,
}

# Set False to train
is_trained = False
if not is_trained:
    losses, train_accuracies, test_accuracies = train(cls_model,
                                                      config,
                                                      pruner,
                                                      True)
    graph_losses_n_accs(losses, 
                        train_accuracies, 
                        test_accuracies)

# Evaluating model
load_weights(cls_model, save_path)
print(f"Current model sparsity: {get_model_sparsity(cls_model)}")
acc = test(cls_model, testloader_c10_cls)
print(f"\nAccuracy of model is: {acc}")

In [0]:
######################################
########## MOVEMENT PRUNING ##########
######################################

# Creating projection models for BaCP framework
model_name = 'vitb16'
                                  
# Projection networks
finetuned_weights = f'/dbfs/{model_name}_weights/{model_name}_sup_c10.pt'
pre_trained_model, current_model, finetuned_model = create_models_for_cap(model_name, finetuned_weights)
print(f"Current model sparsity: {get_model_sparsity(current_model)}")

# Initializing Hyperparameters
optimizer_type = 'adam'
optimizer_cfg = {
    'model':        current_model,
    'lr':           0.0001,
    'momentum':     MOMENTUM,
    'weight_decay': WEIGHT_DECAY
}
optimizer = set_optimizer(optimizer_type, optimizer_cfg)
scheduler = None
criterion = nn.CrossEntropyLoss()
save_path = f'/dbfs/{model_name}_weights/{model_name}_bacp_mvmp_095'
logger = Logger(model_name, learning_type='bacp')

# Initializing pruner
pruner = LocalMovementPrune(1, TARGET_SPARSITY_LOW)

config = {
    'model_name':       model_name,
    'optimizer':        optimizer,
    'scheduler':        scheduler,               
    'criterion':        criterion,
    'target_sparsity':  TARGET_SPARSITY_LOW,   
    'logger':           logger,
    'batch_size':       BATCH_SIZE,     
    'num_classes':      CIFAR10_CLASSES,
    'lambdas':          [0.25, 0.25, 0.25, 0.25], 
    'save_path':        save_path,
    'pruner':           pruner,
    'recovery_epochs':  0,
    'epochs':           20,
    'pruning_epochs':   1,
}

cap_learner = BaCPLearner(current_model, pre_trained_model, finetuned_model, config)

# Set False to train
is_trained = False
if not is_trained:
    cap_learner.train(trainloader_c10_cl)

In [0]:
# Creating classification model with unfrozen parameters
model_name = 'vitb16'
cls_model = cap_learner.create_classification_net(False)
print(f"Current model sparsity: {get_model_sparsity(cls_model)}")

# Generate masks from model
cap_learner.generate_mask_from_model()

# Initializing Hyperparameters
optimizer_type = 'adam'
optimizer_cfg = {
    'model':        cls_model,
    'lr':           0.0001,
    'momentum':     MOMENTUM,
    'weight_decay': WEIGHT_DECAY
}
optimizer = set_optimizer(optimizer_type, optimizer_cfg)
scheduler = None
criterion = nn.CrossEntropyLoss()
save_path = f'/dbfs/{model_name}_weights/{model_name}_bacp_mvmp_095_ds.pt'
logger = Logger(model_name, learning_type='cls')

# Initializing pruner
pruner = cap_learner.get_pruner()
pruning_type = 'movement_pruning'

config = {
    'trainloader':      trainloader_c10_cls,
    'testloader':       testloader_c10_cls,
    'optimizer':        optimizer,
    'scheduler':        scheduler,
    'criterion':        nn.CrossEntropyLoss(),
    'epochs':           30,
    "batch_size":       BATCH_SIZE,
    "save_path":        save_path,
    "logger":           logger,
    'recover_epochs':   0,
    'pruning_type':     pruning_type,
    'stop_epochs':      5,
}

# Set False to train
is_trained = False
if not is_trained:
    losses, train_accuracies, test_accuracies = train(cls_model,
                                                      config,
                                                      pruner,
                                                      True)
    graph_losses_n_accs(losses, 
                        train_accuracies, 
                        test_accuracies)

# Evaluating model
load_weights(cls_model, save_path)
acc = test(cls_model, testloader_c10_cls)
print(f"\nAccuracy of model is: {acc}")

## Sparsity: 0.97

### ViT-B-16

In [0]:
#######################################
########## MAGNITUDE PRUNING ##########
#######################################

# Creating projection models for BaCP framework
model_name = 'vitb16'
                                  
# Projection networks
finetuned_weights = f'/dbfs/{model_name}_weights/{model_name}_sup_c10.pt'
pre_trained_model, current_model, finetuned_model = create_models_for_cap(model_name, finetuned_weights)
print(f"Current model sparsity: {get_model_sparsity(current_model)}")

# Initializing Hyperparameters
optimizer_type = 'adam'
optimizer_cfg = {
    'model':        current_model,
    'lr':           0.0001,
    'momentum':     MOMENTUM,
    'weight_decay': WEIGHT_DECAY
}
optimizer = set_optimizer(optimizer_type, optimizer_cfg)
scheduler = None
criterion = nn.CrossEntropyLoss()
save_path = f'/dbfs/{model_name}_weights/{model_name}_bacp_magp_097'
logger = Logger(model_name, learning_type='bacp')

# Initializing pruner
pruner = LocalMagnitudePrune(1, TARGET_SPARSITY_MID)

config = {
    'model_name':       model_name,
    'optimizer':        optimizer,
    'scheduler':        scheduler,               
    'criterion':        criterion,
    'target_sparsity':  TARGET_SPARSITY_MID,   
    'logger':           logger,
    'batch_size':       BATCH_SIZE,     
    'num_classes':      CIFAR10_CLASSES,
    'lambdas':          [0.25, 0.25, 0.25, 0.25], 
    'save_path':        save_path,
    'pruner':           pruner,
    'recovery_epochs':  0,
    'epochs':           20,
    'pruning_epochs':   1,
}

cap_learner = BaCPLearner(current_model, pre_trained_model, finetuned_model, config)

# Set False to train
is_trained = False
if not is_trained:
    cap_learner.train(trainloader_c10_cl)

In [0]:
# Creating classification model with unfrozen parameters
model_name = 'vitb16'
cls_model = cap_learner.create_classification_net(False)
print(f"Current model sparsity: {get_model_sparsity(cls_model)}")

# Generate masks from model
cap_learner.generate_mask_from_model()

# Initializing Hyperparameters
optimizer_type = 'adam'
optimizer_cfg = {
    'model':        cls_model,
    'lr':           0.0001,
    'momentum':     MOMENTUM,
    'weight_decay': WEIGHT_DECAY
}
optimizer = set_optimizer(optimizer_type, optimizer_cfg)
scheduler = None
criterion = nn.CrossEntropyLoss()
save_path = f'/dbfs/{model_name}_weights/{model_name}_bacp_magp_097_ds.pt'
logger = Logger(model_name, learning_type='cls')

# Initializing pruner
pruner = cap_learner.get_pruner()
pruning_type = 'magnitude_pruning'

config = {
    'trainloader':      trainloader_c10_cls,
    'testloader':       testloader_c10_cls,
    'optimizer':        optimizer,
    'scheduler':        scheduler,
    'criterion':        nn.CrossEntropyLoss(),
    'epochs':           30,
    "batch_size":       BATCH_SIZE,
    "save_path":        save_path,
    "logger":           logger,
    'recover_epochs':   0,
    'pruning_type':     pruning_type,
    'stop_epochs':      5,
}

# Set False to train
is_trained = False
if not is_trained:
    losses, train_accuracies, test_accuracies = train(cls_model,
                                                      config,
                                                      pruner,
                                                      True)
    graph_losses_n_accs(losses, 
                        train_accuracies, 
                        test_accuracies)

# Evaluating model
load_weights(cls_model, save_path)
acc = test(cls_model, testloader_c10_cls)
print(f"\nAccuracy of model is: {acc}")

In [0]:
######################################
########## MOVEMENT PRUNING ##########
######################################

# Creating projection models for BaCP framework
model_name = 'vitb16'
                                  
# Projection networks
finetuned_weights = f'/dbfs/{model_name}_weights/{model_name}_sup_c10.pt'
pre_trained_model, current_model, finetuned_model = create_models_for_cap(model_name, finetuned_weights)
print(f"Current model sparsity: {get_model_sparsity(current_model)}")

# Initializing Hyperparameters
optimizer_type = 'adam'
optimizer_cfg = {
    'model':        current_model,
    'lr':           0.0001,
    'momentum':     MOMENTUM,
    'weight_decay': WEIGHT_DECAY
}
optimizer = set_optimizer(optimizer_type, optimizer_cfg)
scheduler = None
criterion = nn.CrossEntropyLoss()
save_path = f'/dbfs/{model_name}_weights/{model_name}_bacp_mvmp_097'
logger = Logger(model_name, learning_type='bacp')

# Initializing pruner
pruner = LocalMovementPrune(1, TARGET_SPARSITY_MID)

config = {
    'model_name':       model_name,
    'optimizer':        optimizer,
    'scheduler':        scheduler,               
    'criterion':        criterion,
    'target_sparsity':  TARGET_SPARSITY_MID,   
    'logger':           logger,
    'batch_size':       BATCH_SIZE,     
    'num_classes':      CIFAR10_CLASSES,
    'lambdas':          [0.25, 0.25, 0.25, 0.25], 
    'save_path':        save_path,
    'pruner':           pruner,
    'recovery_epochs':  0,
    'epochs':           20,
    'pruning_epochs':   1,
}

cap_learner = BaCPLearner(current_model, pre_trained_model, finetuned_model, config)

# Set False to train
is_trained = False
if not is_trained:
    cap_learner.train(trainloader_c10_cl)

In [0]:
# Creating classification model with unfrozen parameters
model_name = 'vitb16'
cls_model = cap_learner.create_classification_net(False)
print(f"Current model sparsity: {get_model_sparsity(cls_model)}")

# Generate masks from model
cap_learner.generate_mask_from_model()

# Initializing Hyperparameters
optimizer_type = 'adam'
optimizer_cfg = {
    'model':        cls_model,
    'lr':           0.0001,
    'momentum':     MOMENTUM,
    'weight_decay': WEIGHT_DECAY
}
optimizer = set_optimizer(optimizer_type, optimizer_cfg)
scheduler = None
criterion = nn.CrossEntropyLoss()
save_path = f'/dbfs/{model_name}_weights/{model_name}_bacp_mvmp_097_ds.pt'
logger = Logger(model_name, learning_type='cls')

# Initializing pruner
pruner = cap_learner.get_pruner()
pruning_type = 'movement_pruning'

config = {
    'trainloader':      trainloader_c10_cls,
    'testloader':       testloader_c10_cls,
    'optimizer':        optimizer,
    'scheduler':        scheduler,
    'criterion':        nn.CrossEntropyLoss(),
    'epochs':           30,
    "batch_size":       BATCH_SIZE,
    "save_path":        save_path,
    "logger":           logger,
    'recover_epochs':   0,
    'pruning_type':     pruning_type,
    'stop_epochs':      5,
}

# Set False to train
is_trained = False
if not is_trained:
    losses, train_accuracies, test_accuracies = train(cls_model,
                                                      config,
                                                      pruner,
                                                      True)
    graph_losses_n_accs(losses, 
                        train_accuracies, 
                        test_accuracies)

# Evaluating model
load_weights(cls_model, save_path)
acc = test(cls_model, testloader_c10_cls)
print(f"\nAccuracy of model is: {acc}")

## Sparsity: 0.99

### ViT-B-16

In [0]:
#######################################
########## MAGNITUDE PRUNING ##########
#######################################

# Creating projection models for BaCP framework
model_name = 'vitb16'
                                  
# Projection networks
finetuned_weights = f'/dbfs/{model_name}_weights/{model_name}_sup_c10.pt'
pre_trained_model, current_model, finetuned_model = create_models_for_cap(model_name, finetuned_weights)
print(f"Current model sparsity: {get_model_sparsity(current_model)}")

# Initializing Hyperparameters
optimizer_type = 'adam'
optimizer_cfg = {
    'model':        current_model,
    'lr':           0.0001,
    'momentum':     MOMENTUM,
    'weight_decay': WEIGHT_DECAY
}
optimizer = set_optimizer(optimizer_type, optimizer_cfg)
scheduler = None
criterion = nn.CrossEntropyLoss()
save_path = f'/dbfs/{model_name}_weights/{model_name}_bacp_magp_099'
logger = Logger(model_name, learning_type='bacp')

# Initializing pruner
pruner = LocalMagnitudePrune(1, TARGET_SPARSITY_HIGH)

config = {
    'model_name':       model_name,
    'optimizer':        optimizer,
    'scheduler':        scheduler,               
    'criterion':        criterion,
    'target_sparsity':  TARGET_SPARSITY_HIGH,   
    'logger':           logger,
    'batch_size':       BATCH_SIZE,     
    'num_classes':      CIFAR10_CLASSES,
    'lambdas':          [0.25, 0.25, 0.25, 0.25], 
    'save_path':        save_path,
    'pruner':           pruner,
    'recovery_epochs':  0,
    'epochs':           20,
    'pruning_epochs':   1,
}

cap_learner = BaCPLearner(current_model, pre_trained_model, finetuned_model, config)

# Set False to train
is_trained = False
if not is_trained:
    cap_learner.train(trainloader_c10_cl)

In [0]:
# Creating classification model with unfrozen parameters
model_name = 'vitb16'
cls_model = cap_learner.create_classification_net(False)
print(f"Current model sparsity: {get_model_sparsity(cls_model)}")

# Generate masks from model
cap_learner.generate_mask_from_model()

# Initializing Hyperparameters
optimizer_type = 'adam'
optimizer_cfg = {
    'model':        cls_model,
    'lr':           0.0001,
    'momentum':     MOMENTUM,
    'weight_decay': WEIGHT_DECAY
}
optimizer = set_optimizer(optimizer_type, optimizer_cfg)
scheduler = None
criterion = nn.CrossEntropyLoss()
save_path = f'/dbfs/{model_name}_weights/{model_name}_bacp_magp_099_ds.pt'
logger = Logger(model_name, learning_type='cls')

# Initializing pruner
pruner = cap_learner.get_pruner()
pruning_type = 'magnitude_pruning'

config = {
    'trainloader':      trainloader_c10_cls,
    'testloader':       testloader_c10_cls,
    'optimizer':        optimizer,
    'scheduler':        scheduler,
    'criterion':        nn.CrossEntropyLoss(),
    'epochs':           30,
    "batch_size":       BATCH_SIZE,
    "save_path":        save_path,
    "logger":           logger,
    'recover_epochs':   0,
    'pruning_type':     pruning_type,
    'stop_epochs':      5,
}

# Set False to train
is_trained = False
if not is_trained:
    losses, train_accuracies, test_accuracies = train(cls_model,
                                                      config,
                                                      pruner,
                                                      True)
    graph_losses_n_accs(losses, 
                        train_accuracies, 
                        test_accuracies)

# Evaluating model
load_weights(cls_model, save_path)
acc = test(cls_model, testloader_c10_cls)
print(f"\nAccuracy of model is: {acc}")

In [0]:
######################################
########## MOVEMENT PRUNING ##########
######################################

# Creating projection models for BaCP framework
model_name = 'vitb16'
                                  
# Projection networks
finetuned_weights = f'/dbfs/{model_name}_weights/{model_name}_sup_c10.pt'
pre_trained_model, current_model, finetuned_model = create_models_for_cap(model_name, finetuned_weights)
print(f"Current model sparsity: {get_model_sparsity(current_model)}")

# Initializing Hyperparameters
optimizer_type = 'adam'
optimizer_cfg = {
    'model':        current_model,
    'lr':           0.0001,
    'momentum':     MOMENTUM,
    'weight_decay': WEIGHT_DECAY
}
optimizer = set_optimizer(optimizer_type, optimizer_cfg)
scheduler = None
criterion = nn.CrossEntropyLoss()
save_path = f'/dbfs/{model_name}_weights/{model_name}_bacp_mvmp_099'
logger = Logger(model_name, learning_type='bacp')

# Initializing pruner
pruner = LocalMovementPrune(1, TARGET_SPARSITY_HIGH)

config = {
    'model_name':       model_name,
    'optimizer':        optimizer,
    'scheduler':        scheduler,               
    'criterion':        criterion,
    'target_sparsity':  TARGET_SPARSITY_HIGH,   
    'logger':           logger,
    'batch_size':       BATCH_SIZE,     
    'num_classes':      CIFAR10_CLASSES,
    'lambdas':          [0.25, 0.25, 0.25, 0.25], 
    'save_path':        save_path,
    'pruner':           pruner,
    'recovery_epochs':  0,
    'epochs':           20,
    'pruning_epochs':   1,
}

cap_learner = BaCPLearner(current_model, pre_trained_model, finetuned_model, config)

# Set False to train
is_trained = False
if not is_trained:
    cap_learner.train(trainloader_c10_cl)

In [0]:
# Creating classification model with unfrozen parameters
model_name = 'vitb16'
cls_model = cap_learner.create_classification_net(False)
print(f"Current model sparsity: {get_model_sparsity(cls_model)}")

# Generate masks from model
cap_learner.generate_mask_from_model()

# Initializing Hyperparameters
optimizer_type = 'adam'
optimizer_cfg = {
    'model':        cls_model,
    'lr':           0.0001,
    'momentum':     MOMENTUM,
    'weight_decay': WEIGHT_DECAY
}
optimizer = set_optimizer(optimizer_type, optimizer_cfg)
scheduler = None
criterion = nn.CrossEntropyLoss()
save_path = f'/dbfs/{model_name}_weights/{model_name}_bacp_mvmp_099_ds.pt'
logger = Logger(model_name, learning_type='cls')

# Initializing pruner
pruner = cap_learner.get_pruner()
pruning_type = 'movement_pruning'

config = {
    'trainloader':      trainloader_c10_cls,
    'testloader':       testloader_c10_cls,
    'optimizer':        optimizer,
    'scheduler':        scheduler,
    'criterion':        nn.CrossEntropyLoss(),
    'epochs':           30,
    "batch_size":       BATCH_SIZE,
    "save_path":        save_path,
    "logger":           logger,
    'recover_epochs':   0,
    'pruning_type':     pruning_type,
    'stop_epochs':      5,
}

# Set False to train
is_trained = False
if not is_trained:
    losses, train_accuracies, test_accuracies = train(cls_model,
                                                      config,
                                                      pruner,
                                                      True)
    graph_losses_n_accs(losses, 
                        train_accuracies, 
                        test_accuracies)

# Evaluating model
load_weights(cls_model, save_path)
acc = test(cls_model, testloader_c10_cls)
print(f"\nAccuracy of model is: {acc}")