In [11]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_image
import torch.nn.utils.prune as prune
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

def show(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')

In [2]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Create a directory if not exists
sample_dir = 'samples'
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

In [3]:

data_dir = 'data'
# MNIST dataset
dataset = torchvision.datasets.MNIST(root=data_dir,
                                     train=True,
                                     transform=transforms.ToTensor(),
                                     download=True)
train_data = (dataset.train_data/255.).to(device)
train_labels = dataset.train_labels.to(device)
testset = torchvision.datasets.MNIST(data_dir, train=False, download=True, transform=transforms.ToTensor())
test_data = (testset.test_data/255.).to(device)
test_labels = testset.test_labels.to(device)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


100.1%

Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


113.5%

Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


100.4%

Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


180.4%

Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw
Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [31]:
#the network uses masks to avoid training some weights. 
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5, bias=False)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5, bias=False)
        self.conv3 = nn.Conv2d(16, 120, 4, bias=False)
        self.fc1 = nn.Linear(120, 84, bias=False)
        self.fc2 = nn.Linear(84, 10, bias=False)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = F.relu(self.conv3(x))
        x = x.view(-1, 120)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    def prune_weights(self, amount):
        parameters_to_prune = (
            (self.conv1, 'weight'),
            (self.conv2, 'weight'),
            (self.conv3, 'weight'),
            (self.fc1, 'weight'),
            (self.fc2, 'weight'),
        )

        prune.global_unstructured(
            parameters_to_prune,
            pruning_method=prune.L1Unstructured,
            amount=amount,
        )

In [32]:
import random
def iter_batch(L1, L2, batch_size, shuffle = False):
    I = list(range(len(L1)//batch_size))
    if shuffle:
        random.shuffle(I)
    for i in I:
        yield L1[batch_size*i:batch_size*(i+1), None], L2[batch_size*i:batch_size*(i+1)]

In [33]:
def test(model, test_data, test_labels):
    correct = 0
    total = 0
    with torch.no_grad():
        for data, labels in iter_batch(test_data, test_labels, 10):
            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print('Accuracy of the network on the %d test images: %f %%' % (
        len(test_data), 100 * correct / total))
    return 100 * correct / total

In [34]:
def n_params(model):
    print("Layer conv1 : ", torch.sum(model.conv1.weight_mask).item(), " / ", torch.sum(1*model.conv1.weight_mask >= 0).item())
    print("Layer conv2 : ", torch.sum(model.conv2.weight_mask).item(), " / ", torch.sum(1*model.conv2.weight_mask >= 0).item())
    print("Layer conv3 : ", torch.sum(model.conv3.weight_mask).item(), " / ", torch.sum(1*model.conv3.weight_mask >= 0).item())
    print("Layer fc1 : ", torch.sum(model.fc1.weight_mask).item(), " / ", torch.sum(1*model.fc1.weight_mask >= 0).item())
    print("Layer fc2 : ", torch.sum(model.fc2.weight_mask).item(), " / ", torch.sum(1*model.fc2.weight_mask >= 0).item())

Initial training of the network

In [35]:
import torch.optim as optim
batch_size = 64
criterion = nn.CrossEntropyLoss()
net = Net().to(device)
optimizer = optim.SGD(net.parameters(), lr=1.5e-2, momentum=0.9, weight_decay=2e-3)
n_epochs = 3
final_epochs = 30
n_prune = 18
for i_prune in range(n_prune):
    for epoch in range(n_epochs):  # loop over the dataset multiple times
                for inputs, labels in iter_batch(train_data, train_labels, batch_size, shuffle=True):
                        optimizer.zero_grad()
                        output = net(inputs)
                        loss = criterion(output, labels)
                        loss.backward()
                        optimizer.step()
                test(net, test_data, test_labels)
    net.prune_weights(0.2)
    print("Pruned 20% of weights")
    test(net, test_data, test_labels)
#final training phase with low learning rate
optimizer = optim.SGD(net.parameters(), lr=5e-3)
batch_size = 32
for epoch in range(final_epochs):  # loop over the dataset multiple times
        for inputs, labels in iter_batch(train_data, train_labels, batch_size, shuffle=True):
                optimizer.zero_grad()
                output = net(inputs)
                loss = criterion(output, labels)
                loss.backward()
                optimizer.step()
        test(net, test_data, test_labels)

Accuracy of the network on the 10000 test images: 97.090000 %
Accuracy of the network on the 10000 test images: 98.200000 %
Accuracy of the network on the 10000 test images: 97.820000 %
Pruned 20% of weights
Accuracy of the network on the 10000 test images: 97.850000 %
Accuracy of the network on the 10000 test images: 98.580000 %
Accuracy of the network on the 10000 test images: 98.430000 %
Accuracy of the network on the 10000 test images: 98.790000 %
Pruned 20% of weights
Accuracy of the network on the 10000 test images: 98.770000 %
Accuracy of the network on the 10000 test images: 98.270000 %
Accuracy of the network on the 10000 test images: 98.270000 %
Accuracy of the network on the 10000 test images: 98.810000 %
Pruned 20% of weights
Accuracy of the network on the 10000 test images: 98.850000 %
Accuracy of the network on the 10000 test images: 98.980000 %
Accuracy of the network on the 10000 test images: 98.590000 %
Accuracy of the network on the 10000 test images: 98.650000 %
Prun

In [36]:
n_params(net)

Layer conv1 :  68.0  /  150
Layer conv2 :  189.0  /  2400
Layer conv3 :  259.0  /  30720
Layer fc1 :  177.0  /  10080
Layer fc2 :  103.0  /  840


In [38]:
import pickle

pkl_file = open('training_data/conv1_weights.pkl', 'wb')
pickle.dump(net.conv1.weight.data.cpu().detach().numpy().tolist(), pkl_file)
pkl_file.close()

pkl_file = open('training_data/conv2_weights.pkl', 'wb')
pickle.dump(net.conv2.weight.data.cpu().detach().numpy().tolist(), pkl_file)
pkl_file.close()

pkl_file = open('training_data/conv3_weights.pkl', 'wb')
pickle.dump(net.conv3.weight.data.cpu().detach().numpy().tolist(), pkl_file)
pkl_file.close()

pkl_file = open('training_data/fc1_weights.pkl', 'wb')
pickle.dump(net.fc1.weight.data.cpu().detach().numpy().tolist(), pkl_file)
pkl_file.close()

pkl_file = open('training_data/fc2_weights.pkl', 'wb')
pickle.dump(net.fc2.weight.data.cpu().detach().numpy().tolist(), pkl_file)
pkl_file.close()