#  Lottery Notebook

### imports :

In [1]:
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
import numpy as np
import torch.nn.utils.prune as prune
import matplotlib.pyplot as plt
from pruning import Pruning_tool
import copy
import time
if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")
print('using', device)

using cuda:0


## Model :

In [2]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(784, 400)
        self.fc2 = nn.Linear(400, 200)
        self.fc3 = nn.Linear(200,100)
        self.fc41 = nn.Linear(100, 32)
        self.fc42 = nn.Linear(100, 32)
        self.fc5 = nn.Linear(32, 100)
        self.fc6 = nn.Linear(100,200)
        self.fc7 = nn.Linear(200,400)
        self.fc8 = nn.Linear(400, 784)

    def encode(self, x):
        h1 = F.relu(self.fc3(self.fc2(self.fc1(x))))
        return self.fc41(h1), self.fc42(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc7(self.fc6(self.fc5(z))))
        return torch.sigmoid(self.fc8(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar



# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar, beta):
    """ Compute the loss function between recon_x (output of the VAE) 
    and x (input of the VAE)
    """
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    KLD = 0.5 * torch.sum(logvar.exp() + mu.pow(2) - 1 - logvar)
    return BCE + beta*KLD

epochs = 10
batch_size = 64
log_interval = 100

mnist_trainset = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.ToTensor()),
                   batch_size=batch_size, shuffle=True)

mnist_testset = test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.ToTensor()),
                  batch_size=batch_size, shuffle=True)

model = VAE().to(device)
print(model.parameters)
optimizer = optim.Adam(model.parameters(), lr=7e-4)
rewind_state_dict = copy.deepcopy(model.state_dict()) 
step = 0

<bound method Module.parameters of VAE(
  (fc1): Linear(in_features=784, out_features=400, bias=True)
  (fc2): Linear(in_features=400, out_features=200, bias=True)
  (fc3): Linear(in_features=200, out_features=100, bias=True)
  (fc41): Linear(in_features=100, out_features=32, bias=True)
  (fc42): Linear(in_features=100, out_features=32, bias=True)
  (fc5): Linear(in_features=32, out_features=100, bias=True)
  (fc6): Linear(in_features=100, out_features=200, bias=True)
  (fc7): Linear(in_features=200, out_features=400, bias=True)
  (fc8): Linear(in_features=400, out_features=784, bias=True)
)>


## Training with global Iterative Magnitude Pruning


In [3]:
from pruning import Pruning

plt_loss = []
prstat = Pruning_tool()
first_training = True
pruning_amount = np.linspace(0, 0.99, 10)
final_pa = pruning_amount[-1]

for pa in pruning_amount:
    print(round(pa*100,3),'% of weights pruned')
    # global pruning
    cutting_value = prstat.compute_global_criterion(model, pa)
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            pr = Pruning(module)
            pr.set_mask_globally(cutting_value)
            module.register_forward_pre_hook(pr)

    for name, param in model.named_parameters(): 
        param.data = rewind_state_dict[name].clone()
        
    epochs = 20 if pa == final_pa else 10
    for epoch in range(epochs):

        model.train()
        train_loss = 0
        for batch_idx, (data, _) in enumerate(mnist_trainset):
            data = data.to(device)
            optimizer.zero_grad()
            recon_batch, mu, logvar = model(data)
            beta = 1
            loss = loss_function(recon_batch, data, mu, logvar, beta)
            loss.backward()
            train_loss += loss.item()
            optimizer.step()
            step += 1
            
            if  first_training and batch_idx == 500 :
                rewind_state_dict = copy.deepcopy(model.state_dict())
                first_training = False
                print('rewind state saved')
            
            #if batch_idx % log_interval == 0:
            #    print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
            #        epoch, batch_idx * len(data), len(mnist_trainset.dataset),
            #        100. * batch_idx / len(mnist_trainset),
            #        loss.item() / len(data)))
            #    plt_loss.append(loss.item() / len(data))  # For ploting loss
        print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(mnist_trainset.dataset)))
    _, _, _ = model(data)
    prstat.stats_pruning(model)

0.0 % of weights pruned
rewind state saved
====> Epoch: 0 Average loss: 163.6184
====> Epoch: 1 Average loss: 126.1361
====> Epoch: 2 Average loss: 118.5564
====> Epoch: 3 Average loss: 114.9881
====> Epoch: 4 Average loss: 112.5811
====> Epoch: 5 Average loss: 110.8894
====> Epoch: 6 Average loss: 109.8315
====> Epoch: 7 Average loss: 109.0097
====> Epoch: 8 Average loss: 108.3887
====> Epoch: 9 Average loss: 107.9239
Model :
Linear(in_features=784, out_features=400, bias=True)
Sparsity in Layer 0: 0.00%
Linear(in_features=400, out_features=200, bias=True)
Sparsity in Layer 1: 0.00%
Linear(in_features=200, out_features=100, bias=True)
Sparsity in Layer 2: 0.00%
Linear(in_features=100, out_features=32, bias=True)
Sparsity in Layer 3: 0.00%
Linear(in_features=100, out_features=32, bias=True)
Sparsity in Layer 4: 0.00%
Linear(in_features=32, out_features=100, bias=True)
Sparsity in Layer 5: 0.00%
Linear(in_features=100, out_features=200, bias=True)
Sparsity in Layer 6: 0.00%
Linear(in_fe

In [7]:
for epoch in range(31,50):

        model.train()
        train_loss = 0
        for batch_idx, (data, _) in enumerate(mnist_trainset):
            data = data.to(device)
            optimizer.zero_grad()
            recon_batch, mu, logvar = model(data)
            beta = 1
            loss = loss_function(recon_batch, data, mu, logvar, beta)
            loss.backward()
            train_loss += loss.item()
            optimizer.step()
            step += 1
            
            if  first_training and batch_idx == 500 :
                rewind_state_dict = copy.deepcopy(model.state_dict())
                first_training = False
                print('rewind state saved')
            
            #if batch_idx % log_interval == 0:
            #    print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
            #        epoch, batch_idx * len(data), len(mnist_trainset.dataset),
            #        100. * batch_idx / len(mnist_trainset),
            #        loss.item() / len(data)))
            #    plt_loss.append(loss.item() / len(data))  # For ploting loss
        print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(mnist_trainset.dataset)))
_, _, _ = model(data)
prstat.stats_pruning(model)

====> Epoch: 31 Average loss: 167.0520
====> Epoch: 32 Average loss: 166.7728
====> Epoch: 33 Average loss: 166.4616
====> Epoch: 34 Average loss: 166.0388
====> Epoch: 35 Average loss: 165.6754
====> Epoch: 36 Average loss: 165.3309
====> Epoch: 37 Average loss: 164.9875
====> Epoch: 38 Average loss: 164.7251
====> Epoch: 39 Average loss: 164.5643
====> Epoch: 40 Average loss: 164.4707
====> Epoch: 41 Average loss: 164.3246
====> Epoch: 42 Average loss: 164.2883
====> Epoch: 43 Average loss: 164.2213
====> Epoch: 44 Average loss: 164.1321
====> Epoch: 45 Average loss: 164.0421
====> Epoch: 46 Average loss: 164.0200
====> Epoch: 47 Average loss: 163.9758
====> Epoch: 48 Average loss: 163.9093
====> Epoch: 49 Average loss: 163.8431
Model :
Linear(in_features=784, out_features=400, bias=True)
Sparsity in Layer 0: 99.75%
Linear(in_features=400, out_features=200, bias=True)
Sparsity in Layer 1: 99.27%
Linear(in_features=200, out_features=100, bias=True)
Sparsity in Layer 2: 98.17%
Linear(i