In [77]:
# Importing Libraries
import argparse
import copy
import os
import sys
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
import os
from tensorboardX import SummaryWriter
import torchvision.utils as vutils
import seaborn as sns

import torch.nn.init as init
import pickle
from prune_layer import *


# Custom Libraries
import utils
class argument:
    def __init__(self, lr=1.2e-3,batch_size = 60,start_iter = 0,end_iter = 100,print_freq = 1,
                 valid_freq = 1,resume = "store_true",prune_type= "lt",gpu = "0",
                 dataset = "mnist" ,arch_type = "fc1",prune_percent  = 10,prune_iterations = 35):
        self.lr = lr
        self.batch_size = batch_size
        self.start_iter = start_iter
        self.end_iter = end_iter
        self.print_freq = print_freq
        self.valid_freq = valid_freq
        self.resume = resume
        self.prune_type = prune_type #reinit
        self.gpu = gpu
        self.dataset = dataset #"mnist | cifar10 | fashionmnist | cifar100"
        self.arch_type = arch_type # "fc1 | lenet5 | alexnet | vgg16 | resnet18 | densenet121"
        self.prune_percent  = prune_percent 
        self.prune_iterations = prune_iterations 
        

In [78]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



args = argument(end_iter = 50,arch_type ="lenet5")
reinit = True if args.prune_type=="reinit" else False

#Data Loader
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
if args.dataset == "mnist":
    traindataset = datasets.MNIST('~/work/data/Xian', train=True, download=True,transform=transform)
    testdataset = datasets.MNIST('~/work/data/Xian', train=False, transform=transform)
    from archs.mnist import  LeNet5, fc1, vgg, resnet,AlexNet
# If you want to add extra datasets paste here
else:
    print("\nWrong Dataset choice \n")
    exit()

train_loader = torch.utils.data.DataLoader(traindataset, batch_size=args.batch_size, shuffle=True, num_workers=0,drop_last=False)
test_loader = torch.utils.data.DataLoader(testdataset, batch_size=args.batch_size, shuffle=False, num_workers=0,drop_last=True)

# Importing Network Architecture
global model
if args.arch_type == "fc1":
    model = fc1.fc1().to(device)
elif args.arch_type == "lenet5":
    model = LeNet5.LeNet5().to(device)
else:
    print("\nWrong Model choice\n")
    exit()


# Copying and Saving Initial State
initial_state_dict = copy.deepcopy(model.state_dict())
utils.checkdir(f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/")
torch.save(model, f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/initial_state_dict_{args.prune_type}.pth.tar")


# Optimizer and Loss
optimizer = torch.optim.Adam(model.parameters(), weight_decay=1e-4)
criterion = nn.CrossEntropyLoss() # Default was F.nll_loss

# Layer Looper
for name, param in model.named_parameters():
    print(name, param.size())
    
    
# Pruning
# NOTE First Pruning Iteration is of No Compression
bestacc = 0.0
best_accuracy = 0
ITERATION = args.prune_iterations
comp = np.zeros(ITERATION,float)
bestacc = np.zeros(ITERATION,float)
step = 0
all_loss = np.zeros(args.end_iter,float)
all_accuracy = np.zeros(args.end_iter,float)


conv1.conv.weight torch.Size([64, 1, 3, 3])
conv1.conv.bias torch.Size([64])
BN2d_1.weight torch.Size([64])
BN2d_1.bias torch.Size([64])
conv2.conv.weight torch.Size([64, 64, 3, 3])
conv2.conv.bias torch.Size([64])
BN2d_2.weight torch.Size([64])
BN2d_2.bias torch.Size([64])
fc1.linear.weight torch.Size([256, 12544])
fc1.linear.bias torch.Size([256])
BN1d_1.weight torch.Size([256])
BN1d_1.bias torch.Size([256])
fc2.linear.weight torch.Size([256, 256])
fc2.linear.bias torch.Size([256])
BN1d_2.weight torch.Size([256])
BN1d_2.bias torch.Size([256])
fc3.linear.weight torch.Size([10, 256])
fc3.linear.bias torch.Size([10])


In [76]:
def prune_percentage_nonzero(q = 10):
    global model 
    for n,m in model.named_modules():
        if isinstance(m, PrunedConv):
            m.prune_by_percentage(q = q)
        if isinstance(m, PruneLinear):
            m.prune_by_percentage(q = q)
            
def mask_weights(mask_data = True): 
    global model 
    if mask_data:
        for n, m in model.named_modules():
            if isinstance(m, PrunedConv):
                m.conv.weight.data.mul_(m.mask)
            if isinstance(m, PruneLinear):
                m.linear.weight.data.mul_(m.mask)
    else:
        for n, m in model.named_modules():
            if isinstance(m, PrunedConv):
                m.conv.weight.grad.mul_(m.mask)
            if isinstance(m, PruneLinear):
                m.linear.weight.grad.mul_(m.mask)
            
def initialize_weights(initial_state_dict):
    global model 
    for n,m in model.named_modules():
        if isinstance(m, PrunedConv):
            m.conv.weight.data = m.mask*initial_state_dict[n + '.conv.weight']
            m.conv.weight.bias = initial_state_dict[n + '.conv.bias']
        if isinstance(m, PruneLinear):
            m.linear.weight.data = m.mask*initial_state_dict[n + '.linear.weight']
            m.linear.weight.bias =initial_state_dict[n + '.linear.bias']
            
def reintilize_weights():
    global model 
    model.apply(weight_init)
    mask_weights()

In [63]:
initial_state_dict = copy.deepcopy(model.state_dict())

In [None]:
# Function for Training
def train(model, train_loader, optimizer, criterion):
    EPS = 1e-6
    
    model.train()
    
    for batch_idx, (imgs, targets) in enumerate(train_loader):
        optimizer.zero_grad()
        
        mask_weights()#Mask data into zero
        imgs, targets = imgs.to(device), targets.to(device)
        output = model(imgs)
        train_loss = criterion(output, targets)
        train_loss.backward()

        mask_weights(False) #Mask gradients of weights to zero
        optimizer.step()
        
    return train_loss.item()

# Function for Testing
def test(model, test_loader, criterion):

    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.data.max(1, keepdim=True)[1]  # get the index of the max log-probability
            correct += pred.eq(target.data.view_as(pred)).sum().item()
        test_loss /= len(test_loader.dataset)
        accuracy = 100. * correct / len(test_loader.dataset)
    return accuracy