# The Lottery Ticket Hypothesis
##### Due to Jonathan Frankle and Michael Carbin (2018)

Let's begin with a feedforward neural network made to recognize handwritten digits

In [1]:
# Some setup:
from digit import *
from IPython.display import Image
from IPython.core.display import HTML 

# Define the transformations for the data
transform = transforms.Compose([transforms.ToTensor(),
                                 transforms.Normalize((0.5,), (0.5,))])

# Load the MNIST dataset
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False)

We'll define a Net class with init() and forward() methods:
We expect 28x28 pixel images to be mapped to 10 outputs, the digits 0-9.

In [2]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

Nothing new so far. Let's initialize a network, and for a secret reason, save the initial state. Now we can train and test the net. 

In [3]:
net = Net()
torch.save(net.state_dict(), 'init.pt')

train_network(net, trainloader, None, 4)
test_network(net, testloader)

[1,  1000] loss: 0.460
[2,  1000] loss: 0.162
[3,  1000] loss: 0.121
[4,  1000] loss: 0.096
Accuracy of the network on the 10000 test images: 96 %


Amazing! The network correctly identifies over 90% of the test images in just four epochs. 


## Pruning

You may have heard of the concept of "pruning" a neural network. In general, pruning is the removal of select weights or neurons from a network to save on space and computation. For our purposes, just think of it as the zeroing-out of some percentage of the __lowest weights__ in a network. In their 1989 paper _"Optimal Brain Damage,"_ Yann Lecun, John Denker, and Sara Solla claim that "By removing unimportant weights from a network, several improvements can be expected: better generalization, fewer training examples, and improved speed of learning and/or classification."

Let's give our own little net some brain damage!

In [4]:
# set bottom 40% of weights to 0
pruning_percent = 0.4
p40_pruning_mask, net = prune_network(net, pruning_percent)

print("Sparsity in model: {:.2f}%".format(100 * float(torch.sum(net.fc1.weight == 0) + torch.sum(net.fc2.weight == 0) + torch.sum(net.fc3.weight == 0)) / float(net.fc1.weight.nelement() + net.fc2.weight.nelement() + net.fc3.weight.nelement())))
test_network(net, testloader)


Pruning layer fc1.weight with 40140 weights to prune.
Pruning layer fc2.weight with 3276 weights to prune.
Pruning layer fc3.weight with 256 weights to prune.
Sparsity in model: 40.00%
Accuracy of the network on the 10000 test images: 96 %


Look at that! Cutting 40% of the weights seems not to have hurt our accuracy. How far can this idea be pushed? 

In [5]:
# set bottom 60% of weights to 0
pruning_percent = 0.6
pruning_mask, net = prune_network(net, pruning_percent)

print("Sparsity in model: {:.2f}%".format(100 * float(torch.sum(net.fc1.weight == 0) + torch.sum(net.fc2.weight == 0) + torch.sum(net.fc3.weight == 0)) / float(net.fc1.weight.nelement() + net.fc2.weight.nelement() + net.fc3.weight.nelement())))
test_network(net, testloader)

Pruning layer fc1.weight with 60211 weights to prune.
Pruning layer fc2.weight with 4915 weights to prune.
Pruning layer fc3.weight with 384 weights to prune.
Sparsity in model: 60.00%
Accuracy of the network on the 10000 test images: 96 %


In [6]:
# set bottom 80% of weights to 0
pruning_percent = 0.8
pruning_mask, net = prune_network(net, pruning_percent)

print("Sparsity in model: {:.2f}%".format(100 * float(torch.sum(net.fc1.weight == 0) + torch.sum(net.fc2.weight == 0) + torch.sum(net.fc3.weight == 0)) / float(net.fc1.weight.nelement() + net.fc2.weight.nelement() + net.fc3.weight.nelement())))
test_network(net, testloader)


Pruning layer fc1.weight with 80281 weights to prune.
Pruning layer fc2.weight with 6553 weights to prune.
Pruning layer fc3.weight with 512 weights to prune.
Sparsity in model: 80.00%
Accuracy of the network on the 10000 test images: 85 %


Ok, it's starting to suffer a little... One more go:

In [7]:
# set bottom 95% of weights to 0
pruning_percent = 0.95
pruning_mask, net = prune_network(net, pruning_percent)

print("Sparsity in model: {:.2f}%".format(100 * float(torch.sum(net.fc1.weight == 0) + torch.sum(net.fc2.weight == 0) + torch.sum(net.fc3.weight == 0)) / float(net.fc1.weight.nelement() + net.fc2.weight.nelement() + net.fc3.weight.nelement())))
test_network(net, testloader)


Pruning layer fc1.weight with 95334 weights to prune.
Pruning layer fc2.weight with 7782 weights to prune.
Pruning layer fc3.weight with 608 weights to prune.
Sparsity in model: 95.00%
Accuracy of the network on the 10000 test images: 21 %


Alas, our experiment with pruning has come to an end. Its practical limit (in this case) seems to have been between 60% and 80%, which is impressive. 

In [19]:
Image(url= "https://media.makeameme.org/created/not-so-fast-e35e21f418.jpg")

Instead of just pruning the network, let's retrain the 40% pruned version with the original initialization weights...

In [8]:
# Reset
net = Net()
net.load_state_dict(torch.load('init.pt'))

# Train with 40% removed
train_network(net, trainloader, p40_pruning_mask, 4)

print("Sparsity in model: {:.2f}%".format(100 * float(torch.sum(net.fc1.weight == 0) + torch.sum(net.fc2.weight == 0) + torch.sum(net.fc3.weight == 0)) / float(net.fc1.weight.nelement() + net.fc2.weight.nelement() + net.fc3.weight.nelement())))
test_network(net, testloader)

[1,  1000] loss: 0.473
[2,  1000] loss: 0.146
[3,  1000] loss: 0.108
[4,  1000] loss: 0.087
Sparsity in model: 40.00%
Accuracy of the network on the 10000 test images: 97 %


Let's take another 40% off.

In [9]:
pruning_percent = 1 - 0.6**2

pruning_mask, net = prune_network(net, pruning_percent)

# Reset
net = Net()
net.load_state_dict(torch.load('init.pt'))

# Train with 64% removed
train_network(net, trainloader, pruning_mask, 4)

print("Sparsity in model: {:.2f}%".format(100 * float(torch.sum(net.fc1.weight == 0) + torch.sum(net.fc2.weight == 0) + torch.sum(net.fc3.weight == 0)) / float(net.fc1.weight.nelement() + net.fc2.weight.nelement() + net.fc3.weight.nelement())))
test_network(net, testloader)

Pruning layer fc1.weight with 64225 weights to prune.
Pruning layer fc2.weight with 5242 weights to prune.
Pruning layer fc3.weight with 409 weights to prune.
[1,  1000] loss: 0.533
[2,  1000] loss: 0.157
[3,  1000] loss: 0.114
[4,  1000] loss: 0.100
Sparsity in model: 64.00%
Accuracy of the network on the 10000 test images: 96 %


### Again!

In [10]:
pruning_percent = 1 - 0.6**3

pruning_mask, net = prune_network(net, pruning_percent)

# Reset
net = Net()
net.load_state_dict(torch.load('init.pt'))

# Train with 78.4% removed
train_network(net, trainloader, pruning_mask, 4)

print("Sparsity in model: {:.2f}%".format(100 * float(torch.sum(net.fc1.weight == 0) + torch.sum(net.fc2.weight == 0) + torch.sum(net.fc3.weight == 0)) / float(net.fc1.weight.nelement() + net.fc2.weight.nelement() + net.fc3.weight.nelement())))
test_network(net, testloader)

Pruning layer fc1.weight with 78675 weights to prune.
Pruning layer fc2.weight with 6422 weights to prune.
Pruning layer fc3.weight with 501 weights to prune.
[1,  1000] loss: 0.596
[2,  1000] loss: 0.190
[3,  1000] loss: 0.138
[4,  1000] loss: 0.114
Sparsity in model: 78.40%
Accuracy of the network on the 10000 test images: 96 %


## Go to 95%!!!

In [11]:
pruning_percent = 1 - 0.6**6 

pruning_mask, net = prune_network(net, pruning_percent)

# Reset
net = Net()
net.load_state_dict(torch.load('init.pt'))

# Train with 95.33% removed
train_network(net, trainloader, pruning_mask, 4)

print("Sparsity in model: {:.2f}%".format(100 * float(torch.sum(net.fc1.weight == 0) + torch.sum(net.fc2.weight == 0) + torch.sum(net.fc3.weight == 0)) / float(net.fc1.weight.nelement() + net.fc2.weight.nelement() + net.fc3.weight.nelement())))
test_network(net, testloader)

Pruning layer fc1.weight with 95669 weights to prune.
Pruning layer fc2.weight with 7809 weights to prune.
Pruning layer fc3.weight with 610 weights to prune.
[1,  1000] loss: 1.028
[2,  1000] loss: 0.227
[3,  1000] loss: 0.177
[4,  1000] loss: 0.155
Sparsity in model: 95.33%
Accuracy of the network on the 10000 test images: 95 %


## The Lottery Ticket Hypothesis

... is the idea that when a dense neural network is successfully trained, it is successful by virtue of having found a subnetwork which performs the necessary operations. This subnetwork is usually a fraction of the size the overall dense neural network, and if the wieghts outside of the subnetwork are removed, the subnetwork can still perform quite well.

This **sparse** subnetwork (with 5-20% of the weights of the full network) can be formed out of the initialization weights and trained more quickly, sometimes to a higher degree of accuracy on the test data, than the overall model.

In [3]:
Image(url = "https://external-content.duckduckgo.com/iu/?u=https%3A%2F%2Falberton.info%2Fimages%2Farticles%2Fgraphs%2Fgraphs_complete_sparse.png")

### Initialization matters
When the same sparse networks are used but new initialization weights are chosen, the models take much longer to train. This is because those initialization weights are part of what determine the structure of the subnetwork; they're poised to get to where they need to go in response to feedback from the optimizer.

### There's a catch!

PyTorch actually doesn't have support for spare neural nets, so we have to use dense tensors with zeros everywhere a weight is missing. This actually takes *the same amount of time and energy* to train and run as a full dense network, so ... while **in theory** the lottery ticket hypothesis implies better time and energy efficiency for neural networks, the state-of-the-art tools for building neural networks haven't yet implemented tools for taking advantage of this.

In [4]:
Image(url = "https://slj.ma/step.png")

In [28]:
Image(url = "https://slj.ma/acc.png")

The original number of weights in the neural network was __109,184__, and the number of weights in the smallest winning ticket was __1965__.