In [190]:
import argparse
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np

class Layer():
    def __init__(self, input_dim, output_dim, idx):
        self.layer = nn.Linear(input_dim, output_dim)
        self.activations = None
        self.idx = idx

    def forward(self, x):
        self.activations = self.layer(x)
        return self.activations
    
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = Layer(320, 150, 0)
        self.fc2 = Layer(150, 150, 1)
        self.fc3 = Layer(150, 150, 2)
        self.fc4 = Layer(150, 150, 3)
        self.fc5 = Layer(150, 10, 4)
        
        self.prunable_layers = [self.fc1,self.fc2,self.fc3,self.fc4,self.fc5] 

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1.forward(x))
        x = F.relu(self.fc2.forward(x))
        x = F.relu(self.fc3.forward(x))
        x = F.relu(self.fc4.forward(x))
        x = self.fc5.forward(x)
        return F.log_softmax(x, dim=1)
    
    def remove_nodes(self, to_remove):
        for layer_index, nodes_to_remove in to_remove.items():
            n_remove = len(to_remove)
            layer = self.prunable_layers[layer_index]
            layer.layer.out_features -= n_remove
            
            # delete layer_index row in layer, and column in next layer
            np_weights = layer.layer.weight.detach().numpy()
            np_weights = np.delete(np_weights, nodes_to_remove, axis=0)
            layer.layer.weight = Parameter(torch.Tensor(np_weights))
            if(layer.layer.bias is not None):
                layer_weights = layer.layer.bias.detach().numpy()
                layer_weights = np.delete(layer_weights, nodes_to_remove)
                layer.layer.bias = Parameter(torch.Tensor(layer_weights))
            
            next_layer = self.prunable_layers[layer_index + 1]
            next_layer.layer.in_features -= 1
            np_weights = next_layer.layer.weight.detach().numpy()
            np_weights = np.delete(np_weights, nodes_to_remove, axis=1)
            next_layer.layer.weight = Parameter(torch.Tensor(np_weights))

        '''assert layer_index < 4
        layer = self.prunable_layers[layer_index]
        assert(node < self.prunable_layers[layer_index].layer.out_features )
        layer.layer.out_features -= 1
        # delete layer_index row in layer, and column in next layer
        np_weights = layer.layer.weight.detach().numpy()
        np_weights = np.delete(np_weights, layer_index, axis=0)
        layer.layer.weight = Parameter(torch.Tensor(np_weights))
        if(layer.layer.bias is not None):
            layer_weights = layer.layer.bias.detach().numpy()
            layer_weights = np.delete(layer_weights, layer_index)
            layer.layer.bias = Parameter(torch.Tensor(layer_weights))
        next_layer = self.prunable_layers[layer_index + 1]
        next_layer.layer.in_features -= 1
        np_weights = next_layer.layer.weight.detach().numpy()
        np_weights = np.delete(np_weights, layer_index, axis=1)
        next_layer.layer.weight = Parameter(torch.Tensor(np_weights))
        
        print(layer.layer.weight.shape)
        print(next_layer.layer.weight.shape)'''

def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [191]:
torch.manual_seed(1)

device = torch.device("cpu")

kwargs = {}
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=64, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=64, shuffle=True, **kwargs)

model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

In [192]:
for epoch in range(10):
    train(model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)

val_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=10000, shuffle=True, **kwargs)
test(model, device, val_loader)


Test set: Average loss: 2.3002, Accuracy: 1235/10000 (12%)




Test set: Average loss: 2.2967, Accuracy: 1073/10000 (11%)


Test set: Average loss: 2.2914, Accuracy: 957/10000 (10%)




Test set: Average loss: 2.2795, Accuracy: 1641/10000 (16%)


Test set: Average loss: 2.2622, Accuracy: 2160/10000 (22%)




Test set: Average loss: 2.2399, Accuracy: 2530/10000 (25%)




Test set: Average loss: 2.2139, Accuracy: 2576/10000 (26%)


Test set: Average loss: 2.1836, Accuracy: 2883/10000 (29%)




Test set: Average loss: 2.1518, Accuracy: 3246/10000 (32%)


Test set: Average loss: 2.1189, Accuracy: 3465/10000 (35%)


Test set: Average loss: 2.1189, Accuracy: 3465/10000 (35%)



In [195]:
activ_list = []
n_layers = len(model.prunable_layers) - 1
for layer_num in range(n_layers - 1):
    layer = model.prunable_layers[layer_num]
    activs = np.var(layer.activations.detach().numpy(), axis=0)
    activ_list.append(activs)
    
max_nodes = max(len(row) for row in activ_list)
activations = np.empty((n_layers, max_nodes))
for layer_num, activs in enumerate(activ_list):
    activations[layer_num,:len(activs)] += activs

print(activations.ravel())
# https://stackoverflow.com/questions/30577375/have-numpy-argsort-return-an-array-of-2d-indices
sorted = np.dstack(np.unravel_index(np.argsort(activations.ravel()), activations.shape))[0]
indices_to_remove = sorted[:50, :]
to_remove = {}
print(indices_to_remove, sorted)
for i in indices_to_remove:
    layer_num = i[0]
    node_num = i[1]
    if layer_num in to_remove:
        to_remove[layer_num].append(node_num)
    else:
        to_remove[layer_num] = [node_num]
model.remove_nodes(to_remove)

[2.71762665e+002 1.48338211e+002 1.09639091e+002 1.62244736e+002
 1.59837540e+002 1.48485992e+002 2.20620159e+161 1.96266954e+243
 2.07958837e+262 2.14435303e+002 9.85330622e+165 6.67105512e+135
 1.11566383e+002 1.01883614e+002 1.07956940e+002 1.02436569e+002
 1.04195221e+002 1.14269356e+002 7.74549484e+001 1.16329964e+002
 1.76577759e+002 3.54014587e+002 2.11535327e+257 1.46971008e+002
 9.89803615e+164 8.89115793e+179 1.83573525e+223 2.37697104e+137
 9.02193423e+217 9.10611877e+001 2.03292755e+002 6.65878601e+001
 1.29315989e+161 8.90567240e+252 1.78410919e+002 6.09079069e+247
 1.52447601e+002 1.34668121e+002 4.63461863e+228 1.98453918e+002
 2.44475250e+002 1.59993698e+002 1.61410559e+132 2.11161559e+257
 2.83837555e+002 2.88067937e+214 9.89803615e+164 1.86718964e+002
 8.34413757e+001 1.62074188e+002 1.99692490e+002 2.11680463e+257
 2.05326996e+002 1.58298477e+002 1.00385544e+218 1.73470917e+002
 2.13412571e+257 1.15026436e+002 1.59346725e+002 1.36068726e+002
 4.82594598e+276 9.898036



In [196]:
for layer in model.prunable_layers:
    print(layer.layer.weight.shape)
for epoch in range(5):
    train(model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)

val_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=10000, shuffle=True, **kwargs)
test(model, device, val_loader)

torch.Size([150, 320])
torch.Size([150, 150])
torch.Size([150, 150])
torch.Size([65, 150])
torch.Size([10, 65])

Test set: Average loss: 2.1521, Accuracy: 2963/10000 (30%)




Test set: Average loss: 2.1292, Accuracy: 3213/10000 (32%)


Test set: Average loss: 2.1075, Accuracy: 3397/10000 (34%)




Test set: Average loss: 2.0888, Accuracy: 3551/10000 (36%)


Test set: Average loss: 2.0677, Accuracy: 3733/10000 (37%)


Test set: Average loss: 2.0677, Accuracy: 3733/10000 (37%)

