This work intends to replicate some of the experiments of Frankle and Carbin's Lottery Ticket Hypothesis.
The authors also have a great framework for these experiments. Check out OpenLTH!
To execute any of the original experiments run:
# don't forget to install requirements first
> python lth --help
usage: lth [-h] [-p] [--batch_size] [-o] [-lr] [-i] [-r] [-es] [-rw] [-pr]
[--recover] [-s] [-rs] [-fc] [--prune_global] [--gpu] [--quiet]
[--random]
net dataset
Run experiments with Iterative Prunning, identifying Lottery Tickets
positional arguments:
net Network architecture to use. For more info run
`utils.models.models`
dataset Dataset type (MNIST or CIFAR10)
optional arguments:
-h, --help show this help message and exit
-p , --data Path to root dataset folder (default: ./datasets/)
--batch_size Dataloader's batch size (training) (default: None)
-o , --optim Model's optimizer (default: None)
-lr , --learn_rate Learning rate (default: None)
-i , --iter Training iterations (default: 50000)
-r , --rounds Prunning rounds (default: 26)
-es , --step Evaluate validation and test set every x steps. To
evaluate every epoch, use -1 (default: None)
-rw , --rewind Number of iterations to train in the first round before
using weights as reference for later rounds. Set rewind
between (0, 1) to use it as a percetage. (default: 0)
-pr , --prune_rate Prunning rate 0-.99 (default: 0.2)
--recover Recover/resume interrupted training procedure (default:
None)
-s , --save Directory to store the experiments (default:
./experiments/)
-rs , --seed Random seed (default: None)
-fc , --fc_rate Different prunning rate for Fully Connected layers
(default: None)
--prune_global Global prunnning instead of layer-wise (default: False)
--gpu Allow for GPU usage (default: False)
--quiet Verbosity mode (default: False)
--random Random initialization (default: False)
# then, for instance
> python lth lenet mnist --rounds 20 --prune_rate 0.2
INFO [round: 0 | epoch: 1] train: 0.974 validation: 0.3793 | sparsity: 0% | duration: 3.8s
INFO [round: 1 | epoch: 1] train: 1.547 validation: 0.4314 | sparsity: 20% | duration: 6.0s
...
The code was written in Python 3.7 using Pytorch's pruning module You can also create a model and make your own pruning experiments:
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import lth
# train, validation and test dataloaders
# if data is not there, it will prompt you to download it
dataloader = lth.data.load_MNIST('./datasets/mnist', validation=4500, validation_batch_size=200)
class Custom_Model(nn.Module):
def __init__(self):
super(Custom_Model, self).__init__()
self.net = nn.Sequential(
nn.Linear(784, 200),
nn.Linear(200, 100),
nn.Linear(100, 10)
)
# there are some required attributes
self.optim = torch.optim.Adam(self.parameters(), lr=0.05)
self.optim.name = 'adam'
self._initialize_weights()
self.criterion = nn.CrossEntropyLoss()
def forward(self, x):
x = x.view(x.size(0), -1) # Flatten.
return self.net(x)
# defining required initialization method
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
# make your own pruning method
def prune(net):
layers = lth.prune.fetch_layers(net) # fetch all parameters to be pruned
for (layer, param_type) in layers:
prune.l1_unstructured(layer, name=param_type, amount=0.11)
return net
iterations = 1200
rounds = 15
net = Custom_Model()
lth.iterative_pruning(
net, dataloader, iterations, rounds, prune
)