## 1- Definition of arguments for function usage

In [None]:
import sys
import torch
import torch.optim as optim
import torchvision
from torchvision import transforms
from utils import *
import argparse
sys.argv = ['']

parser = argparse.ArgumentParser(description='Parameters training')
parser.add_argument('--model_architecture', type=str, default="VGG16", help='....')
parser.add_argument('--method', type=str, default="random", help='....')
parser.add_argument('--dataset', type=str, default="CIFAR10", help='....')
parser.add_argument('--batch_size', type=int, default=8, help='....')
parser.add_argument('--num_epochs', type=int, default=20, help='....')
parser.add_argument('--learning_rate', type=float, default=1e-3, help='....')
parser.add_argument('--optimizer_val', type=str, default="SGD", help='....')
parser.add_argument('--model_type', type=str, default="HOMOGENEA", help='....')
parser.add_argument('--device', type=str, default=None, help='....')
parser.add_argument('--model_input', default=torch.ones((1, 3, 224, 224)), help='....')
parser.add_argument('--eval_metric', default="accuracy", help='....')
parser.add_argument('--pruning_seed', type=int, default=23, help='....')
parser.add_argument('--list_pruning', type=list, default = [0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0], help='....')

args = parser.parse_args()

if args.device is None:
    import torch
    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Get Model, DATASET and TRAIN

In [None]:
#args.pruned_model_name = "VGG16_DISTRI_1"
#args.num_epochs = 40
args.dataset = "CIFAR10"
args.eval_metric = "accuracy"

In [None]:
model = get_model(10, args)

In [None]:
train_loader, test_loader, num_classes = get_dataset(args)

In [None]:
train_model(args,
            train_loader = train_loader,
            test_loader = test_loader,
            model = model,
            num_classes = num_classes)

## Pruning with multiple seeds

In [None]:
dict_distri = {"HOMOGENEA" : [0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0],
                "CRECIENTE" : [0.15,0.35,0.4,0.4,0.5,0.5,0.5,0.7,0.7,0.7,0.7,0.7,0.7,0.5,0.5,0],
                "DECRECIENTE" : [0.65,0.65,0.55,0.55,0.5,0.5,0.5,0.45,0.44,0.43,0.36,0.35,0.35,0.5,0.5,0],
                "MENOS_MAS_MENOS" : [0.3,0.3,0.42,0.42,0.65,0.65,0.65,0.6,0.6,0.6,0.3,0.3,0.3,0.5,0.5,0],
                "MAS_MENOS_MAS" : [0.6,0.6,0.53,0.53,0.4,0.4,0.4,0.5,0.5,0.51,0.63,0.64,0.64,0.5,0.5,0]}

In [None]:
list_seeds = [23,42,97,112,167]
original_model_name = 'VGG16_CIFAR10_UNPRUNED'

for distri in dict_distri:
    args.list_pruning = dict_distri[distri]
    for seed in list_seeds:
        #load original model
        model = torch.load(f'models/{original_model_name}.pth')
        args.seed = seed
        args.model_type = f'{distri}_50_PRUNED_SEED_{seed}'
        #prune original model
        prune_model(model, args)
        args.model_type = f'{distri}_50_PRUNED_FT_SEED_{seed}'
        #retraining pruned model
        train_model(
                    train_loader = train_loader,
                    test_loader = test_loader,
                    model = model,
                    num_classes = num_classes,
                    args = args)