#  Lottery Notebook

### imports :

In [34]:
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
import numpy as np
import torch.nn.utils.prune as prune
import matplotlib.pyplot as plt
from pruning import Pruning_tool
import copy
import time
if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")
print('using', device)

using cuda:0


## Model :

In [42]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(784, 400)
        self.fc2 = nn.Linear(400, 200)
        self.fc3 = nn.Linear(200,100)
        self.fc41 = nn.Linear(100, 32)
        self.fc42 = nn.Linear(100, 32)
        self.fc5 = nn.Linear(32, 100)
        self.fc6 = nn.Linear(100,200)
        self.fc7 = nn.Linear(200,400)
        self.fc8 = nn.Linear(400, 784)

    def encode(self, x):
        h1 = F.relu(self.fc3(self.fc2(self.fc1(x))))
        return self.fc41(h1), self.fc42(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc7(self.fc6(self.fc5(z))))
        return torch.sigmoid(self.fc8(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

In [43]:
epochs = 10
batch_size = 64
log_interval = 100

mnist_trainset = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.ToTensor()),
                   batch_size=batch_size, shuffle=True)

mnist_testset = test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.ToTensor()),
                  batch_size=batch_size, shuffle=True)

model = VAE().to(device)
print(model.parameters)
optimizer = optim.Adam(model.parameters(), lr=7e-4)
rewind_state_dict = copy.deepcopy(model.state_dict()) 
step = 0

<bound method Module.parameters of VAE(
  (fc1): Linear(in_features=784, out_features=400, bias=True)
  (fc2): Linear(in_features=400, out_features=200, bias=True)
  (fc3): Linear(in_features=200, out_features=100, bias=True)
  (fc41): Linear(in_features=100, out_features=32, bias=True)
  (fc42): Linear(in_features=100, out_features=32, bias=True)
  (fc5): Linear(in_features=32, out_features=100, bias=True)
  (fc6): Linear(in_features=100, out_features=200, bias=True)
  (fc7): Linear(in_features=200, out_features=400, bias=True)
  (fc8): Linear(in_features=400, out_features=784, bias=True)
)>


In [44]:
def kl_anneal_function(anneal_function, step, k, x0):
    """ Beta update function
        
        Parameters
        ----------
        anneal_function : string
            What type of update (logisitc or linear)
        step : int
            Which step of the training
        k : float
            Coefficient of the logistic function
        x0 : float
            Delay of the logistic function or slope of the linear function
        Returns
        -------
        beta : float
            Weight of the KL divergence in the loss function 
        """
    if anneal_function == 'logistic':
        return float(1/(1+np.exp(-k*(step-x0))))
    elif anneal_function == 'linear':
        return min(1, step/x0)



# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar, beta):
    """ Compute the loss function between recon_x (output of the VAE) 
    and x (input of the VAE)
    """
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    KLD = 0.5 * torch.sum(logvar.exp() + mu.pow(2) - 1 - logvar)
    return BCE + beta*KLD

In [47]:
plt_loss = []
for name, param in model.named_parameters(): 
    param.data = rewind_state_dict[name].clone()
for epoch in range(15): # epochs + 1):

    model.train()
    train_loss = 0

    for batch_idx, (data, _) in enumerate(mnist_trainset):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)

        beta = kl_anneal_function('linear', step, 1, 10*len(mnist_trainset))
            
        loss = loss_function(recon_batch, data, mu, logvar, beta)
        
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        step += 1
        #if batch_idx == 500:
        #    rewind_state_dict = copy.deepcopy(model.state_dict())
        #    print('rewind state saved')

        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(mnist_trainset.dataset),
                100. * batch_idx / len(mnist_trainset),
                loss.item() / len(data)))
            plt_loss.append(loss.item() / len(data))  # For ploting loss

    print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(mnist_trainset.dataset)))


====> Epoch: 0 Average loss: 180.5791
====> Epoch: 1 Average loss: 134.5248
====> Epoch: 2 Average loss: 121.3537
====> Epoch: 3 Average loss: 116.1898
====> Epoch: 4 Average loss: 114.4453
====> Epoch: 5 Average loss: 113.4486
====> Epoch: 6 Average loss: 112.6476
====> Epoch: 7 Average loss: 112.1565
====> Epoch: 8 Average loss: 111.8000
====> Epoch: 9 Average loss: 111.4733
====> Epoch: 10 Average loss: 111.2019
====> Epoch: 11 Average loss: 111.0004
====> Epoch: 12 Average loss: 110.8052
====> Epoch: 13 Average loss: 110.6306


====> Epoch: 14 Average loss: 110.5062


## Pruning


In [22]:
from pruning import Pruning
from pruning import PruningBack


#local pruning
for name, module in model.named_modules():
    if isinstance(module, torch.nn.Linear):
        pr = Pruning(module)
        prb = PruningBack(module)
        pr.set_mask(0.7)
        prb.set_mask(0.7)

        module.register_forward_pre_hook(pr)
        #module.register_backward_hook(prb)

for batch_idx, (data, _) in enumerate(mnist_trainset):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        break

prstat = Pruning_tool()
prstat.stats_pruning(model)

Model :
Linear(in_features=784, out_features=400, bias=True)
Sparsity in Layer 0: 50.00%
Linear(in_features=400, out_features=40, bias=True)
Sparsity in Layer 1: 50.01%
Linear(in_features=400, out_features=40, bias=True)
Sparsity in Layer 2: 50.01%
Linear(in_features=40, out_features=400, bias=True)
Sparsity in Layer 3: 50.01%
Linear(in_features=400, out_features=784, bias=True)
Sparsity in Layer 4: 50.00%
Global Sparsity : 50.00%


In [46]:
from pruning import Pruning
from pruning import PruningBack
plt_loss = []
prstat = Pruning_tool()

pruning_amount = np.arange(10)/10

for pa in pruning_amount:
    print(pa*100,'% of weights pruned')
    #local pruning
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            pr = Pruning(module)
            prb = PruningBack(module)
            pr.set_mask(pa)
            prb.set_mask(pa)

            module.register_forward_pre_hook(pr)
            #module.register_backward_hook(prb)

    for name, param in model.named_parameters(): 
        param.data = rewind_state_dict[name].clone()
    for epoch in range(0, 3):

        model.train()
        train_loss = 0
        for batch_idx, (data, _) in enumerate(mnist_trainset):
            data = data.to(device)
            optimizer.zero_grad()
            recon_batch, mu, logvar = model(data)

            beta = kl_anneal_function('linear', step, 1, 10*len(mnist_trainset))
            #print(beta)
            loss = loss_function(recon_batch, data, mu, logvar, beta)
            loss.backward()
            train_loss += loss.item()
            optimizer.step()
            step += 1

            if batch_idx % log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(mnist_trainset.dataset),
                    100. * batch_idx / len(mnist_trainset),
                    loss.item() / len(data)))
                plt_loss.append(loss.item() / len(data))  # For ploting loss
        print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(mnist_trainset.dataset)))
    _, _, _ = model(data)
    prstat.stats_pruning(model)

prstat.stats_pruning(model)   

0.0 % of weights pruned
====> Epoch: 0 Average loss: 113.9240
====> Epoch: 1 Average loss: 100.4699
====> Epoch: 2 Average loss: 97.6759
Model :
Linear(in_features=784, out_features=400, bias=True)
Sparsity in Layer 0: 0.00%
Linear(in_features=400, out_features=200, bias=True)
Sparsity in Layer 1: 0.00%
Linear(in_features=200, out_features=100, bias=True)
Sparsity in Layer 2: 0.01%
Linear(in_features=100, out_features=32, bias=True)
Sparsity in Layer 3: 0.03%
Linear(in_features=100, out_features=32, bias=True)
Sparsity in Layer 4: 0.03%
Linear(in_features=32, out_features=100, bias=True)
Sparsity in Layer 5: 0.03%
Linear(in_features=100, out_features=200, bias=True)
Sparsity in Layer 6: 0.01%
Linear(in_features=200, out_features=400, bias=True)
Sparsity in Layer 7: 0.00%
Linear(in_features=400, out_features=784, bias=True)
Sparsity in Layer 8: 0.00%
Global Sparsity : 0.00%
10.0 % of weights pruned
====> Epoch: 0 Average loss: 120.6815
====> Epoch: 1 Average loss: 106.9662
====> Epoch: 

====> Epoch: 1 Average loss: 113.5518
====> Epoch: 2 Average loss: 110.4400
Model :
Linear(in_features=784, out_features=400, bias=True)
Sparsity in Layer 0: 30.00%
Linear(in_features=400, out_features=200, bias=True)
Sparsity in Layer 1: 30.00%
Linear(in_features=200, out_features=100, bias=True)
Sparsity in Layer 2: 30.00%
Linear(in_features=100, out_features=32, bias=True)
Sparsity in Layer 3: 30.03%
Linear(in_features=100, out_features=32, bias=True)
Sparsity in Layer 4: 30.03%
Linear(in_features=32, out_features=100, bias=True)
Sparsity in Layer 5: 30.03%
Linear(in_features=100, out_features=200, bias=True)
Sparsity in Layer 6: 30.00%
Linear(in_features=200, out_features=400, bias=True)
Sparsity in Layer 7: 30.00%
Linear(in_features=400, out_features=784, bias=True)
Sparsity in Layer 8: 30.00%
Global Sparsity : 30.00%
40.0 % of weights pruned
====> Epoch: 0 Average loss: 125.2949
====> Epoch: 1 Average loss: 112.1610
====> Epoch: 2 Average loss: 109.6155
Model :
Linear(in_features

====> Epoch: 0 Average loss: 127.7825
====> Epoch: 1 Average loss: 110.5847
====> Epoch: 2 Average loss: 108.4515
Model :
Linear(in_features=784, out_features=400, bias=True)
Sparsity in Layer 0: 70.00%
Linear(in_features=400, out_features=200, bias=True)
Sparsity in Layer 1: 70.00%
Linear(in_features=200, out_features=100, bias=True)
Sparsity in Layer 2: 70.00%
Linear(in_features=100, out_features=32, bias=True)
Sparsity in Layer 3: 70.03%
Linear(in_features=100, out_features=32, bias=True)
Sparsity in Layer 4: 70.03%
Linear(in_features=32, out_features=100, bias=True)
Sparsity in Layer 5: 70.03%
Linear(in_features=100, out_features=200, bias=True)
Sparsity in Layer 6: 70.00%
Linear(in_features=200, out_features=400, bias=True)
Sparsity in Layer 7: 70.00%
Linear(in_features=400, out_features=784, bias=True)
Sparsity in Layer 8: 70.00%
Global Sparsity : 70.00%
80.0 % of weights pruned
====> Epoch: 0 Average loss: 137.5300
====> Epoch: 1 Average loss: 113.0305
====> Epoch: 2 Average los