## 1- Definition of arguments for function usage

In [None]:
import os
import sys
os.chdir('..')
sys.path.append(os.path.abspath("../flexiprune"))

In [None]:
## 1- Definition of arguments for function usage

import sys
import torch
import torch.optim as optim
import torchvision
from torchvision import transforms
from flexiprune import *
import argparse
sys.argv = ['']

import argparse
import torch

parser = argparse.ArgumentParser(description='Parameters for training')

parser.add_argument('--model_architecture', type=str, default="VGG16", 
                    help='Specify the architecture of the model (e.g., VGG16, AlexNet, etc.).')

parser.add_argument('--method', type=str, default="random", 
                    help='Specify the training method (e.g., SenpisFaster, random, weight).')

parser.add_argument('--dataset', type=str, default="CIFAR10", 
                    help='Specify the dataset for training (e.g., CIFAR10, "Name of custom dataset").')

parser.add_argument('--batch_size', type=int, default=8, 
                    help='Set the batch size for training.')

parser.add_argument('--num_epochs', type=int, default=1, 
                    help='Specify the number of training epochs.')

parser.add_argument('--learning_rate', type=float, default=1e-3, 
                    help='Set the learning rate for the optimizer.')

parser.add_argument('--optimizer_val', type=str, default="SGD", 
                    help='Specify the optimizer for training (e.g., SGD, Adam, etc.).')

parser.add_argument('--model_type', type=str, default="UNPRUNED", 
                    help='Specify the type of the model (e.g., PRUNED or UNPRUNED).')

parser.add_argument('--device', type=str, default=None, 
                    help='Specify the device for training (e.g., "cuda:0" for GPU).')

parser.add_argument('--model_input', default=torch.ones((1, 3, 224, 224)), 
                    help='Input tensor for the model (default is a tensor of ones).')

parser.add_argument('--eval_metric', default="accuracy", 
                    help='Specify the evaluation metric (e.g., accuracy, f1).')

parser.add_argument('--seed', type=int, default=23, 
                    help='Set the seed for random pruning operations.')

parser.add_argument('--list_pruning', type=list, 
                    default=[], 
                    help='Specify the list of pruning ratios for each layer.')

args = parser.parse_args()


args = parser.parse_args()

if args.device is None:
    import torch
    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Get Model, DATASET and TRAIN

In [None]:
model = get_model(10, args)
train_loader, test_loader, num_classes, trainset = get_dataset(args)

In [None]:
train_model(train_loader = train_loader,
            test_loader = test_loader,
            model = model,
            num_classes = num_classes,
            args = args)

## Pruning

In [None]:
list_seeds = [23,42,1234]

GPR_LIST = [50,30,20]
PD_LIST = [1,2,3,4,5]
model_type = args.model_type
obj_pd = PruningDistributionCalculator(model = args.model_architecture)

for GPR in GPR_LIST:
    for PD in PD_LIST:
        args.list_pruning = obj_pd.calculate(GPR = GPR, PD = PD)
        print(args.list_pruning)
        for seed in list_seeds:
            args.seed = seed
            #original model unpruned
            args.model_type = model_type
            model = torch.load(f'models/{args.dataset}/{args.model_architecture}_{args.dataset}_{args.model_type}.pth',weights_only = False)
            model.to(args.device)
            args.model_type = f'PD{PD}_GPR-{GPR}_PRUNED_SEED_{seed}'
            #prune original model
            prune_model(model,num_classes,trainset, args)
            #retraining pruned model
            args.model_type = f'PD{PD}_GPR-{GPR}_PRUNED_FT_SEED_{seed}'
            train_model(
                train_loader = train_loader,
                test_loader = test_loader,
                model = model,
                num_classes = num_classes,
                args = args)