<a href="https://colab.research.google.com/github/ashwinvaswani/IKD_DAFL/blob/master/notebooks/DSN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Importing Libraries

In [0]:
import torch
import torch.nn as nn
import torch.optim as optim

import torchvision.transforms as transforms
import torchvision.datasets as datasets

import random
import time

import numpy as np
import copy

In [0]:
!nvidia-smi

Fri Mar 20 08:17:53 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 440.64.00    Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   42C    P0    28W / 250W |      0MiB / 16280MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|  No ru

# LeNet

In [0]:
class LeNet5(nn.Module):

    def __init__(self,params):
        super(LeNet5, self).__init__()

        self.conv1 = nn.Conv2d(1, params[0], kernel_size=(5, 5))
        self.relu1 = nn.ReLU()
        self.maxpool1 = nn.MaxPool2d(kernel_size=(2, 2), stride=2)
        self.conv2 = nn.Conv2d(params[0], params[1], kernel_size=(5, 5))
        self.relu2 = nn.ReLU()
        self.maxpool2 = nn.MaxPool2d(kernel_size=(2, 2), stride=2)
        self.conv3 = nn.Conv2d(params[1], params[2], kernel_size=(5, 5))
        self.relu3 = nn.ReLU()
        self.fc1 = nn.Linear(params[2], params[3])
        self.relu4 = nn.ReLU()
        self.fc2 = nn.Linear(params[3], 10)

    def forward(self, img, out_feature=False):
        output = self.conv1(img)
        output = self.relu1(output)
        output = self.maxpool1(output)
        output = self.conv2(output)
        output = self.relu2(output)
        output = self.maxpool2(output)
        output = self.conv3(output)
        output = self.relu3(output)
        feature = output.view(-1, 120)
        output = self.fc1(feature)
        output = self.relu4(output)
        output = self.fc2(output)
        if out_feature == False:
            return output
        else:
            return output,feature

In [0]:
original_parameters = [6,16,120,84]
net = LeNet5(original_parameters)

In [0]:
net

LeNet5(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (relu1): ReLU()
  (maxpool1): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (relu2): ReLU()
  (maxpool2): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv2d(16, 120, kernel_size=(5, 5), stride=(1, 1))
  (relu3): ReLU()
  (fc1): Linear(in_features=120, out_features=84, bias=True)
  (relu4): ReLU()
  (fc2): Linear(in_features=84, out_features=10, bias=True)
)

In [0]:
original_parameters = [6,16,120,84]
for depth in range(3):
  if depth == 0:
    print("Original network: ")
    net = LeNet5(original_parameters)
    print(net)
  else:
    print("At depth " + str(depth) + ": ")
    original_parameters = [int(i/2) for i in original_parameters]
    print(original_parameters)
    net = LeNet5(original_parameters)
    print(net)

Original network: 
LeNet5(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (relu1): ReLU()
  (maxpool1): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (relu2): ReLU()
  (maxpool2): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv2d(16, 120, kernel_size=(5, 5), stride=(1, 1))
  (relu3): ReLU()
  (fc1): Linear(in_features=120, out_features=84, bias=True)
  (relu4): ReLU()
  (fc2): Linear(in_features=84, out_features=10, bias=True)
)
At depth 1: 
[3, 8, 60, 42]
LeNet5(
  (conv1): Conv2d(1, 3, kernel_size=(5, 5), stride=(1, 1))
  (relu1): ReLU()
  (maxpool1): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(3, 8, kernel_size=(5, 5), stride=(1, 1))
  (relu2): ReLU()
  (maxpool2): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv2d(8, 60, kernel

# VGG

In [0]:
class VGG(nn.Module):

    def __init__(self, features, params, num_classes=10, init_weights=True):
        super(VGG, self).__init__()
        self.features = features
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        self.classifier = nn.Sequential(
            nn.Linear(params[0] * 7 * 7, params[1]),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(params[1], params[2]),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(params[2], num_classes),
        )
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

In [0]:
def make_layers(cfg, batch_norm=False):
    layers = []
    in_channels = 3
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)

def _vgg(arch, cfg,params, batch_norm, pretrained, progress, **kwargs):
    if pretrained:
        kwargs['init_weights'] = False
    model = VGG(make_layers(cfg, batch_norm=batch_norm),params = params, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model

In [0]:
def vgg16(features,params,pretrained=False,progress=True, **kwargs):
    return _vgg('vgg16', features, params, False, pretrained, progress, **kwargs)

In [0]:
def get_output_nodes(params):
  for i in reversed(params):
    if type(i) == int:
      last_number = i
      break

  return last_number

# ResNet


In [0]:
class BasicBlock(nn.Module):
    expansion = 1
 
    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
 
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )
 
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

In [0]:
class Bottleneck(nn.Module):
    expansion = 4
 
    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)
 
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )
 
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out
 

In [0]:
class ResNet(nn.Module):
    def __init__(self, block, num_blocks,params, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = params[0]
 
        self.conv1 = nn.Conv2d(3, params[0], kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(params[0])
        self.layer1 = self._make_layer(block, params[1], num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, params[2], num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, params[3], num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, params[4], num_blocks[3], stride=2)
        self.linear = nn.Linear(params[4]*block.expansion, num_classes)
 
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)
 
    def forward(self, x, out_feature=False):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        feature = out.view(out.size(0), -1)
        out = self.linear(feature)
        if out_feature == False:
            return out
        else:
            return out,feature

In [0]:
def ResNet18(parameters,num_classes=10):
    return ResNet(BasicBlock, [2,2,2,2], parameters, num_classes)
 
def ResNet34(parameters,num_classes=10):
    return ResNet(BasicBlock, [3,4,6,3], parameters, num_classes)
 
def ResNet50(parameters,num_classes=10):
    return ResNet(Bottleneck, [3,4,6,3], parameters,num_classes)
 
def ResNet101(parameters,num_classes=10):
    return ResNet(Bottleneck, [3,4,23,3], parameters,num_classes)
 
def ResNet152(parameters,num_classes=10):
    return ResNet(Bottleneck, [3,8,36,3], parameters,num_classes)

In [0]:
original_params = [64,64,128,256,512]
net = ResNet34(parameters = original_params)

In [0]:
net

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=

In [0]:
original_params = [64,64,128,256,512]
for depth in range(3):
  if depth == 0:
    print("Original network: ")
    net = ResNet34(parameters = original_params)
    print(net)
  else:
    print("At depth " + str(depth) + ": ")
    original_params = [int(i/2) for i in original_params]
    print(original_params)
    net = ResNet34(parameters = original_params)
    print(net)
    optimizer_S = torch.optim.SGD(net.parameters(), lr=args.lr_S, momentum=0.9, weight_decay=5e-4)
    accr_best, cnt = utils.train_model(net, teacher, generator, data_test_loader, device, criterion, optimizer_G, optimizer_S, args.lr_G, args.lr_S, args.oh, args.ie, args.a, args.batch_size, args.img_size, args.latent_dim, args.n_epochs, args.dataset, cnt)
    print()
    print("Best accuracy currently : {}".format(accr_best))
    print("\n############################################################################\n")


Original network: 
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, mo

In [0]:
original_params = [64,64,128,256,512]

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

for depth in range(4):

  if depth == 0:
    # continue
    print("Original network: ")
    net = ResNet34(parameters = original_params)
    print(net)
    # optimizer = optim.SGD(net.parameters(),lr=0.001)
    # criterion = nn.CrossEntropyLoss()
    # criterion = criterion.to(device)
    # net, hist = train_model(net, dataloaders_dict, criterion, optimizer, num_epochs=EPOCHS)
    print("\n############################################################################\n")

  # For odd case :
  elif depth % 2 == 1:
    # continue
    print("At depth " + str(depth) + ": ")
    print("Dividing filters by 2: ")
    original_params = [int(i/2) for i in original_params]
    print(original_params)
    net = ResNet34(parameters = original_params)
    print(net)
  
    # optimizer = optim.SGD(net.parameters(),lr=0.001)
    # criterion = nn.CrossEntropyLoss()
    # criterion = criterion.to(device)
    # net, hist = train_model(net, dataloaders_dict, criterion, optimizer, num_epochs=EPOCHS)
    print("\n############################################################################\n")

  # For even case :
  elif depth % 2 == 0:
    print("At depth " + str(depth) + ": ")
    print("Removing last layer: ")
    net = ResNet34(parameters = original_params)
    second_last_layer = [net.features[i] for i in range(len(net.features)-1)][-1]
    if str(second_last_layer) ==  'ReLU(inplace=True)':
      net.features = nn.Sequential(*[net.features[i] for i in range(len(net.features) -2)])
    else:
      net.features = nn.Sequential(*[net.features[i] for i in range(len(net.features) -1)])
    original_parameters = original_parameters[:-1]
    # print(net)
    optimizer = optim.SGD(net.parameters(),lr=0.001)
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.to(device)
    net, hist = train_model(net, dataloaders_dict, criterion, optimizer, num_epochs=EPOCHS)
    print("\n############################################################################\n")

# Testing on Cifar-10

In [0]:
SEED = 1234

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [0]:
train_data = datasets.CIFAR10(root = 'data', 
                              train = True, 
                              download = True)

means = train_data.data.mean(axis = (0,1,2)) / 255
stds = train_data.data.std(axis = (0,1,2)) / 255

print(f'Calculated means: {means}')
print(f'Calculated stds: {stds}')

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Extracting data/cifar-10-python.tar.gz to data
Calculated means: [0.49139968 0.48215841 0.44653091]
Calculated stds: [0.24703223 0.24348513 0.26158784]


In [0]:
train_transforms = transforms.Compose([
                           transforms.RandomHorizontalFlip(),
                           transforms.RandomRotation(10),
                           transforms.ToTensor(),
                           transforms.Normalize(mean = means, 
                                                std = stds)
                       ])

test_transforms = transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize(mean = means, 
                                                std = stds)
                       ])

In [0]:
train_data = datasets.CIFAR10('data', 
                              train = True, 
                              download = True, 
                              transform = train_transforms)

test_data = datasets.CIFAR10('data', 
                             train = False, 
                             download = True, 
                             transform = test_transforms)

Files already downloaded and verified
Files already downloaded and verified


In [0]:
n_train_examples = int(len(train_data)*0.9)
n_valid_examples = len(train_data) - n_train_examples

train_data, valid_data = torch.utils.data.random_split(train_data, 
                                                       [n_train_examples, n_valid_examples])

In [0]:
print(f'Number of training examples: {len(train_data)}')
print(f'Number of validation examples: {len(valid_data)}')
print(f'Number of testing examples: {len(test_data)}')

Number of training examples: 45000
Number of validation examples: 5000
Number of testing examples: 10000


In [0]:
BATCH_SIZE = 64

train_iterator = torch.utils.data.DataLoader(train_data, 
                                             shuffle = True, 
                                             batch_size = BATCH_SIZE)

valid_iterator = torch.utils.data.DataLoader(valid_data, 
                                             batch_size = BATCH_SIZE)

test_iterator = torch.utils.data.DataLoader(test_data, 
                                            batch_size = BATCH_SIZE)

# Building model and testing

In [0]:
def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_inception=False):
    since = time.time()
    model = model.to(device)
    val_acc_history = []

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    # Get model outputs and calculate loss
                    # Special case for inception because in training it has an auxiliary output. In train
                    #   mode we calculate the loss by summing the final output and the auxiliary output
                    #   but in testing we only consider the final output.
                    if is_inception and phase == 'train':
                        # From https://discuss.pytorch.org/t/how-to-optimize-inception-model-with-auxiliary-classifiers/7958
                        outputs, aux_outputs = model(inputs)
                        loss1 = criterion(outputs, labels)
                        loss2 = criterion(aux_outputs, labels)
                        loss = loss1 + 0.4*loss2
                    else:
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)

                    _, preds = torch.max(outputs, 1)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
            if phase == 'val':
                val_acc_history.append(epoch_acc)

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model, val_acc_history

In [0]:
dataloaders_dict = {'train':train_iterator,'val':valid_iterator}

# Number of classes in the dataset
num_classes = 10

# Batch size for training (change depending on how much memory you have)
batch_size = 64

# Number of epochs to train for
EPOCHS = 50

In [0]:
original_parameters = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']
classifier_parameters = [512,4096,4096]

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

for depth in range(7):
  last_number = get_output_nodes(original_parameters)
  classifier_params = [last_number,4096,4096]
  if depth == 0:
    continue
    print("Original network: ")
    net = vgg16(original_parameters,classifier_parameters)
    # print(net)
    optimizer = optim.SGD(net.parameters(),lr=0.001)
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.to(device)
    net, hist = train_model(net, dataloaders_dict, criterion, optimizer, num_epochs=EPOCHS)
    print("\n############################################################################\n")

  # For odd case :
  elif depth % 2 == 1:
    continue
    print("At depth " + str(depth) + ": ")
    print("Dividing filters by 2: ")
    original_parameters = [int(i/2) if type(i) == int else i for i in original_parameters]
    classifier_parameters = [int(i/2) if type(i) == int else i for i in classifier_parameters]
    print(original_parameters)
    net = vgg16(original_parameters,classifier_parameters)
    last_layer = [net.features[i] for i in range(len(net.features))][-1]
    if str(last_layer) ==  'ReLU(inplace=True)':
      net.features = nn.Sequential(*[net.features[i] for i in range(len(net.features) -1)])
    # print(net)
    optimizer = optim.SGD(net.parameters(),lr=0.001)
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.to(device)
    net, hist = train_model(net, dataloaders_dict, criterion, optimizer, num_epochs=EPOCHS)
    print("\n############################################################################\n")

  # For even case :
  elif depth % 2 == 0:
    print("At depth " + str(depth) + ": ")
    print("Removing last layer: ")
    net = vgg16(original_parameters,classifier_parameters)
    second_last_layer = [net.features[i] for i in range(len(net.features)-1)][-1]
    if str(second_last_layer) ==  'ReLU(inplace=True)':
      net.features = nn.Sequential(*[net.features[i] for i in range(len(net.features) -2)])
    else:
      net.features = nn.Sequential(*[net.features[i] for i in range(len(net.features) -1)])
    original_parameters = original_parameters[:-1]
    # print(net)
    optimizer = optim.SGD(net.parameters(),lr=0.001)
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.to(device)
    net, hist = train_model(net, dataloaders_dict, criterion, optimizer, num_epochs=EPOCHS)
    print("\n############################################################################\n")


At depth 2: 
Removing last layer: 
Epoch 0/49
----------
train Loss: 2.3009 Acc: 0.1033
val Loss: 2.2994 Acc: 0.1154

Epoch 1/49
----------
train Loss: 2.2976 Acc: 0.1058
val Loss: 2.2947 Acc: 0.1134

Epoch 2/49
----------
train Loss: 2.2906 Acc: 0.1117
val Loss: 2.2829 Acc: 0.1404

Epoch 3/49
----------
train Loss: 2.2730 Acc: 0.1263
val Loss: 2.2529 Acc: 0.1572

Epoch 4/49
----------
train Loss: 2.2299 Acc: 0.1624
val Loss: 2.1864 Acc: 0.1728

Epoch 5/49
----------
train Loss: 2.1599 Acc: 0.1997
val Loss: 2.0906 Acc: 0.2474

Epoch 6/49
----------
train Loss: 2.0417 Acc: 0.2495
val Loss: 1.9644 Acc: 0.2830

Epoch 7/49
----------
train Loss: 1.9443 Acc: 0.2786
val Loss: 1.8785 Acc: 0.3072

Epoch 8/49
----------
train Loss: 1.8666 Acc: 0.3039
val Loss: 1.9401 Acc: 0.3010

Epoch 9/49
----------
train Loss: 1.7989 Acc: 0.3220
val Loss: 1.7281 Acc: 0.3532

Epoch 10/49
----------
train Loss: 1.7394 Acc: 0.3484
val Loss: 2.1445 Acc: 0.2804

Epoch 11/49
----------
train Loss: 1.6838 Acc: 0.37

KeyboardInterrupt: ignored