In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_image
import torch.nn.utils.prune as prune
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch.onnx
def show(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')

In [2]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Create a directory if not exists
sample_dir = 'samples'
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

In [3]:

data_dir = 'data'
# MNIST dataset
dataset = torchvision.datasets.MNIST(root=data_dir,
                                     train=True,
                                     transform=transforms.ToTensor(),
                                     download=True)
train_data = (dataset.train_data/255.).to(device)
train_labels = dataset.train_labels.to(device)
testset = torchvision.datasets.MNIST(data_dir, train=False, download=True, transform=transforms.ToTensor())
test_data = (testset.test_data/255.).to(device)
test_labels = testset.test_labels.to(device)



In [4]:
#the network uses masks to avoid training some weights. 
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5, bias=False)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5, bias=False)
        self.conv3 = nn.Conv2d(16, 120, 4, bias=False)
        self.fc1 = nn.Linear(120, 84, bias=False)
        self.fc2 = nn.Linear(84, 10, bias=False)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = F.relu(self.conv3(x))
        x = x.view(-1, 120)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    def prune_weights(self, amount):
        parameters_to_prune = (
            (self.conv1, 'weight'),
            (self.conv2, 'weight'),
            (self.conv3, 'weight'),
            (self.fc1, 'weight'),
            (self.fc2, 'weight'),
        )

        prune.global_unstructured(
            parameters_to_prune,
            pruning_method=prune.L1Unstructured,
            amount=amount,
        )

In [5]:
import random
def iter_batch(L1, L2, batch_size, shuffle = False):
    I = list(range(len(L1)//batch_size))
    if shuffle:
        random.shuffle(I)
    for i in I:
        yield L1[batch_size*i:batch_size*(i+1), None], L2[batch_size*i:batch_size*(i+1)]

In [6]:
def test(model, test_data, test_labels):
    correct = 0
    total = 0
    with torch.no_grad():
        for data, labels in iter_batch(test_data, test_labels, 10):
            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print('Accuracy of the network on the %d test images: %f %%' % (
        len(test_data), 100 * correct / total))
    return 100 * correct / total

In [7]:
def n_params(model):
    print("Layer conv1 : ", torch.sum(model.conv1.weight_mask).item(), " / ", torch.sum(1*model.conv1.weight_mask >= 0).item())
    print("Layer conv2 : ", torch.sum(model.conv2.weight_mask).item(), " / ", torch.sum(1*model.conv2.weight_mask >= 0).item())
    print("Layer conv3 : ", torch.sum(model.conv3.weight_mask).item(), " / ", torch.sum(1*model.conv3.weight_mask >= 0).item())
    print("Layer fc1 : ", torch.sum(model.fc1.weight_mask).item(), " / ", torch.sum(1*model.fc1.weight_mask >= 0).item())
    print("Layer fc2 : ", torch.sum(model.fc2.weight_mask).item(), " / ", torch.sum(1*model.fc2.weight_mask >= 0).item())

Initial training of the network

In [8]:
import torch.optim as optim
batch_size = 64
criterion = nn.CrossEntropyLoss()
net = Net().to(device)
optimizer = optim.SGD(net.parameters(), lr=1.5e-2, momentum=0.9, weight_decay=2e-3)
n_epochs = 3
final_epochs = 15
n_prune = 18
for i_prune in range(n_prune):
    for epoch in range(n_epochs):  # loop over the dataset multiple times
                for inputs, labels in iter_batch(train_data, train_labels, batch_size, shuffle=True):
                        optimizer.zero_grad()
                        output = net(inputs)
                        loss = criterion(output, labels)
                        loss.backward()
                        optimizer.step()
                test(net, test_data, test_labels)
    net.prune_weights(0.2)
    print("Pruned 20% of weights")
    test(net, test_data, test_labels)
#final training phase with low learning rate
optimizer = optim.SGD(net.parameters(), lr=5e-3)
batch_size = 32
for epoch in range(final_epochs):  # loop over the dataset multiple times
        for inputs, labels in iter_batch(train_data, train_labels, batch_size, shuffle=True):
                optimizer.zero_grad()
                output = net(inputs)
                loss = criterion(output, labels)
                loss.backward()
                optimizer.step()
        test(net, test_data, test_labels)

Accuracy of the network on the 10000 test images: 97.170000 %
Accuracy of the network on the 10000 test images: 97.940000 %
Accuracy of the network on the 10000 test images: 98.320000 %
Pruned 20% of weights
Accuracy of the network on the 10000 test images: 98.310000 %
Accuracy of the network on the 10000 test images: 98.850000 %
Accuracy of the network on the 10000 test images: 98.680000 %
Accuracy of the network on the 10000 test images: 98.360000 %
Pruned 20% of weights
Accuracy of the network on the 10000 test images: 98.370000 %
Accuracy of the network on the 10000 test images: 98.930000 %
Accuracy of the network on the 10000 test images: 98.680000 %
Accuracy of the network on the 10000 test images: 98.800000 %
Pruned 20% of weights
Accuracy of the network on the 10000 test images: 98.800000 %
Accuracy of the network on the 10000 test images: 98.590000 %
Accuracy of the network on the 10000 test images: 98.930000 %
Accuracy of the network on the 10000 test images: 98.540000 %
Prun

In [9]:
n_params(net)

Layer conv1 :  87.0  /  150
Layer conv2 :  199.0  /  2400
Layer conv3 :  246.0  /  30720
Layer fc1 :  171.0  /  10080
Layer fc2 :  93.0  /  840


In [10]:
dummy_input = torch.randn(1, 1, 28, 28)
torch.onnx.export(net,               # model being run
                  dummy_input,                         # model input (or a tuple for multiple inputs)
                  "net.onnx",   # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=10,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['input'],   # the model's input names
                  output_names = ['output'], # the model's output names
                  verbose=True
                 )


graph(%input : Float(1:784, 1:784, 28:28, 28:1, requires_grad=0, device=cpu),
      %37 : Float(6:25, 1:25, 5:5, 5:1, requires_grad=0, device=cpu),
      %39 : Float(16:150, 6:25, 5:5, 5:1, requires_grad=0, device=cpu),
      %41 : Float(120:256, 16:16, 4:4, 4:1, requires_grad=0, device=cpu),
      %44 : Float(120:1, 84:120, requires_grad=0, device=cpu),
      %47 : Float(84:1, 10:84, requires_grad=0, device=cpu)):
  %13 : Float(1:3456, 6:576, 24:24, 24:1, requires_grad=1, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[5, 5], pads=[0, 0, 0, 0], strides=[1, 1]](%input, %37) # /home/timo/.local/lib/python3.8/site-packages/torch/nn/modules/conv.py:419:0
  %14 : Float(1:3456, 6:576, 24:24, 24:1, requires_grad=1, device=cpu) = onnx::Relu(%13) # /home/timo/.local/lib/python3.8/site-packages/torch/nn/functional.py:1136:0
  %15 : Float(1:864, 6:144, 12:12, 12:1, requires_grad=1, device=cpu) = onnx::MaxPool[ceil_mode=0, kernel_shape=[2, 2], pads=[0, 0, 0, 0], strides=[2, 2]](

In [28]:
x = test_data[None, 0:1]
x = net.pool(F.relu(net.conv1(x)))
x = net.pool(F.relu(net.conv2(x)))
x = F.relu(net.conv3(x))
x = x.view(-1, 120)
x = F.relu(net.fc1(x))
x = net.fc2(x)
print(x)

tensor([[-4.1751, -0.0403,  0.7811,  2.3847, -1.4277, -2.4714, -7.6809, 15.5948,
         -3.1924,  2.2708]], grad_fn=<MmBackward>)


In [38]:
import pickle

pkl_file = open('training_data/conv1_weights.pkl', 'wb')
pickle.dump(net.conv1.weight.data.cpu().detach().numpy().tolist(), pkl_file)
pkl_file.close()

pkl_file = open('training_data/conv2_weights.pkl', 'wb')
pickle.dump(net.conv2.weight.data.cpu().detach().numpy().tolist(), pkl_file)
pkl_file.close()

pkl_file = open('training_data/conv3_weights.pkl', 'wb')
pickle.dump(net.conv3.weight.data.cpu().detach().numpy().tolist(), pkl_file)
pkl_file.close()

pkl_file = open('training_data/fc1_weights.pkl', 'wb')
pickle.dump(net.fc1.weight.data.cpu().detach().numpy().tolist(), pkl_file)
pkl_file.close()

pkl_file = open('training_data/fc2_weights.pkl', 'wb')
pickle.dump(net.fc2.weight.data.cpu().detach().numpy().tolist(), pkl_file)
pkl_file.close()