## Index:
1. Definition of arguments for function usage
2. Model creation
        2.1. FLOPs verification, number of conv and linear layers
3. Dataset creation
4. Training unpruned model
5. Model pruning
        5.1. FLOPs verification, number of conv and linear layers
6. Training the pruned model

### 1. Definition of arguments for function usage


In [1]:
import sys
import torch
import torch.optim as optim
import torchvision
from torchvision import transforms
from utils import *
import argparse
sys.argv = ['']

parser = argparse.ArgumentParser(description='Parameters training')
parser.add_argument('--model_architecture', type=str, default="VGG16", help='....')
parser.add_argument('--dataset', type=str, default="CIFAR10", help='....')
parser.add_argument('--batch_size', type=int, default=16, help='....')
parser.add_argument('--num_epochs', type=int, default=1, help='....')
parser.add_argument('--learning_rate', type=float, default=1e-4, help='....')
parser.add_argument('--optimizer_val', type=str, default="SGD", help='....')
parser.add_argument('--model_type', type=str, default="UNPRUNED", help='....')
parser.add_argument('--device', type=str, default=None, help='....')
parser.add_argument('--model_input', default=torch.ones((1, 3, 224, 224)), help='....')
parser.add_argument('--pruning_seed', type=int, default=23, help='....')
parser.add_argument('--list_pruning', type=list, default = [0.6,0.6,0.53,0.53,0.4,0.4,0.4,0.5,0.5,0.5,0.6,0.6,0.6,0.5,0.5,0], help='....')
args = parser.parse_args()

if args.device is None:
    import torch
    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


### 2. Model creation

In [2]:
model = get_model(args)

#### 2.1. FLOPs verification, number of conv and linear layers

In [3]:
obj_params = ModelParams(model, args.model_input)
flops, conv_layers, linear_layers = obj_params.get_all_params()
print(f"FLOPS: {flops}\n Conv_layers: {conv_layers}\n linear_layers: {linear_layers}\n Total prune layers: {conv_layers+linear_layers}")

Unsupported operator aten::add_ encountered 13 time(s)
Unsupported operator aten::max_pool2d encountered 5 time(s)


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 224, 224]           1,792
       BatchNorm2d-2         [-1, 64, 224, 224]             128
              ReLU-3         [-1, 64, 224, 224]               0
            Conv2d-4         [-1, 64, 224, 224]          36,928
       BatchNorm2d-5         [-1, 64, 224, 224]             128
              ReLU-6         [-1, 64, 224, 224]               0
         MaxPool2d-7         [-1, 64, 112, 112]               0
            Conv2d-8        [-1, 128, 112, 112]          73,856
       BatchNorm2d-9        [-1, 128, 112, 112]             256
             ReLU-10        [-1, 128, 112, 112]               0
           Conv2d-11        [-1, 128, 112, 112]         147,584
      BatchNorm2d-12        [-1, 128, 112, 112]             256
             ReLU-13        [-1, 128, 112, 112]               0
        MaxPool2d-14          [-1, 128,

###  Dataset creation

In [4]:
train_loader,test_loader = get_dataset(args)

Files already downloaded and verified
Files already downloaded and verified


### 4. Training unpruned model

In [5]:
train_model(args,
            train_loader = train_loader,
            test_loader = test_loader,
            model = model)

Epoch: [1/2]	 || Training Loss: 0.776	 || Val Loss: 0.313	 || Training Acc: 75.55% 	 ||  Val Acc 89.31%
Epoch: [2/2]	 || Training Loss: 0.304	 || Val Loss: 0.247	 || Training Acc: 89.89% 	 ||  Val Acc 91.45%


### 5. Model pruning

In [6]:
model_name = ''
model = torch.load(f'models/{model_name}.pth')
#list_pruning = [0.6,0.6,0.53,0.53,0.4,0.4,0.4,0.5,0.5,0.5,0.6,0.6,0.6,0.5,0.5,0]
#args.list_pruning = list_pruning
#args.pruned_model_name = "VGG16_DISTRI_1"
args.model_type = f'50_PRUNED_SEED_{args.seed}'
prune_model(model, args)

#### 5.1. FLOPs verification, number of conv and linear layers

In [7]:
obj_params = ModelParams(model, args.model_input)
flops, conv_layers, linear_layers = obj_params.get_all_params()
print(f"FLOPS: {flops}\n Conv_layers: {conv_layers}\n linear_layers: {linear_layers}\n Total prune layers: {conv_layers+linear_layers}")

Unsupported operator aten::add encountered 12 time(s)
Unsupported operator aten::max_pool2d encountered 5 time(s)


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 26, 224, 224]             728
       BatchNorm2d-2         [-1, 26, 224, 224]              52
              ReLU-3         [-1, 26, 224, 224]               0
             ConvB-4         [-1, 26, 224, 224]           6,084
       BatchNorm2d-5         [-1, 26, 224, 224]              52
              ReLU-6         [-1, 26, 224, 224]               0
         MaxPool2d-7         [-1, 26, 112, 112]               0
             ConvB-8         [-1, 60, 112, 112]          14,040
       BatchNorm2d-9         [-1, 60, 112, 112]             120
             ReLU-10         [-1, 60, 112, 112]               0
            ConvB-11         [-1, 60, 112, 112]          32,400
      BatchNorm2d-12         [-1, 60, 112, 112]             120
             ReLU-13         [-1, 60, 112, 112]               0
        MaxPool2d-14           [-1, 60,

### 6. Training the pruned model

In [8]:
args.model_type = f'50_PRUNED_FT_SEED_{args.seed}'

train_model(args,
            train_loader = train_loader,
            test_loader = test_loader,
            model = model)


Epoch: [1/2]	 || Training Loss: 1.623	 || Val Loss: 1.083	 || Training Acc: 41.20% 	 ||  Val Acc 60.87%
Epoch: [2/2]	 || Training Loss: 1.036	 || Val Loss: 0.774	 || Training Acc: 63.16% 	 ||  Val Acc 72.56%
