## 1- Definition of arguments for function usage

In [None]:
import sys
import torch
import torch.optim as optim
import torchvision
from torchvision import transforms
from utils import *
import argparse
sys.argv = ['']

parser = argparse.ArgumentParser(description='Parameters training')
parser.add_argument('--model_architecture', type=str, default="VGG16", help='....')
parser.add_argument('--method', type=str, default="random", help='....')
parser.add_argument('--dataset', type=str, default="CIFAR10", help='....')
parser.add_argument('--batch_size', type=int, default=8, help='....')
parser.add_argument('--num_epochs', type=int, default=40, help='....')
parser.add_argument('--learning_rate', type=float, default=1e-3, help='....')
parser.add_argument('--optimizer_val', type=str, default="SGD", help='....')
parser.add_argument('--model_type', type=str, default="UNPRUNED", help='....')
parser.add_argument('--device', type=str, default=None, help='....')
parser.add_argument('--model_input', default=torch.ones((1, 3, 224, 224)), help='....')
parser.add_argument('--eval_metric', default="accuracy", help='....')
parser.add_argument('--pruning_seed', type=int, default=23, help='....')
parser.add_argument('--list_pruning', type=list, default = [0.6,0.6,0.53,0.53,0.4,0.4,0.4,0.5,0.5,0.5,0.6,0.6,0.6,0.5,0.5,0], help='....')

args = parser.parse_args()

if args.device is None:
    import torch
    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Modify for training
#args.model_architecture = "ResNet18"
#args.num_epochs = 40

#methods:  weight, SenpisFaster, random
args.method = 'SenpisFaster'
args.dataset = "Date_Fruit_7classes"
args.eval_metric = "f1_score"
custom_split = 1    # 1: if dataset is already divided into train and test folders, 
                    # 0: if all the images of the dataset are in a single folder



## Get Model, DATASET and TRAIN

In [None]:
if args.method != 'SenpisFaster':
    train_loader, test_loader, num_classes, _ = get_dataset(args, custom_split = custom_split)
    trainset = None
else:
    train_loader, test_loader, num_classes, trainset = get_dataset(args, custom_split = custom_split)

In [None]:
model = get_model(num_classes, args)

In [None]:
train_model(train_loader = train_loader,
            test_loader = test_loader,
            model = model,
            num_classes = num_classes,
            args = args)

## Pruning with multiple seeds

In [None]:
#dist 20%
dict_distri = {
                "HOMOGENEA":[0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0],
                "CRECIENTE" : [0.15,0.15,0.17,0.17,0.2,0.2,0.2,0.22,0.22,0.22,0.3,0.3,0.31,0.2,0.2,0],
                "DECRECIENTE" : [0.35,0.35,0.25,0.25,0.2,0.2,0.2,0.13,0.13,0.12,0.1,0.1,0.09,0.2,0.2,0],
                "MENOS_MAS_MENOS" : [0.15,0.15,0.15,0.15,0.34,0.34,0.34,0.15,0.15,0.15,0.1,0.09,0.09,0.2,0.2,0],
                "MAS_MENOS_MAS" : [0.35,0.35,0.2,0.2,0.11,0.11,0.1,0.19,0.2,0.2,0.31,0.31,0.31,0.2,0.2,0]}

base_percentage = 20

In [None]:
#dist 30%
dict_distri = {"HOMOGENEA" : [0.3,0.3,0.3,0.3,0.3,0.3,0.3,0.3,0.3,0.3,0.3,0.3,0.3,0.3,0.3,0],
                "CRECIENTE" : [0.15,0.15,0.2,0.2,0.3,0.3,0.3,0.43,0.43,0.44,0.45,0.45,0.46,0.3,0.3,0],
                "DECRECIENTE" : [0.45,0.45,0.35,0.35,0.3,0.3,0.3,0.25,0.25,0.25,0.16,0.16,0.13,0.3,0.3,0],
                "MENOS_MAS_MENOS" : [0.15,0.15,0.3,0.3,0.45,0.45,0.45,0.28,0.28,0.27,0.2,0.19,0.15,0.3,0.3,0],
                "MAS_MENOS_MAS" : [0.4,0.4,0.2,0.2,0.3,0.3,0.3,0.3,0.3,0.3,0.37,0.37,0.36,0.3,0.3,0]}

base_percentage = 30

In [None]:
#dist 50%
dict_distri = {"HOMOGENEA" : [0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0],
                "CRECIENTE" : [0.15,0.35,0.4,0.4,0.5,0.5,0.5,0.7,0.7,0.7,0.7,0.7,0.7,0.5,0.5,0],
                "DECRECIENTE" : [0.65,0.65,0.55,0.55,0.5,0.5,0.5,0.45,0.44,0.43,0.36,0.35,0.35,0.5,0.5,0],
                "MENOS_MAS_MENOS" : [0.3,0.3,0.42,0.42,0.65,0.65,0.65,0.6,0.6,0.6,0.3,0.3,0.3,0.5,0.5,0],
                "MAS_MENOS_MAS" : [0.6,0.6,0.53,0.53,0.4,0.4,0.4,0.5,0.5,0.51,0.63,0.64,0.64,0.5,0.5,0]}

base_percentage = 50

In [None]:
if args.method != 'random':
    list_seeds = [23]
else:
    list_seeds = [23,42,97,112,167]
    
    
for distri in dict_distri:
    args.list_pruning = dict_distri[distri]
    type_pruning = distri
    original_model_name = f'{args.model_architecture}_{args.dataset}_UNPRUNED'
    
    for seed in list_seeds:
        #load original model
        model = torch.load(f'models/{args.dataset}/{original_model_name}.pth')
        model.to(args.device)
        args.seed = seed
        args.model_type = f'{type_pruning}_{base_percentage}_PRUNED_SEED_{seed}'
        #prune original model
        prune_model(model, num_classes, trainset, args)
        args.model_type = f'{type_pruning}_{base_percentage}_PRUNED_FT_SEED_{seed}'
        #retraining pruned model
        train_model(train_loader = train_loader,
                    test_loader = test_loader,
                    model = model,
                    num_classes = num_classes,
                    args = args)
        print('============================')