## 1- Definition of arguments for function usage

In [None]:
import os
os.chdir('..')

In [None]:
## 1- Definition of arguments for function usage

import sys
import torch
import torch.optim as optim
import torchvision
from torchvision import transforms
from flexiprune import *
import argparse
sys.argv = ['']

import argparse
import torch

parser = argparse.ArgumentParser(description='Parameters for training')

parser.add_argument('--model_architecture', type=str, default="VGG16", 
                    help='Specify the architecture of the model (e.g., VGG16, AlexNet, etc.).')

parser.add_argument('--method', type=str, default="weight", 
                    help='Specify the training method (e.g., SenpisFaster, random, weight).')

parser.add_argument('--dataset', type=str, default="CIFAR10", 
                    help='Specify the dataset for training (e.g., CIFAR10, "Name of custom dataset").')

parser.add_argument('--batch_size', type=int, default=8, 
                    help='Set the batch size for training.')

parser.add_argument('--num_epochs', type=int, default=1, 
                    help='Specify the number of training epochs.')

parser.add_argument('--learning_rate', type=float, default=1e-3, 
                    help='Set the learning rate for the optimizer.')

parser.add_argument('--optimizer_val', type=str, default="SGD", 
                    help='Specify the optimizer for training (e.g., SGD, Adam, etc.).')

parser.add_argument('--model_type', type=str, default="UNPRUNED", 
                    help='Specify the type of the model (e.g., PRUNED or UNPRUNED).')

parser.add_argument('--device', type=str, default=None, 
                    help='Specify the device for training (e.g., "cuda:0" for GPU).')

parser.add_argument('--model_input', default=torch.ones((1, 3, 224, 224)), 
                    help='Input tensor for the model (default is a tensor of ones).')

parser.add_argument('--eval_metric', default="accuracy", 
                    help='Specify the evaluation metric (e.g., accuracy, f1).')

parser.add_argument('--seed', type=int, default=23, 
                    help='Set the seed for random pruning operations.')

parser.add_argument('--list_pruning', type=list, 
                    default=[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0], 
                    help='Specify the list of pruning ratios for each layer.')

args = parser.parse_args()


args = parser.parse_args()

if args.device is None:
    import torch
    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Get Model, DATASET and TRAIN

In [None]:
model = get_model(10, args)
train_loader, test_loader, num_classes, trainset = get_dataset(args)

In [None]:
train_model(train_loader = train_loader,
            test_loader = test_loader,
            model = model,
            num_classes = num_classes,
            args = args)

## Pruning

In [None]:
list_seeds = [23,42,1234]
#distribution global PD = 50%
dict_distri = {"GPD-50":
               {"PD1" : [0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0],
                "PD2" : [0.15,0.35,0.4,0.4,0.5,0.5,0.5,0.7,0.7,0.7,0.7,0.7,0.7,0.5,0.5,0],
                "PD3" : [0.65,0.65,0.55,0.55,0.5,0.5,0.5,0.45,0.44,0.43,0.36,0.35,0.35,0.5,0.5,0],
                "PD4" : [0.3,0.3,0.42,0.42,0.65,0.65,0.65,0.6,0.6,0.6,0.3,0.3,0.3,0.5,0.5,0],
                "PD5" : [0.6,0.6,0.53,0.53,0.4,0.4,0.4,0.5,0.5,0.51,0.63,0.64,0.64,0.5,0.5,0]},
                "GPD-30":
                {
                "PD1" : [0.3,0.3,0.3,0.3,0.3,0.3,0.3,0.3,0.3,0.3,0.3,0.3,0.3,0.3,0.3,0],
                "PD2" : [0.15,0.15,0.2,0.2,0.3,0.3,0.3,0.43,0.43,0.44,0.45,0.45,0.46,0.3,0.3,0],
                "PD3" : [0.45,0.45,0.35,0.35,0.3,0.3,0.3,0.25,0.25,0.25,0.16,0.16,0.13,0.3,0.3,0],
                "PD4" : [0.15,0.15,0.3,0.3,0.45,0.45,0.45,0.28,0.28,0.27,0.2,0.19,0.15,0.3,0.3,0],
                "PD5" : [0.4,0.4,0.2,0.2,0.3,0.3,0.3,0.3,0.3,0.3,0.37,0.37,0.36,0.3,0.3,0]},
                "GPD-20":
                {
                "PD1":[0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0],
                "PD2" : [0.15,0.15,0.17,0.17,0.2,0.2,0.2,0.22,0.22,0.22,0.3,0.3,0.31,0.2,0.2,0],
                "PD3" : [0.35,0.35,0.25,0.25,0.2,0.2,0.2,0.13,0.13,0.12,0.1,0.1,0.09,0.2,0.2,0],
                "PD4" : [0.15,0.15,0.15,0.15,0.34,0.34,0.34,0.15,0.15,0.15,0.1,0.09,0.09,0.2,0.2,0],
                "PD5" : [0.35,0.35,0.2,0.2,0.11,0.11,0.1,0.19,0.2,0.2,0.31,0.31,0.31,0.2,0.2,0]}
                }
model_type = args.model_type

for GPD in dict_distri:
    for distri in dict_distri[GPD]:
        args.list_pruning = dict_distri[GPD][distri]
        for seed in list_seeds:
            args.seed = seed
            #original model unpruned
            args.model_type = model_type
            model = torch.load(f'models/{args.dataset}/{args.model_architecture}_{args.dataset}_{args.model_type}.pth')
            model.to(args.device)
            args.model_type = f'{distri}_{GPD}_PRUNED_SEED_{seed}'
            #prune original model
            prune_model(model,num_classes,trainset, args)
            #retraining pruned model
            args.model_type = f'{distri}_{GPD}_PRUNED_FT_SEED_{seed}'
            train_model(
                train_loader = train_loader,
                test_loader = test_loader,
                model = model,
                num_classes = num_classes,
                args = args)

