In [0]:
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import torch.nn.utils.prune as prune

import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets


from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import os
import argparse
import copy
import sys

In [0]:
# Useful functions.

# Function to reinit the net's weights to the original weights + apply the 
# current mask.

# TODO: see if you need to make weight_orig requires_grad=False.
def reinit_and_apply_mask(net):
    # For each layer, set the current weights as weight_init, leave the mask untouched.
    modules_names = net._modules.keys()
    for module_name in modules_names:
        
        # Check if module_name has been pruned.
        pruned = False
        for name, param in net._modules[module_name].named_parameters():
            if 'weight_orig' in name:
                pruned = True

        # If the module has been pruned, copy the weights from weight_orig 
        # to module weights and apply the mask.
        net._modules[module_name].weight = net._modules[module_name].weight_orig.detach().requires_grad_()
        net._modules[module_name].weight = net._modules[module_name].weight * net._modules[module_name].weight_mask


# Function for weight initialisation (taken from 
# https://github.com/rahulvigneswaran/Lottery-Ticket-Hypothesis-in-Pytorch/blob/master/main.py)
def weight_init(m):
    '''
    Usage:
        model = Model()
        model.apply(weight_init)
    '''
    if isinstance(m, nn.Conv1d):
        init.normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.Conv2d):
        init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.Conv3d):
        init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.ConvTranspose1d):
        init.normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.ConvTranspose2d):
        init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.ConvTranspose3d):
        init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.BatchNorm1d):
        init.normal_(m.weight.data, mean=1, std=0.02)
        init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.BatchNorm2d):
        init.normal_(m.weight.data, mean=1, std=0.02)
        init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.BatchNorm3d):
        init.normal_(m.weight.data, mean=1, std=0.02)
        init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.Linear):
        init.xavier_normal_(m.weight.data)
        init.normal_(m.bias.data)
    elif isinstance(m, nn.LSTM):
        for param in m.parameters():
            if len(param.shape) >= 2:
                init.orthogonal_(param.data)
            else:
                init.normal_(param.data)
    elif isinstance(m, nn.LSTMCell):
        for param in m.parameters():
            if len(param.shape) >= 2:
                init.orthogonal_(param.data)
            else:
                init.normal_(param.data)
    elif isinstance(m, nn.GRU):
        for param in m.parameters():
            if len(param.shape) >= 2:
                init.orthogonal_(param.data)
            else:
                init.normal_(param.data)
    elif isinstance(m, nn.GRUCell):
        for param in m.parameters():
            if len(param.shape) >= 2:
                init.orthogonal_(param.data)
            else:
                init.normal_(param.data)

# Code for training one epoch (one pass through the dataset).
def train_one_epoch(model, train_loader, optimizer, criterion):
    model.train() #Sets nn.Module in train mode (has effects only on some models)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    for batch_idx, (imgs, targets) in enumerate(train_loader):
        optimizer.zero_grad()
        imgs, targets = imgs.to(device), targets.to(device)

        output = model(imgs)
        
        train_loss = criterion(output, targets)
        train_loss.backward()

        # In original, gradient of the pruned nodes were made 0.
        optimizer.step()
    
    # train_loss is a tensor with one value;
    # tensor.item() returns the value held by a tensor with one value;
    return train_loss.item()

# Not sure if the loss calculation is correct here.
def calculate_accuracy_and_loss(model, loader, criterion):
    # Put the model in evaluation mode.
    model.eval()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    nr_batches = len(loader)
    total_loss = 0
    accuracy = 0

    correct = 0
    total = 0

    with torch.no_grad():
        for batch_idx, (imgs, targets) in enumerate(loader):
            imgs, targets = imgs.to(device), targets.to(device)

            output = model(imgs)

            total_loss += (1/nr_batches) * criterion(output, targets)

            _, predicted = torch.max(output.data, 1)
            total += targets.shape[0]
            correct += (predicted == targets).sum().item()

    accuracy = correct/total * 100
    return accuracy, total_loss

In [0]:
# Test run #1:

# All the variables + initialisations we need for an experiment.
reinit = True
dataset = "mnist"
batch_size = 32
experiment_name = "weight_magnitude_pruning_test"
# Not sure if we need a validation set.
nr_valid_elems = 10000
test_freq = 1
# model_type = "LeNetTrash"
# Will we experiment with different optimizers/losses/arguments to those?
# If yes, we need to prove a re-initializer.
#optimizer = 
#criterion = 
prune_iterations = 5
epochs = 5
#criterion = 
# Need to have a parameter for the pruning type + python file with the pruning.
# WHat if the pruning changes parameters iteration after iteration?
# Then modify the script accordingly.
# What if the pruning method changes iteration after iteration?
# Load a list of pruning methods from the file instead? - to be seen.

# All the experiment data must be saved in a folder experiment_name.
# The data must be saved in a different folde (outside any experiment folder).

writer = SummaryWriter()
sns.set_style('darkgrid')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# TODO: Add more dataset-dependent data loaders.
if dataset == 'mnist':
    # Transforms which will be applied to the data.
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.1307,), (0.3081, ))
                                    ])
    # 0.1307, 0.3081 represent the mean + std of the mnist dataset.
    
    # Split the train dataset into a train + valid datasets.
    # Must set the values of the samples in each split (here, 50000, 10000).
    dataset = datasets.MNIST(root=os.getcwd() + '/data', train=True, download=True, transform=transform)
    train_set, valid_set = torch.utils.data.random_split(dataset, [50000, 10000])

    # Load the test dataset.
    test_set = datasets.MNIST(root=os.getcwd() + '/data', train=False, transform=transform)

    # TODO: Also, import the relevant models (to be tested).
    # from models.mnist import fc1


# Transformations will not be applied until you call a DataLoader on it.
# make valid_loader = test_loader if the option for the split is 0.
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True)
valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=batch_size, shuffle=False, num_workers=0, drop_last=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=0, drop_last=True)

# Load model (placeholder, see if it makes sense to load models given how PruningMethods differ).
#if model_type == "fc1":
    #model = fc1.fc1().to(device)
# Create a neural net.
class LeNetTrash(nn.Module):
    def __init__(self):
        super(LeNetTrash, self).__init__()
        self.fc1 = nn.Linear(784, 300)
        self.fc2 = nn.Linear(300, 100)
        self.fc3 = nn.Linear(100, 10)
    
    def forward(self, input):
        x = input.flatten(start_dim=1, end_dim=-1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNetTrash().to(device)

# Initialise weights.
model.apply(weight_init)

# Save initial model for reference.
os.makedirs(os.getcwd() + '/' + experiment_name, exist_ok=True)
torch.save(model, os.getcwd()+ '/' + experiment_name + '/model_init.pth')

# Main training loop. We can add bells and whistles afterwards, depending on 
# what we want to save.

# Placeholder for loading pruning methods.
# TODO: refactor after testing.
def prune_model(model):
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            prune.l1_unstructured(module, name='weight', amount=0.1)

# TODO: remove these two?
optimizer = torch.optim.Adam(model.parameters(), weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()
writer = SummaryWriter(os.getcwd() + '/' + experiment_name)

best_accuracies = np.zeros((prune_iterations, ))
best_accuracy = 0
for prune_iteration in range(prune_iterations):
    print('\n\nStarting pruning iteration: ' + str(prune_iteration) + '\n')

    #os.makedirs(os.getcwd() + '/' + experiment_name + '/' + 'prune_iter_' + str(prune_iteration), exist_ok=True)

    if prune_iteration != 0:
        # do the pruning here.
        # Before pruning (or after, we can save the mask here, if interested).
        prune_model(model)
        if reinit:
            # do reinitialization of the net + optimizer here.
            reinit_and_apply_mask(model)
            optimizer = torch.optim.Adam(model.parameters(), weight_decay=1e-4)
    
    for epoch in range(epochs):
        train_one_epoch(model, train_loader, optimizer, criterion)

        if epoch % test_freq == 0:
            experiment_PATH = os.getcwd() + '/' + experiment_name
            # save whatever you're interested in here
            # need a function to save a model with a name
            # for starters, save best model for each pruning iteration

            train_acc, train_loss = calculate_accuracy_and_loss(model, train_loader, criterion)
            valid_acc, valid_loss = calculate_accuracy_and_loss(model, valid_loader, criterion)

            writer.add_scalar('Accuracy/train_' + str(prune_iteration), train_acc, epoch)
            writer.add_scalar('Loss/train_' + str(prune_iteration), train_loss, epoch)
            writer.add_scalar('Accuracy/valid_' + str(prune_iteration), valid_acc, epoch)
            writer.add_scalar('Loss/valid_' + str(prune_iteration), valid_loss, epoch)

            # Maybe save the best models here, now we're interested only in best accs/losses.
            if(valid_acc > best_accuracy):
                best_accuracy = valid_acc
                #PATH = get_best_model_PATH(experiment_PATH, best_model_filename, 0)
                #remove_checkpoint(PATH)
                #save_checkpoint(PATH, epoch, net, optimizer)

            print('Epoch: ' + str(epoch) +  ', Train loss: {:.4f}, Train Acc: {:.2f}, Valid loss: {:.4f}, Valid Acc: {:.2f}'.format(train_loss, train_acc, valid_loss, valid_acc))
    
    best_accuracies[prune_iteration] = best_accuracy

    # Here, the model has finished training. 
    # Again, save whatever information seems legit to save, dependent on the pruning run.
    # Save a list of best accuracies for each pruning iteration.
    # Cool, that's all I think.

  "type " + obj.__name__ + ". It won't be checked "




Starting pruning iteration: 0

Epoch: 0, Train loss: 0.1221, Train Acc: 96.02, Valid loss: 0.1459, Valid Acc: 95.55
Epoch: 1, Train loss: 0.0642, Train Acc: 97.89, Valid loss: 0.0981, Valid Acc: 96.96
Epoch: 2, Train loss: 0.0634, Train Acc: 97.96, Valid loss: 0.1051, Valid Acc: 96.86
Epoch: 3, Train loss: 0.0388, Train Acc: 98.83, Valid loss: 0.0795, Valid Acc: 97.60
Epoch: 4, Train loss: 0.0400, Train Acc: 98.66, Valid loss: 0.0924, Valid Acc: 97.38


Starting pruning iteration: 1

Epoch: 0, Train loss: 0.0412, Train Acc: 98.53, Valid loss: 0.0983, Valid Acc: 97.06
Epoch: 1, Train loss: 0.0361, Train Acc: 98.73, Valid loss: 0.1011, Valid Acc: 97.37
Epoch: 2, Train loss: 0.0305, Train Acc: 98.94, Valid loss: 0.0859, Valid Acc: 97.35
Epoch: 3, Train loss: 0.0213, Train Acc: 99.32, Valid loss: 0.0863, Valid Acc: 97.59
Epoch: 4, Train loss: 0.0209, Train Acc: 99.31, Valid loss: 0.0780, Valid Acc: 97.72


Starting pruning iteration: 2

Epoch: 0, Train loss: 0.0398, Train Acc: 98.60, Val

In [0]:
if __name__=="__main__":
    
    #from gooey import Gooey
    #@Gooey      
    
    # Arguement Parser
    parser = argparse.ArgumentParser()
    parser.add_argument("--lr",default= 1.2e-3, type=float, help="Learning rate")
    parser.add_argument("--batch_size", default=60, type=int)
    parser.add_argument("--start_iter", default=0, type=int)
    parser.add_argument("--end_iter", default=100, type=int)
    parser.add_argument("--print_freq", default=1, type=int)
    parser.add_argument("--valid_freq", default=1, type=int)
    parser.add_argument("--resume", action="store_true")
    parser.add_argument("--prune_type", default="lt", type=str, help="lt | reinit")
    parser.add_argument("--gpu", default="0", type=str)
    parser.add_argument("--dataset", default="mnist", type=str, help="mnist | cifar10 | fashionmnist | cifar100")
    parser.add_argument("--arch_type", default="fc1", type=str, help="fc1 | lenet5 | alexnet | vgg16 | resnet18 | densenet121")
    parser.add_argument("--prune_percent", default=10, type=int, help="Pruning percent")
    parser.add_argument("--prune_iterations", default=35, type=int, help="Pruning iterations count")

    
    args = parser.parse_args()


    os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
    os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu
    
    
    #FIXME resample
    resample = False

    # Looping Entire process
    #for i in range(0, 5):
    main(args, ITE=1)