In [25]:
import torch
import torch.nn as nn
import keras
import tensorflow_datasets as tfds
import numpy as np
import sys
import matplotlib.pyplot as plt
sys.path.append('..')
import condense
condense.logger.setLevel("INFO")


ds_train, ds_test = tfds.load('mnist', split=['train', 'test'], shuffle_files=True, as_supervised=True)
def generator(batch_size, data_set):
    _gen = iter(tfds.as_numpy(data_set.batch(batch_size).cache()))
    while True:
        X, y = next(_gen)
        yield torch.Tensor(X.reshape(batch_size, 1, 28, 28)), torch.Tensor(y).type(torch.LongTensor)

gen = generator(300, ds_train)
gen_test = generator(250, ds_test)

# Model Definition

In [26]:
class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.layer1 = nn.Conv2d(in_channels=1, out_channels=4, kernel_size=4, stride=2)
        self.layer2 = nn.Conv2d(in_channels=4, out_channels=2, kernel_size=2)
        self.dense = nn.Linear(288, out_features=50)
        self.output = nn.Linear(50, out_features=10)
    

    def forward(self, X):
        X = self.layer1.forward(X)
        X = self.layer2.forward(torch.relu(X))
        X = X.view(torch.relu(X).size(0), -1)
        X = self.dense.forward(torch.relu(X))
        X = self.output.forward(torch.relu(X))
        X = torch.log_softmax(X, 1)
        return X
    
    def train(self, d, epochs=20):
        criterion = nn.CrossEntropyLoss()
        optim = torch.optim.SGD(self.parameters(), lr=0.01, weight_decay=0.1)
        
        for _ in range(epochs):
            X, y = next(d)
            for _ in range(20):
                self.zero_grad()
                pred = self.forward(X)
                l = criterion(pred, y)
                l.backward()
                optim.step()
            print('Training Loss:', float(l))
            
net = Network()

# Apply `PruningAgent` module to your model

As we can see, no pruning has happened yet and all sparsity masks are initialized to `1`.

In [27]:
pruned = condense.torch.PruningAgent(net, condense.optimizer.sparsity_functions.Constant(0.8))   
f'Parameter Sparsity: {pruned.get_parameter_sparsity()}'

'Parameter Sparsity: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]'

# Start the ticket search process
The ticket search process is training operation agnostic, so you are free to use whatever you like to train your model.
On `TicketSearch` entry all model parameters get saved for reinitialization later on.
On exit, parameters get reinitialized to their original values and masks get applied.

In [28]:
with condense.torch.TicketSearch(pruned):
    pruned.model.train(gen, 10)  
    
f'Parameter Sparsity: {pruned.get_parameter_sparsity()}'    

INFO:condense:💾 Storing module parameters for reinitialization
INFO:condense:🔬 Searching for winning ticket


Training Loss: 1.7746738195419312
Training Loss: 0.48766133189201355
Training Loss: 0.2685297429561615
Training Loss: 0.13927403092384338
Training Loss: 0.13883404433727264
Training Loss: 0.1436176598072052
Training Loss: 0.053960900753736496
Training Loss: 0.05116381496191025
Training Loss: 0.09471690654754639


INFO:condense:🎟 Winning ticket found
INFO:condense:⚙️ Generating Mask
INFO:condense:😄 Reinitialized module parameters
INFO:condense:🥷 Ticket masks applied to module parameters


Training Loss: 0.11399843543767929


'Parameter Sparsity: [0.8125, 1.0, 0.8125, 1.0, 0.8000694444444445, 0.82, 0.802, 0.9]'

# Start the actual training on the pruned model

In [None]:
pruned.model.train(gen, 10) 
print(f'Parameter Sparsity: {pruned.get_parameter_sparsity()}') 
plt.imshow(np.abs(pruned.model.output.weight.detach().numpy()), vmin=0, vmax=.1)

Training Loss: 0.957158625125885
Training Loss: 0.6910885572433472
Training Loss: 0.45791131258010864
Training Loss: 0.42609715461730957
Training Loss: 0.32559362053871155
