In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_image

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

def show(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')

In [2]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Create a directory if not exists
sample_dir = 'samples'
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

In [4]:

data_dir = 'data'
# MNIST dataset
dataset = torchvision.datasets.MNIST(root=data_dir,
                                     train=True,
                                     transform=transforms.ToTensor(),
                                     download=True)
train_data = (dataset.train_data/255.).to(device)
train_labels = dataset.train_labels.to(device)
testset = torchvision.datasets.MNIST(data_dir, train=False, download=True, transform=transforms.ToTensor())
test_data = (testset.test_data/255.).to(device)
test_labels = testset.test_labels.to(device)



In [5]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5, bias=False)
        self.mask_conv1 = torch.ones_like(self.conv1.weight).to(device)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5, bias=False)
        self.mask_conv2 = torch.ones_like(self.conv2.weight).to(device)
        self.conv3 = nn.Conv2d(16, 120, 4, bias=False)
        self.mask_conv3 = torch.ones_like(self.conv3.weight).to(device)
        self.fc1 = nn.Linear(120, 84, bias=False)
        self.mask_fc1 = torch.ones_like(self.fc1.weight).to(device)
        self.fc2 = nn.Linear(84, 10, bias=False)
        self.mask_fc2 = torch.ones_like(self.fc2.weight).to(device)


    def forward(self, x):
        with torch.no_grad():
            self.conv1.weight *= self.mask_conv1
            self.conv2.weight *=self.mask_conv2
            self.conv3.weight *= self.mask_conv3
            self.fc1.weight *= self.mask_fc1
            self.fc2.weight *= self.mask_fc2
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = F.relu(self.conv3(x))
        x = x.view(-1, 120)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    def prune_weights(self, treshold):
        self.mask_conv1 = 1*(abs(self.conv1.weight) > treshold)
        self.mask_conv2 = 1*(abs(self.conv2.weight) > treshold)
        self.mask_conv3 = 1*(abs(self.conv3.weight) > treshold)
        self.mask_fc1 = 1*(abs(self.fc1.weight) > treshold)
        self.mask_fc2 = 1*(abs(self.fc2.weight) > treshold)

In [6]:
import random
def iter_batch(L1, L2, batch_size, shuffle = False):
    I = list(range(len(L1)//batch_size))
    if shuffle:
        random.shuffle(I)
    for i in I:
        yield L1[batch_size*i:batch_size*(i+1), None], L2[batch_size*i:batch_size*(i+1)]

In [7]:
def test(model, test_data, test_labels):
    correct = 0
    total = 0
    with torch.no_grad():
        for data, labels in iter_batch(test_data, test_labels, 10):
            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print('Accuracy of the network on the %d test images: %f %%' % (
        len(test_data), 100 * correct / total))
    return 100 * correct / total

In [9]:
def n_params(model):
    print("Layer conv1 : ", torch.sum(model.mask_conv1).item(), " / ", torch.sum(1*model.mask_conv1 >= 0).item())
    print("Layer conv2 : ", torch.sum(model.mask_conv2).item(), " / ", torch.sum(1*model.mask_conv2 >= 0).item())
    print("Layer conv3 : ", torch.sum(model.mask_conv3).item(), " / ", torch.sum(1*model.mask_conv3 >= 0).item())
    print("Layer fc1 : ", torch.sum(model.mask_fc1).item(), " / ", torch.sum(1*model.mask_fc1 >= 0).item())
    print("Layer fc2 : ", torch.sum(model.mask_fc2).item(), " / ", torch.sum(1*model.mask_fc2 >= 0).item())

Initial training of the network

In [10]:
import torch.optim as optim
batch_size = 64
criterion = nn.CrossEntropyLoss()
net = Net().to(device)
optimizer = optim.SGD(net.parameters(), lr=0.015)
n_epochs = 5
final_epochs = 5
n_prune = 5
for prune in range(n_prune):
    for epoch in range(n_epochs):  # loop over the dataset multiple times
                for inputs, labels in iter_batch(train_data, train_labels, batch_size, shuffle=True):
                        optimizer.zero_grad()
                        output = net(inputs)
                        loss = criterion(output, labels)
                        loss.backward()
                        optimizer.step()
                test(net, test_data, test_labels)
    net.prune_weights(0.05*(prune+1))
    print("Pruned")
    test(net, test_data, test_labels)
optimizer = optim.SGD(net.parameters(), lr=0.005)
for epoch in range(final_epochs):  # loop over the dataset multiple times
        for inputs, labels in iter_batch(train_data, train_labels, batch_size, shuffle=True):
                optimizer.zero_grad()
                output = net(inputs)
                loss = criterion(output, labels)
                loss.backward()
                optimizer.step()
        test(net, test_data, test_labels)

Accuracy of the network on the 10000 test images: 84.350000 %
Accuracy of the network on the 10000 test images: 93.990000 %
Accuracy of the network on the 10000 test images: 96.230000 %
Accuracy of the network on the 10000 test images: 97.320000 %
Accuracy of the network on the 10000 test images: 97.570000 %
Pruned
Accuracy of the network on the 10000 test images: 96.910000 %
Accuracy of the network on the 10000 test images: 97.440000 %
Accuracy of the network on the 10000 test images: 97.880000 %
Accuracy of the network on the 10000 test images: 98.040000 %
Accuracy of the network on the 10000 test images: 97.880000 %
Accuracy of the network on the 10000 test images: 98.160000 %
Pruned
Accuracy of the network on the 10000 test images: 53.390000 %
Accuracy of the network on the 10000 test images: 96.770000 %
Accuracy of the network on the 10000 test images: 97.130000 %
Accuracy of the network on the 10000 test images: 97.470000 %
Accuracy of the network on the 10000 test images: 97.570

In [11]:
optimizer = optim.SGD(net.parameters(), lr=0.001)
for epoch in range(final_epochs):  # loop over the dataset multiple times
        for inputs, labels in iter_batch(train_data, train_labels, batch_size, shuffle=True):
                optimizer.zero_grad()
                output = net(inputs)
                loss = criterion(output, labels)
                loss.backward()
                optimizer.step()
        test(net, test_data, test_labels)

Accuracy of the network on the 10000 test images: 96.590000 %
Accuracy of the network on the 10000 test images: 96.650000 %
Accuracy of the network on the 10000 test images: 96.620000 %


KeyboardInterrupt: 

In [12]:
n_params(net)

Layer conv1 :  68  /  150
Layer conv2 :  128  /  2400
Layer conv3 :  119  /  30720
Layer fc1 :  120  /  10080
Layer fc2 :  143  /  840


In [24]:
#reset the masked weights again
with torch.no_grad():
    net.conv1.weight *= net.mask_conv1
    net.conv2.weight *=net.mask_conv2
    net.conv3.weight *= net.mask_conv3
    net.fc1.weight *= net.mask_fc1
    net.fc2.weight *= net.mask_fc2

In [25]:
x = test_data[0:1, None]
x = net.pool(F.relu(net.conv1(x)))
x = net.pool(F.relu(net.conv2(x)))
x = F.relu(net.conv3(x))
x = x.view(-1, 120)
x = F.relu(net.fc1(x))
print(x)
x = net.fc2(x)

tensor([[ 0.0000,  0.0000,  0.0000,  7.4778,  0.0000,  0.0000,  0.0000,  0.0000,
          7.4172,  0.0000,  0.0000,  0.0000,  0.0000,  2.1894,  0.0000,  0.0000,
          0.0000,  6.7635,  0.0000,  6.0280,  3.2683,  0.0000,  0.0000,  0.0000,
          3.1382,  4.8983,  7.2235,  2.4603,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  4.5126,  0.0000,  4.9029,  4.9477,  0.0000,  0.0000, 14.8545,
          0.0000,  2.9681,  0.0000,  0.0000,  0.0000,  0.0000,  1.6912,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.2355,  3.4194,  0.0000,  7.8199,
          0.0000,  1.0242,  1.9522,  0.4231,  0.0000,  0.0000,  0.0000,  0.0000,
          1.7307,  6.8453,  1.4306,  0.7382,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  3.8030,  0.0000,  5.7349,  8.0474,  0.0000,
          3.5094,  0.0000,  6.5486,  0.0000,  5.8241,  0.7721,  9.2738,  0.0000,
          4.8042,  0.0000,  

In [17]:
import pickle

pkl_file = open('conv1_weights.pkl', 'wb')
pickle.dump(net.conv1.weight.data.cpu().detach().numpy().tolist(), pkl_file)
pkl_file.close()

pkl_file = open('conv2_weights.pkl', 'wb')
pickle.dump(net.conv2.weight.data.cpu().detach().numpy().tolist(), pkl_file)
pkl_file.close()

pkl_file = open('conv3_weights.pkl', 'wb')
pickle.dump(net.conv3.weight.data.cpu().detach().numpy().tolist(), pkl_file)
pkl_file.close()

pkl_file = open('fc1_weights.pkl', 'wb')
pickle.dump(net.fc1.weight.data.cpu().detach().numpy().tolist(), pkl_file)
pkl_file.close()

pkl_file = open('fc2_weights.pkl', 'wb')
pickle.dump(net.fc2.weight.data.cpu().detach().numpy().tolist(), pkl_file)
pkl_file.close()

In [21]:
print(net.conv2.weight)

Parameter containing:
tensor([[[[-0.0000,  0.0000, -0.3146,  0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000,  0.0000,  0.0000,  0.4201],
          [-0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000,  0.0000, -0.2201, -0.0000]],

         [[ 0.0000, -0.5751,  0.0000, -0.0000,  0.0000],
          [-0.0000,  0.0000,  0.0000, -0.3586, -0.0000],
          [-0.0000,  0.3859,  0.0000,  0.2977, -0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.3387],
          [-0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],

         [[-0.0000, -0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.3560,  0.4044,  0.0000],
          [-0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [-0.0000, -0.0000,  0.0000,  0.0000,  0.0000],
          [-0.0000, -0.0000,  0.0000,  0.0000,  0.0000]],

         [[ 0.0000,  0.0000,  0.0000, -0.0000, -0.3106],
          [-0.0000,  0.0000,  0.0000, -0.0000,  0.0000],
   