In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets

import os
import argparse
from tqdm import tqdm
import sys
import time
import math
import copy
import numpy as np
import matplotlib.pyplot as plt
import torchvision.utils as vutils
import seaborn as sns
import torch.nn.init as init
import pickle

### Model

In [None]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion *
                               planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def ResNet18():
    return ResNet(BasicBlock, [2, 2, 2, 2])

### Utils

In [None]:
def train(model, train_loader, optimizer, criterion):
    EPS = 1e-6
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.train()
    for batch_idx, (imgs, targets) in enumerate(train_loader):
        optimizer.zero_grad()
        #imgs, targets = next(train_loader)
        imgs, targets = imgs.to(device), targets.to(device)
        output = model(imgs)
        train_loss = criterion(output, targets)
        train_loss.backward()

        # Freezing Pruned weights by making their gradients Zero
        for name, p in model.named_parameters():
            if 'weight' in name:
                tensor = p.data
                grad_tensor = p.grad
                grad_tensor = torch.where(tensor.abs() < EPS, torch.zeros_like(grad_tensor), grad_tensor)
                p.grad.data = grad_tensor
        optimizer.step()
    return train_loss.item()


def test(model, test_loader, criterion):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.data.max(1, keepdim=True)[1]  # get the index of the max log-probability
            correct += pred.eq(target.data.view_as(pred)).sum().item()
        test_loss /= len(test_loader.dataset)
        accuracy = 100. * correct / len(test_loader.dataset)
    return accuracy

In [None]:
def compute_nonzeros(model, verbose= False):
    nonzero = total = 0
    for name, p in model.named_parameters():
        tensor = p.data.cpu().numpy()
        nz_count = np.count_nonzero(tensor)
        total_params = np.prod(tensor.shape)
        nonzero += nz_count
        total += total_params
        if verbose:
            print(f'{name:20} | nonzeros = {nz_count:7} / {total_params:7} ({100 * nz_count / total_params:6.2f}%) | total_pruned = {total_params - nz_count :7} | shape = {tensor.shape}')
    print(f'alive: {nonzero}, pruned : {total - nonzero}, total: {total}, Compression rate : {total/nonzero:10.2f}x  ({100 * (total-nonzero) / total:6.2f}% pruned)')
    return np.round((nonzero/total)*100,1)

In [None]:
def weight_init(m):
    '''
    Usage:
        model = Model()
        model.apply(weight_init)
    '''
    if isinstance(m, nn.Conv1d):
        init.normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.Conv2d):
        init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.Conv3d):
        init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.ConvTranspose1d):
        init.normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.ConvTranspose2d):
        init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.ConvTranspose3d):
        init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.BatchNorm1d):
        init.normal_(m.weight.data, mean=1, std=0.02)
        init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.BatchNorm2d):
        init.normal_(m.weight.data, mean=1, std=0.02)
        init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.BatchNorm3d):
        init.normal_(m.weight.data, mean=1, std=0.02)
        init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.Linear):
        init.xavier_normal_(m.weight.data)
        init.normal_(m.bias.data)

In [None]:
def make_mask(model):
    global mask
    count = 0
    for name, param in model.named_parameters():
        if 'weight' in name:
            count += 1
    mask = [None]* count
    count = 0
    for name, param in model.named_parameters():
        if 'weight' in name:
            tensor = param.data.cpu().numpy()
            mask[count] = np.ones_like(tensor)
            count += 1
    return mask

In [None]:
def original_initialization(mask_current, rewinding_state_dict):
    global model
    count = 0
    for name, param in model.named_parameters():
        if "weight" in name:
            weight_dev = param.device
            param.data = torch.from_numpy(mask_current[count] * rewinding_state_dict[name].cpu().numpy()).to(weight_dev)
            count += 1
        if "bias" in name:
            param.data = rewinding_state_dict[name]

In [None]:
def checkdir(directory):
    if not os.path.exists(directory):
        os.makedirs(directory)

In [None]:
def prune_by_percentile_fixed_all_layer(percentage):
      global mask
      global model

      count = 0
      for name, param in model.named_parameters():
          if 'weight' in name:
              tensor = param.data.cpu().numpy()
              alive = tensor[np.nonzero(tensor)] # flattened array of nonzero values
              percentile_value = np.percentile(abs(alive), percentage)
              # Convert Tensors to numpy and calculate
              weight_dev = param.device
              new_mask = np.where(abs(tensor) < percentile_value, 0, mask[count])
              # Apply new weight and mask
              param.data = torch.from_numpy(tensor * new_mask).to(weight_dev)
              mask[count] = new_mask
              count += 1

In [None]:
def prune_by_percentile_fixed_per_layer(percentages):
      global mask
      global model

      count = 0
      for name, param in model.named_parameters():
          if 'weight' in name:
              tensor = param.data.cpu().numpy()
              alive = tensor[np.nonzero(tensor)] # flattened array of nonzero values
              percentile_value = np.percentile(abs(alive), percentages[count])
              # Convert Tensors to numpy and calculate
              weight_dev = param.device
              new_mask = np.where(abs(tensor) < percentile_value, 0, mask[count])
              # Apply new weight and mask
              param.data = torch.from_numpy(tensor * new_mask).to(weight_dev)
              mask[count] = new_mask
              count += 1

In [None]:
def compute_global_taylor_threshold(percentage):
    global model
    global mask
    weights = {}
    gradients = {}
    count = 0
    for name, param in model.named_parameters():
        if 'weight' in name:
            weights[name] = param.data.cpu().numpy()
            if name == "bn1.weight":
                print(name, weights[name])
            gradients[name] = param.grad.cpu().numpy()
            if name == "bn1.weight":
                print(name, gradients[name])
            gradients[name] = gradients[name][np.nonzero(weights[name])]
            if name == "bn1.weight":
                print(name, gradients[name])
            weights[name] = weights[name][np.nonzero(weights[name])]
            if name == "bn1.weight":
                print(name, weights[name])

    flattened_first_order = []
    for name in weights.keys():
        flattened_first_order.append(np.abs(weights[name].flatten()*gradients[name].flatten()))
        if name == "bn1.weight":
            print(np.abs(weights[name].flatten()*gradients[name].flatten()))
    print()
    print()

    threshold = np.percentile(np.concatenate(flattened_first_order), percentage)
    return threshold

In [None]:
def prune_by_percentile_global_taylor(percentage):
      global mask
      global model

      count = 0
      threshold = compute_global_taylor_threshold(percentage)
      print(threshold)
      for name, param in model.named_parameters():
          if 'weight' in name:
              tensor = param.data.cpu().numpy()
              gradient = param.grad.cpu().numpy()
              weight_dev = param.device
              new_mask = np.where(abs(tensor*gradient) < threshold, 0, mask[count])
              param.data = torch.from_numpy(tensor * new_mask).to(weight_dev)
              mask[count] = new_mask
              count += 1

In [None]:
from torch.nn.utils import prune
def l1_channel_pruning(round):
    global mask
    global model

    config = {"conv1": 1-0.95**round,
        "layer1.0.conv1": 1-(1-0.1*(1.05**0))**round,
        "layer1.0.conv2": 1-(1-0.1*(1.05**1))**round,
        "layer1.1.conv1": 1-(1-0.1*(1.05**2))**round,
        "layer1.1.conv2": 1-(1-0.1*(1.05**3))**round,
        "layer2.0.conv1": 1-(1-0.1*(1.05**4))**round,
        "layer2.0.conv2": 1-(1-0.1*(1.05**5))**round,
        "layer2.0.shortcut.0": 1-0.95**round,
        "layer2.1.conv1": 1-(1-0.1*(1.05**6))**round,
        "layer2.1.conv2": 1-(1-0.1*(1.05**7))**round,
        "layer3.0.conv1": 1-(1-0.1*(1.05**8))**round,
        "layer3.0.conv2": 1-(1-0.1*(1.05**9))**round,
        "layer3.0.shortcut.0": 1-0.95**round,
        "layer3.1.conv1": 1-(1-0.1*(1.05**10))**round,
        "layer3.1.conv2": 1-(1-0.1*(1.05**11))**round,
        "layer4.0.conv1": 1-(1-0.1*(1.05**12))**round,
        "layer4.0.conv2": 1-(1-0.1*(1.05**13))**round,
        "layer4.0.shortcut.0": 1-0.85**round,
        "layer4.1.conv1": 1-(1-0.1*(1.05**14))**round,
        "layer4.1.conv2": 1-(1-0.1*(1.05**15))**round }

    count = 0
    for module_name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d):
            prune.ln_structured(module, name="weight", amount=config[module_name], n = 1, dim = 0)
            prune.remove(module, 'weight')

    for name, param in model.named_parameters():
        if 'weight' in name:
            tensor = param.data.cpu().numpy()
            weight_dev = param.device
            new_mask = np.where(abs(tensor) == 0, 0, mask[count])
            mask[count] = new_mask
            count += 1

In [None]:
def l2_channel_pruning(round):
    global mask
    global model

    count = 0
    for module_name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d):
            prune.ln_structured(module, name="weight", amount=1-0.8**round, n = 2, dim = 0)
            prune.remove(module, 'weight')
        elif isinstance(module, torch.nn.Linear) and (not module_name == "linear"):
            prune.ln_structured(module, name="weight", amount=1-0.8**round, n = 2, dim = 0)
            prune.remove(module, 'weight')

    for name, param in model.named_parameters():
        if 'weight' in name:
            tensor = param.data.cpu().numpy()
            weight_dev = param.device
            new_mask = np.where(abs(tensor) == 0, 0, mask[count])
            mask[count] = new_mask
            count += 1

In [None]:
from torch.ao.pruning._experimental.pruner import FPGMPruner, SaliencyPruner
import torch.nn.utils.parametrize as parametrize

def FPGM_pruning(round):
    global model
    global mask

    pruner = FPGMPruner(1-0.8**round)
    config = [
        {"tensor_fqn": "conv1.weight"},
        {"tensor_fqn": "layer1.0.conv1.weight"},
        {"tensor_fqn": "layer1.0.conv2.weight"},
        {"tensor_fqn": "layer1.1.conv1.weight"},
        {"tensor_fqn": "layer1.1.conv2.weight"},
        {"tensor_fqn": "layer2.0.conv1.weight"},
        {"tensor_fqn": "layer2.0.conv2.weight"},
        {"tensor_fqn": "layer2.0.shortcut.0.weight"},
        {"tensor_fqn": "layer2.1.conv1.weight"},
        {"tensor_fqn": "layer2.1.conv2.weight"},
        {"tensor_fqn": "layer3.0.conv1.weight"},
        {"tensor_fqn": "layer3.0.conv2.weight"},
        {"tensor_fqn": "layer3.0.shortcut.0.weight"},
        {"tensor_fqn": "layer3.1.conv1.weight"},
        {"tensor_fqn": "layer3.1.conv2.weight"},
        {"tensor_fqn": "layer4.0.conv1.weight"},
        {"tensor_fqn": "layer4.0.conv2.weight"},
        {"tensor_fqn": "layer4.0.shortcut.0.weight"},
        {"tensor_fqn": "layer4.1.conv1.weight"},
        {"tensor_fqn": "layer4.1.conv2.weight"}]

    pruner.prepare(model, config)
    pruner.enable_mask_update = True
    pruner.step()
    for module_name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d):
            try:
                parametrize.remove_parametrizations(module, "weight")
            except:
                pass
    count = 0
    for name, param in model.named_parameters():
        if 'weight' in name:
            tensor = param.data.cpu().numpy()
            weight_dev = param.device
            new_mask = np.where(abs(tensor) == 0, 0, mask[count])
            mask[count] = new_mask
            count += 1

In [None]:
def compute_mask_size(mask):
  total = 0
  for layer in mask:
      total += len(layer.flatten())
  return total

def compute_mask_sum(mask):
  total = 0
  for layer in mask:
      total += layer.flatten().sum()
  return total

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = ResNet18()
model = model.to(device)
model.apply(weight_init)
compute_nonzeros(model)
make_mask(model)
iterative_pruning_rounds = 15
learning_rate = 0.1

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()

# One additional round required for the original training
for round in range(iterative_pruning_rounds+1):
    if not round == 0:
        l1_channel_pruning(round)
        original_initialization(mask, initial_state_dict)
        optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-4)
    #print(f"\n--- Pruning Level [{ITE}:{_ite}/{ITERATION}]: ---")
    print("_________round", round, "_________")

    p_m_current = compute_nonzeros(model, verbose=True)
    print(compute_mask_size(mask), compute_mask_sum(mask)/compute_mask_size(mask))

    for epoch in [1]:
        loss = train(model, train_loader, optimizer, criterion)
    #_ = compute_nonzeros(model, verbose=True)
    print()
    print()
    print()
    print()
    print()

### Experiments

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 128
training_epochs = 160
target_p_m = 10
iterative_pruning_rounds = 15
pruning_rate_per_round = 20 #(1-np.power(target_p_m/100, 1/iterative_pruning_rounds))*100
reinit = False
learning_rate = 0.1

In [None]:
# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

train_set = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)
train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=128, shuffle=False)

test_set = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
test_loader = torch.utils.data.DataLoader(
    test_set, batch_size=128, shuffle=False)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

model = ResNet18().to(device)
model.apply(weight_init)

initial_state_dict = copy.deepcopy(model.state_dict())
experiment_path = f"/content/drive/MyDrive/R252/Structured_l1_large/"
checkdir(experiment_path+"saves/")
torch.save(model, experiment_path+"saves/initial_state_dict.pth.tar")
make_mask(model)
print(compute_mask_size(mask), compute_mask_sum(mask)/compute_mask_size(mask))

==> Preparing data..
Files already downloaded and verified
Files already downloaded and verified
11169152 1.0


In [None]:
#optimizer = torch.optim.Adam(model.parameters(), weight_decay=1e-4)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()

for name, param in model.named_parameters():
    print(name, param.size())

best_accuracy_per_round_history = np.zeros(iterative_pruning_rounds+1,float)
best_accuracy_per_round = 0.0
best_accuracy_epoch = 0
p_m = np.zeros(iterative_pruning_rounds+1,float)
step = 0
loss_history = np.zeros(training_epochs,float)
accuracy_history = np.zeros(training_epochs+1,float)
_ = compute_nonzeros(model)
# One additional round required for the original training
for round in range(iterative_pruning_rounds+1):
    if not round == 0:
        l1_channel_pruning(round)
        if reinit:
            model.apply(weight_init)
            step = 0
            for name, param in model.named_parameters():
                if 'weight' in name:
                    weight_dev = param.device
                    param.data = torch.from_numpy(param.data.cpu().numpy() * mask[step]).to(weight_dev)
                    step = step + 1
            step = 0
        else:
            original_initialization(mask, rewinding_state_dict)
        optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
    #print(f"\n--- Pruning Level [{ITE}:{_ite}/{ITERATION}]: ---")
    print("_________round", round, "_________")

    p_m_current = compute_nonzeros(model, verbose=False)
    print(compute_mask_size(mask), compute_mask_sum(mask)/compute_mask_size(mask))

    p_m[round] = p_m_current
    pbar = tqdm(range(training_epochs))

    accuracy = test(model, test_loader, criterion)
    best_accuracy_per_round = accuracy
    accuracy_history[0] = accuracy

    for epoch in pbar:
        if epoch >= 80:
            for g in optimizer.param_groups:
                g['lr'] = 0.01
        elif epoch >= 120:
            for g in optimizer.param_groups:
                g['lr'] = 0.001
        loss = train(model, train_loader, optimizer, criterion)


        accuracy = test(model, test_loader, criterion)

        if accuracy > best_accuracy_per_round:
            best_accuracy_per_round = accuracy
            best_accuracy_epoch = epoch
            checkdir(experiment_path+"saves/")
            torch.save(model, experiment_path+f"saves/round_{round}_best_model_lt.pth.tar")

        if round == 0 and epoch == 9:
            checkdir(experiment_path+"saves/")
            torch.save(model, experiment_path+f"saves/rewinding_state_dict_1.pth.tar")
            rewinding_state_dict = copy.deepcopy(model.state_dict())

        if round == 0 and epoch == 15:
            checkdir(experiment_path+"saves/")
            torch.save(model, experiment_path+f"saves/rewinding_state_dict_2.pth.tar")

        loss_history[epoch] = loss
        accuracy_history[epoch+1] = accuracy
        pbar.set_description(
                f'Train Epoch: {epoch}/{training_epochs} Loss: {loss:.6f} Accuracy: {accuracy:.2f}% Best Accuracy: {best_accuracy_per_round:.2f}% (Epoch:{best_accuracy_epoch})')

    checkdir(experiment_path+"saves/")
    torch.save(model, experiment_path+f"saves/round_{round}_final_model_lt.pth.tar")

    _ = compute_nonzeros(model, verbose=False)

    best_accuracy_per_round_history[round] = best_accuracy_per_round

    # Plotting Loss (Training), Accuracy (Testing) per Epoch
    plt.plot(np.arange(1,(training_epochs)+1), 100*(loss_history - np.min(loss_history))/np.ptp(loss_history).astype(float), c="blue", label="Loss")
    plt.plot(np.arange(1,(training_epochs)+1), accuracy_history[1:], c="red", label="Accuracy")
    plt.title(f"Loss and Accuracy per Epoch (ResNet18, CIFAR10)")
    plt.xlabel("Epochs")
    plt.ylabel("Loss and Accuracy")
    plt.legend()
    plt.grid(color="gray")
    checkdir(experiment_path+"plots/")
    plt.savefig(experiment_path+f"plots/round_{round}_LossVsAccuracy_{p_m_current}.png", dpi=1200)
    plt.close()

    # Dump Plot values
    checkdir(experiment_path+"dumps/")
    loss_history.dump(experiment_path+f"dumps/round_{round}_loss_history_{p_m_current}.dat")
    accuracy_history.dump(experiment_path+f"dumps/round_{round}_accuracy_history_{p_m_current}.dat")
    with open(experiment_path+f"dumps/round_{round}_mask_{p_m_current}.pkl", 'wb') as fp:
        pickle.dump(mask, fp)

    best_accuracy_per_round = 0.0
    best_accuracy_epoch
    loss_history = np.zeros(training_epochs,float)
    accuracy_history = np.zeros(training_epochs+1,float)

if iterative_pruning_rounds > 1:
    # Dumping Values for Plotting
    checkdir(experiment_path+"dumps/")
    p_m.dump(experiment_path+"dumps/compression.dat")
    best_accuracy_per_round_history.dump(experiment_path+"dumps/best_accuracy.dat")

    # Plotting
    a = np.arange(iterative_pruning_rounds+1)
    plt.plot(a, best_accuracy_per_round_history, c="blue", label="Winning tickets")
    plt.title(f"Test Accuracy vs Unpruned Weights Percentage (ResNet18, CIFAR10)")
    plt.xlabel("Unpruned Weights Percentage")
    plt.ylabel("test accuracy")
    plt.xticks(a, p_m, rotation ="vertical")
    plt.ylim(0,100)
    plt.legend()
    plt.grid(color="gray")
    checkdir(experiment_path+"plots/")
    plt.savefig(experiment_path+"plots/AccuracyVsWeights.png", dpi=1200)
    plt.close()
