In [0]:
%restart_python

In [0]:
%load_ext autoreload
%autoreload 2

In [0]:
import torch.nn as nn
import torch.optim as optim

from contrastive_learning import ContrastiveLearner
from models import EncoderProjectionNetwork, ClassificationNetwork
from datasets_class import CreateDatasets
from supervised_learning import train, test
from unstructured_pruning import MagnitudePrune, MovementPrune, LocalMagnitudePrune, LocalMovementPrune
from torch.utils.data import DataLoader

from logger import Logger
from utils import *


In [0]:
BATCH_SIZE = 512
NUM_WORKERS = 24
SIZE = 224

TRAINING_EPOCHS = 3
BIG_TRAINING_EPOCHS = 6
FINETUNING_EPOCHS = 5
CIFAR10_CLASSES = 10
PRUNING_EPOCHS = 5

TARGET_SPARSITY = 0.95

TEMPERATURE = 0.07
BASE_TEMPERATURE = 0.07

LR_VITB16 = 0.00001
LR_VITL16 = 0.000005

WEIGHT_DECAY = 0.001


In [0]:
datasets = CreateDatasets('./data')

# Data for supervised learning
trainset_c10_cls_fn, testset_c10_cls_fn = datasets.get_dataset_fn('supervised', 'cifar10', SIZE)
trainset_c10_cls, testset_c10_cls = trainset_c10_cls_fn(), testset_c10_cls_fn()

trainloader_c10_cls = DataLoader(trainset_c10_cls, BATCH_SIZE, shuffle=True)
testloader_c10_cls = DataLoader(testset_c10_cls, BATCH_SIZE, shuffle=False)

# 2-view augmented data for supervised contrastive learning+
trainset_c10_cl_fn, testset_c10_cl_fn = datasets.get_dataset_fn('supcon', 'cifar10', SIZE)
trainset_c10_cl, testset_c10_cl = trainset_c10_cl_fn(), testset_c10_cl_fn()

trainloader_c10_cl = DataLoader(trainset_c10_cl, BATCH_SIZE, shuffle=True)
testloader_c10_cl = DataLoader(testset_c10_cl, BATCH_SIZE, shuffle=False)


# Baseline Accuracies

## ViT-Base-16

In [0]:
# Creating a classification net for downstream task
model_name = 'vitb16'    
cls_model = ClassificationNetwork(model_name, CIFAR10_CLASSES, False).to(get_device())

# Initializing Hyperparameters
optimizer_type = 'adam'
optimizer_cfg = {
    'model':        cls_model,
    'lr':           LR_VITB16,
    'momentum':     MOMENTUM,
    'weight_decay': WEIGHT_DECAY,
}
optimizer = set_optimizer(optimizer_type, optimizer_cfg)
scheduler = None
criterion = nn.CrossEntropyLoss()
save_path = f'/dbfs/{model_name}_weights/{model_name}_sup_c10.pt'
logger = Logger(model_name, learning_type='cls')

config = {
    'trainloader':      trainloader_c10_cls,
    'testloader':       testloader_c10_cls,
    'optimizer':        optimizer,
    'scheduler':        scheduler,
    'criterion':        nn.CrossEntropyLoss(),
    'epochs':           TRAINING_EPOCHS,
    "batch_size":       BATCH_SIZE,
    "save_path":        save_path,
    "logger":           logger,
    "lambda_reg":       0,
    'recover_epochs':   0,
    'pruning_type':     "",
}

# Set True if trained
is_trained = True
if not is_trained:
    losses, train_accuracies, test_accuracies = train(cls_model,
                                                      config)
    graph_losses_n_accs(losses, train_accuracies, test_accuracies)
    
# Evaluating model
load_weights(cls_model, save_path)
acc = test(cls_model, testloader_c10_cls)
print(f"\nAccuracy of model is: {acc}")

## ViT-Large-16

In [0]:
# Creating a classification net for downstream task
model_name = 'vitl16'    
cls_model = ClassificationNetwork(model_name, CIFAR10_CLASSES, False).to(get_device())

# Initializing Hyperparameters
optimizer_type = 'adam'
optimizer_cfg = {
    'model':        cls_model,
    'lr':           LR_VITL16,
    'momentum':     MOMENTUM,
    'weight_decay': WEIGHT_DECAY,
}
optimizer = set_optimizer(optimizer_type, optimizer_cfg)
scheduler = None
criterion = nn.CrossEntropyLoss()
save_path = f'/dbfs/{model_name}_weights/{model_name}_sup_c10.pt'
load_weights(cls_model, save_path)
logger = Logger(model_name, learning_type='cls')

config = {
    'trainloader':      trainloader_c10_cls,
    'testloader':       testloader_c10_cls,
    'optimizer':        optimizer,
    'scheduler':        scheduler,
    'criterion':        nn.CrossEntropyLoss(),
    'epochs':           TRAINING_EPOCHS,
    "batch_size":       BATCH_SIZE,
    "save_path":        save_path,
    "logger":           logger,
    "lambda_reg":       0,
    'recover_epochs':   0,
}

# Set True if trained
is_trained = True
if not is_trained:
    losses, train_accuracies, test_accuracies = train(cls_model,
                                                      config)
    graph_losses_n_accs(losses, train_accuracies, test_accuracies)
    
# Evaluating model
load_weights(cls_model, save_path)
acc = test(cls_model, testloader_c10_cls)
print(f"\nAccuracy of model is: {acc}")

# Pruning Accuracies

## ViT-Base-16

### Magnitude Prune

In [0]:
# Creating classification model
model_name = 'vitb16'    
cls_model = ClassificationNetwork(model_name, CIFAR10_CLASSES, False).to(get_device())
load_weights(cls_model, f'/dbfs/{model_name}_weights/{model_name}_sup_c10.pt')

# Initializing Hyperparameters
optimizer_type = 'adam'
optimizer_cfg = {
    'model':        cls_model,
    'lr':           0.0001,
    'momentum':     MOMENTUM,
    'weight_decay': WEIGHT_DECAY
}
optimizer = set_optimizer(optimizer_type, optimizer_cfg)
scheduler = None
criterion = nn.CrossEntropyLoss()
logger = Logger(model_name, learning_type='cls')
save_path = f'/dbfs/{model_name}_weights/{model_name}_sup_magp_c10.pt'

# Initialing pruning method
# pruner = MagnitudePrune(PRUNING_EPOCHS, TARGET_SPARSITY_LOW)
# pruner = LocalMagnitudePrune(1, 0.965)
pruner = LocalMovementPrune(1, 0.965, target_layer='self_attention')

config = {
    'trainloader':      trainloader_c10_cls,
    'testloader':       testloader_c10_cls,
    'optimizer':        optimizer,
    'scheduler':        scheduler,
    'criterion':        nn.CrossEntropyLoss(),
    'epochs':           1,
    "batch_size":       BATCH_SIZE,
    "save_path":        save_path,
    "logger":           logger,
    'lambda_reg':       0,
    'recover_epochs':   5,
    'pruning_type':    'local_magnitude_pruning',
}

# Set False to train
is_trained = False
if not is_trained:
    losses, train_accuracies, test_accuracies = train(cls_model,
                                                      config,
                                                      pruner)
    torch.save(cls_model.state_dict(), save_path)
    
    # Displaying loss-acc graph
    graph_losses_n_accs(losses, 
                        train_accuracies, 
                        test_accuracies)

# Evaluating model
load_weights(cls_model, save_path)
print(f"Sparsity of pruned model: {get_model_sparsity(cls_model):.3f}")
pruned_acc = test(cls_model, testloader_c10_cls)
print(f"Accuracy of pruned model is: {pruned_acc}")


In [0]:
for name, param in cls_model.named_parameters():
    # if param.dim() > 1:
    print(name, (1-(torch.count_nonzero(param).item() / param.numel())))

### Movement Prune

In [0]:
# Creating classification model
model_name = 'vitb16'    
cls_model = ClassificationNetwork(model_name, CIFAR10_CLASSES, False).to(get_device())
load_weights(cls_model, f'/dbfs/{model_name}_weights/{model_name}_sup_c10.pt')

# Initializing Hyperparameters
optimizer_type = 'adam'
optimizer_cfg = {
    'model':        cls_model,
    'lr':           LR_VITB16,
    'momentum':     MOMENTUM,
    'weight_decay': WEIGHT_DECAY
}
optimizer = set_optimizer(optimizer_type, optimizer_cfg)
scheduler = None
criterion = nn.CrossEntropyLoss()
logger = Logger(model_name, learning_type='cls')
save_path = f'/dbfs/{model_name}_weights/{model_name}_sup_mvmp_c10.pt'

# Initialing pruning method
# pruner = MovementPrune(PRUNING_EPOCHS, TARGET_SPARSITY_LOW)
pruner = LocalMovementPrune(PRUNING_EPOCHS, 0.965)

config = {
    'trainloader':      trainloader_c10_cls,
    'testloader':       testloader_c10_cls,
    'optimizer':        optimizer,
    'scheduler':        scheduler,
    'criterion':        nn.CrossEntropyLoss(),
    'epochs':           PRUNING_EPOCHS,
    "batch_size":       BATCH_SIZE,
    "save_path":        save_path,
    "logger":           logger,
    'lambda_reg':       0,
    'recover_epochs':   1,
    'pruning_type':    'local_movement_prune',
}

# Set False to train
is_trained = False
if not is_trained:
    losses, train_accuracies, test_accuracies = train(cls_model,
                                                      config,
                                                      pruner)
    torch.save(cls_model.state_dict(), save_path)
    
    # Displaying loss-acc graph
    graph_losses_n_accs(losses, 
                        train_accuracies, 
                        test_accuracies)

# Evaluating model
load_weights(cls_model, save_path)
print(f"Sparsity of pruned model: {get_model_sparsity(cls_model):.3f}")
pruned_acc = test(cls_model, testloader_c10_cls)
print(f"Accuracy of pruned model is: {pruned_acc}")


## ViT-Large-16

### Magnitude Prune

In [0]:
# Creating classification model
model_name = 'vitl16'    
cls_model = ClassificationNetwork(model_name, CIFAR10_CLASSES, False).to(get_device())
load_weights(cls_model, f'/dbfs/{model_name}_weights/{model_name}_sup_c10.pt')

# Initializing Hyperparameters
optimizer_type = 'adam'
optimizer_cfg = {
    'model':        cls_model,
    'lr':           LR_VITL16,
    'momentum':     MOMENTUM,
    'weight_decay': WEIGHT_DECAY
}
optimizer = set_optimizer(optimizer_type, optimizer_cfg)
scheduler = None
criterion = nn.CrossEntropyLoss()
logger = Logger(model_name, learning_type='cls')
save_path = f'/dbfs/{model_name}_weights/{model_name}_sup_magp_c10.pt'

# Initialing pruning method
# pruner = MagnitudePrune(PRUNING_EPOCHS, TARGET_SPARSITY_LOW)
pruner = LocalMagnitudePrune(PRUNING_EPOCHS, 0.965)

config = {
    'trainloader':      trainloader_c10_cls,
    'testloader':       testloader_c10_cls,
    'optimizer':        optimizer,
    'scheduler':        scheduler,
    'criterion':        nn.CrossEntropyLoss(),
    'epochs':           PRUNING_EPOCHS,
    "batch_size":       BATCH_SIZE,
    "save_path":        save_path,
    "logger":           logger,
    'lambda_reg':       0,
    'recover_epochs':   1,
    'pruning_type':    'local_magnitude_prune',
}

# Set False to train
is_trained = False
if not is_trained:
    losses, train_accuracies, test_accuracies = train(cls_model,
                                                      config,
                                                      pruner)
    torch.save(cls_model.state_dict(), save_path)
    
    # Displaying loss-acc graph
    graph_losses_n_accs(losses, 
                        train_accuracies, 
                        test_accuracies)

# Evaluating model
load_weights(cls_model, save_path)
print(f"Sparsity of pruned model: {get_model_sparsity(cls_model):.3f}")
pruned_acc = test(cls_model, testloader_c10_cls)
print(f"Accuracy of pruned model is: {pruned_acc}")


### MovementPrune

In [0]:
# Creating classification model
model_name = 'vitl16'    
cls_model = ClassificationNetwork(model_name, CIFAR10_CLASSES, False).to(get_device())
load_weights(cls_model, f'/dbfs/{model_name}_weights/{model_name}_sup_c10.pt')

# Initializing Hyperparameters
optimizer_type = 'adam'
optimizer_cfg = {
    'model':        cls_model,
    'lr':           LR_VITL16,
    'momentum':     MOMENTUM,
    'weight_decay': WEIGHT_DECAY
}
optimizer = set_optimizer(optimizer_type, optimizer_cfg)
scheduler = None
criterion = nn.CrossEntropyLoss()
logger = Logger(model_name, learning_type='cls')
save_path = f'/dbfs/{model_name}_weights/{model_name}_sup_mvmp_c10.pt'

# Initialing pruning method
# pruner = MovementPrune(PRUNING_EPOCHS, TARGET_SPARSITY_LOW)
pruner = LocalMovementPrune(PRUNING_EPOCHS, 0.965)

config = {
    'trainloader':      trainloader_c10_cls,
    'testloader':       testloader_c10_cls,
    'optimizer':        optimizer,
    'scheduler':        scheduler,
    'criterion':        nn.CrossEntropyLoss(),
    'epochs':           PRUNING_EPOCHS,
    "batch_size":       BATCH_SIZE,
    "save_path":        save_path,
    "logger":           logger,
    'lambda_reg':       0,
    'recover_epochs':   1,
    'pruning_type':    'local_movement_prune',
}

# Set False to train
is_trained = False
if not is_trained:
    losses, train_accuracies, test_accuracies = train(cls_model,
                                                      config,
                                                      pruner)
    torch.save(cls_model.state_dict(), save_path)
    
    # Displaying loss-acc graph
    graph_losses_n_accs(losses, 
                        train_accuracies, 
                        test_accuracies)

# Evaluating model
load_weights(cls_model, save_path)
print(f"Sparsity of pruned model: {get_model_sparsity(cls_model):.3f}")
pruned_acc = test(cls_model, testloader_c10_cls)
print(f"Accuracy of pruned model is: {pruned_acc}")


# SupCon Learning

In [0]:
# Initializing projection model for supervised-contrastive learning
model_name = 'vitb16'    
projection_model = EncoderProjectionNetwork(model_name, 128)

# Initializing Hyperparameters
# optimizer_type = 'adam'
# optimizer_cfg = {
#     'model':        projection_model,
#     'lr':           LR_VITB16,
#     'momentum':     MOMENTUM,
#     'weight_decay': WEIGHT_DECAY,
# }

optimizer_type = 'sgd'
optimizer_cfg = {
    'model':        projection_model,
    'lr':           LR_VITB16,
    'momentum':     MOMENTUM,
    'weight_decay': WEIGHT_DECAY
}
optimizer = set_optimizer(optimizer_type, optimizer_cfg)
scheduler = None
save_path = f'/dbfs/{model_name}_weights/{model_name}_supcon_c10.pt'
logger = Logger(model_name, learning_type='supcon')

config = {
    'n_views':          2,
    'optimizer':        optimizer,
    'epochs':           TRAINING_EPOCHS,
    'scheduler':        scheduler,
    'batch_size':       BATCH_SIZE,
    'temperature':      TEMP,
    'base_temperature': BASE_TEMP,
    'loss_type':        'supcon',
}

supcon_learner = ContrastiveLearner(projection_model, config)

# Set if trained
is_trained = False
if not is_trained:
    supcon_learner.train(trainloader_c10_cl, save_path, logger)


In [0]:
# Creating a classification net for downstream task
model_name = 'vitb16'    
cls_model = ClassificationNetwork(model_name, CIFAR10_CLASSES, False).to(get_device())
cls_model.to(get_device())

# Loading the projection models backbone weights into the new classification net
load_projection_model_weights(cls_model, f'/dbfs/{model_name}_weights/{model_name}_supcon_c10.pt')

# Initializing Hyperparameters
optimizer_type = 'adam'
optimizer_cfg = {
    'model':        cls_model,
    'lr':           0.0001,
    'momentum':     MOMENTUM,
    'weight_decay': WEIGHT_DECAY
}
optimizer = set_optimizer(optimizer_type, optimizer_cfg)
scheduler = None
criterion = nn.CrossEntropyLoss()
save_path = f'/dbfs/{model_name}_weights/{model_name}_sup_c10.pt'
logger = Logger(model_name, learning_type='cls')

config = {
    'trainloader':      trainloader_c10_cls,
    'testloader':       testloader_c10_cls,
    'optimizer':        optimizer,
    'scheduler':        scheduler,
    'criterion':        nn.CrossEntropyLoss(),
    'epochs':           LINEAR_EPOCHS,
    "batch_size":       BATCH_SIZE,
    "save_path":        save_path,
    "logger":           logger,
    "lambda_reg":       5e-4,
    'recover_epochs':   0,
    'pruning_type':     "",
    'stop_epochs':     25,
}

# Set True if trained
is_trained = False
if not is_trained:
    losses, train_accuracies, test_accuracies = train(cls_model,
                                                      config)
    graph_losses_n_accs(losses, train_accuracies, test_accuracies)
    
# Evaluating model
load_weights(cls_model, save_path)
acc = test(cls_model, testloader_c10_cls)
print(f"\nAccuracy of model is: {acc}")

# BaCP Accuracies

## ViT-Base-16

### Magnitude Prune

In [0]:
from bacp import create_models_for_cap
from bacp import BaCPLearner
# Creating projection models for BaCP framework
model_name = 'vitb16'
                                  
# Projection networks
finetuned_weights = f'/dbfs/{model_name}_weights/{model_name}_supcon_c10.pt'
pre_trained_model, current_model, finetuned_model = create_models_for_cap(model_name, finetuned_weights)

# Fine-tuned classification network
cls_model = ClassificationNetwork(model_name, CIFAR10_CLASSES, False).to(get_device())
load_weights(cls_model, f'/dbfs/{model_name}_weights/{model_name}_sup_c10.pt')
print(f"Current model sparsity: {get_model_sparsity(current_model)}")

# Initializing Hyperparameters
optimizer_type = 'sgd'
optimizer_cfg = {
    'model':        current_model,
    'lr':           LR_VITB16,
    'momentum':     MOMENTUM,
    'weight_decay': WEIGHT_DECAY
}
optimizer = set_optimizer(optimizer_type, optimizer_cfg)
scheduler = None
criterion = nn.CrossEntropyLoss()
save_path = f'/dbfs/{model_name}_weights/{model_name}_bacp_magp_c10'
logger = Logger(model_name, learning_type='bacp')

# Initializing pruner
# pruner = MagnitudePrune(5, TARGET_SPARSITY_LOW)
pruner = LocalMagnitudePrune(5, 0.965)

config = {
    'model_name':       model_name,
    'n_views':          2,
    'optimizer':        optimizer,
    'scheduler':        scheduler,               
    'criterion':        criterion,
    'temperature':      TEMP,
    'base_temperature': BASE_TEMP,
    'target_sparsity':  TARGET_SPARSITY_MID,   
    'logger':           logger,
    'epochs':           BACP_EPOCHS,         
    'batch_size':       BATCH_SIZE,     
    'num_classes':      CIFAR10_CLASSES,    # Change this based on dataloader
    'lambdas':          LAMBDAS,            # None lambas => lambdas become learnable parameters
    'save_path':        save_path,
    'pruner':           pruner,
}

cap_learner = BaCPLearner(current_model, pre_trained_model, finetuned_model, cls_model, config)

# Set False to train
is_trained = False
if not is_trained:
    cap_learner.cap_train(trainloader_c10_cl)

In [0]:
# Creating classification model with unfrozen parameters
model_name = 'vitb16'
cls_model = cap_learner.create_classification_net(False)
print(f"Current model sparsity: {get_model_sparsity(cls_model)}")

# Generate masks from model
cap_learner.generate_mask_from_model()

# Initializing Hyperparameters
optimizer_type = 'adam'
optimizer_cfg = {
    'model':        cls_model,
    'lr':           LR_VITB16,
    'momentum':     MOMENTUM,
    'weight_decay': WEIGHT_DECAY
}
optimizer = set_optimizer(optimizer_type, optimizer_cfg)
scheduler = None
criterion = nn.CrossEntropyLoss()
save_path = f'/dbfs/{model_name}_weights/{model_name}_bacp_magp_cls_c10.pt'
logger = Logger(model_name, learning_type='cls')

# Initializing pruner
pruner = cap_learner.get_pruner()
pruning_type = 'magnitude_pruning'

config = {
    'trainloader':      trainloader_c10_cls,
    'testloader':       testloader_c10_cls,
    'optimizer':        optimizer,
    'scheduler':        scheduler,
    'criterion':        nn.CrossEntropyLoss(),
    'epochs':           LINEAR_EPOCHS,
    "batch_size":       BATCH_SIZE,
    "save_path":        save_path,
    "logger":           logger,
    'lambda_reg':       0,
    'recover_epochs':   0,
    'pruning_type':     pruning_type,
}

# Set False to train
is_trained = False
if not is_trained:
    losses, train_accuracies, test_accuracies = train(cls_model,
                                                      config,
                                                      pruner,
                                                      True)
    graph_losses_n_accs(losses, 
                        train_accuracies, 
                        test_accuracies)

# Evaluating model
load_weights(cls_model, save_path)
acc = test(cls_model, testloader_c10_cls)
print(f"\nAccuracy of model is: {acc}")

### Movement Prune

In [0]:
from bacp import create_models_for_cap
from bacp import BaCPLearner
# Creating projection models for BaCP framework
model_name = 'vitb16'
                                  
# Projection networks
finetuned_weights = f'/dbfs/{model_name}_weights/{model_name}_supcon_c10.pt'
pre_trained_model, current_model, finetuned_model = create_models_for_cap(model_name, finetuned_weights)

# Fine-tuned classification network
cls_model = ClassificationNetwork(model_name, CIFAR10_CLASSES, False).to(get_device())
load_weights(cls_model, f'/dbfs/{model_name}_weights/{model_name}_sup_c10.pt')

print(f"Current model sparsity: {get_model_sparsity(current_model)}")

# Initializing Hyperparameters
optimizer_type = 'sgd'
optimizer_cfg = {
    'model':        current_model,
    'lr':           LR_VITB16,
    'momentum':     MOMENTUM,
    'weight_decay': WEIGHT_DECAY
}
optimizer = set_optimizer(optimizer_type, optimizer_cfg)
scheduler = None
criterion = nn.CrossEntropyLoss()
save_path = f'/dbfs/{model_name}_weights/{model_name}_bacp_mvmp_c10'
logger = Logger(model_name, learning_type='bacp')

# Initializing pruner
pruner = MovementPrune(5, TARGET_SPARSITY_LOW)

config = {
    'model_name':       model_name,
    'n_views':          2,
    'optimizer':        optimizer,
    'scheduler':        scheduler,               
    'criterion':        criterion,
    'temperature':      TEMP,
    'base_temperature': BASE_TEMP,
    'target_sparsity':  TARGET_SPARSITY_MID,   
    'logger':           logger,
    'epochs':           BACP_EPOCHS,         
    'batch_size':       BATCH_SIZE,     
    'num_classes':      CIFAR10_CLASSES,    # Change this based on dataloader
    'lambdas':          LAMBDAS,            # None lambas => lambdas become learnable parameters
    'save_path':        save_path,
    'pruner':           pruner,
}

cap_learner = BaCPLearner(current_model, pre_trained_model, finetuned_model, cls_model, config)

# Set False to train
is_trained = True
if not is_trained:
    cap_learner.cap_train(trainloader_c10_cl)

In [0]:
# Creating classification model with unfrozen parameters
model_name = 'vitb16'
cls_model = cap_learner.create_classification_net(False)
print(f"Current model sparsity: {get_model_sparsity(cls_model)}")

# Generate masks from model
cap_learner.generate_mask_from_model()

# Initializing Hyperparameters
optimizer_type = 'adam'
optimizer_cfg = {
    'model':        cls_model,
    'lr':           0.0001,
    'momentum':     MOMENTUM,
    'weight_decay': WEIGHT_DECAY
}
optimizer = set_optimizer(optimizer_type, optimizer_cfg)
scheduler = None
criterion = nn.CrossEntropyLoss()
save_path = f'/dbfs/{model_name}_weights/{model_name}_bacp_mvmp_cls_c10_v2.pt'
# load_weights(cls_model, save_path)

logger = Logger(model_name, learning_type='cls')

# Initializing pruner
pruner = cap_learner.get_pruner()
pruning_type = 'movement_pruning'

config = {
    'trainloader':      trainloader_c10_cls,
    'testloader':       testloader_c10_cls,
    'optimizer':        optimizer,
    'scheduler':        scheduler,
    'criterion':        nn.CrossEntropyLoss(),
    'epochs':           LINEAR_EPOCHS,
    "batch_size":       BATCH_SIZE,
    "save_path":        save_path,
    "logger":           logger,
    'lambda_reg':       0,
    'recover_epochs':   0,
    'pruning_type':     pruning_type,
}

# Set False to train
is_trained = False
if not is_trained:
    losses, train_accuracies, test_accuracies = train(cls_model,
                                                      config,
                                                      pruner,
                                                      True)
    graph_losses_n_accs(losses, 
                        train_accuracies, 
                        test_accuracies)

# Evaluating model
load_weights(cls_model, save_path)
acc = test(cls_model, testloader_c10_cls)
print(f"\nAccuracy of model is: {acc}")