#  Lottery Notebook

### imports :

In [1]:
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
import numpy as np
import torch.nn.utils.prune as prune
import matplotlib.pyplot as plt
from pruning import PruningTool
import copy
import time
if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")
print('using', device)

using cuda:0


## Model :

In [2]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(784, 400)
        self.fc2 = nn.Linear(400, 200)
        self.fc3 = nn.Linear(200,100)
        self.fc41 = nn.Linear(100, 32)
        self.fc41.transfer_trim = True
        self.fc42 = nn.Linear(100, 32)
        self.fc5 = nn.Linear(32, 100)
        self.fc6 = nn.Linear(100,200)
        self.fc7 = nn.Linear(200,400)
        self.fc8 = nn.Linear(400, 784)
        self.fc8.unprunable = True

    def encode(self, x):
        h1 = F.relu(self.fc3(self.fc2(self.fc1(x))))
        return self.fc41(h1), self.fc42(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc7(self.fc6(self.fc5(z))))
        return torch.sigmoid(self.fc8(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar



# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar, beta):
    """ Compute the loss function between recon_x (output of the VAE) 
    and x (input of the VAE)
    """
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    KLD = 0.5 * torch.sum(logvar.exp() + mu.pow(2) - 1 - logvar)
    return BCE + beta*KLD

epochs = 10
batch_size = 64
log_interval = 100

mnist_trainset = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.ToTensor()),
                   batch_size=batch_size, shuffle=True)

mnist_testset = test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.ToTensor()),
                  batch_size=batch_size, shuffle=True)


## Trimming


In [3]:
from pruning import Trimming

model = VAE().to(device)
print(model.parameters)
optimizer = optim.Adam(model.parameters(), lr=7e-4)
rewind_state_dict = copy.deepcopy(model.state_dict()) 
step = 0

plt_loss = []
prstat = PruningTool()
trimmer = Trimming(model)

first_training = True
final_trimming_value = 0.8
nb_iter_trimming = 10
trimming_amount = 1 - (1-final_trimming_value)**(1/nb_iter_trimming)

print('trimming amount :', trimming_amount)
for k in range(nb_iter_trimming + 1):
    epochs = 20 if k == nb_iter_trimming else 10
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        for batch_idx, (data, _) in enumerate(mnist_trainset):
            data = data.to(device)
            optimizer.zero_grad()
            recon_batch, mu, logvar = model(data)
            beta = 1
            loss = loss_function(recon_batch, data, mu, logvar, beta)
            loss.backward()
            train_loss += loss.item()
            optimizer.step()
            step += 1
            
            if  first_training and batch_idx == 500 :
                trimmer.save_rewind_state(model)
                print('rewind state saved')
                first_training = False
            
            #if batch_idx % log_interval == 0:
            #    print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
            #        epoch, batch_idx * len(data), len(mnist_trainset.dataset),
            #        100. * batch_idx / len(mnist_trainset),
            #        loss.item() / len(data)))
            #    plt_loss.append(loss.item() / len(data))  # For ploting loss
        print('====> Epoch: {} Average loss: {:.4f} \n'.format(epoch, train_loss / len(mnist_trainset.dataset)))
    if k < nb_iter_trimming:
        trimmer.trim_locally(model, trimming_amount)
        print(round((1 - (1-final_trimming_value)**((k+1)/nb_iter_trimming))*100,3),'% of weights trimmed')
        
    for name, module in model.named_modules():
        if len(list(module.children())) == 0:
            print(name)
            print(module.weight.shape)
            print(module.bias.shape)
            
print(model)

<bound method Module.parameters of VAE(
  (fc1): Linear(in_features=784, out_features=400, bias=True)
  (fc2): Linear(in_features=400, out_features=200, bias=True)
  (fc3): Linear(in_features=200, out_features=100, bias=True)
  (fc41): Linear(in_features=100, out_features=32, bias=True)
  (fc42): Linear(in_features=100, out_features=32, bias=True)
  (fc5): Linear(in_features=32, out_features=100, bias=True)
  (fc6): Linear(in_features=100, out_features=200, bias=True)
  (fc7): Linear(in_features=200, out_features=400, bias=True)
  (fc8): Linear(in_features=400, out_features=784, bias=True)
)>
trimming amount : 0.1486600774792154
rewind state saved
====> Epoch: 0 Average loss: 161.0567 

14.866 % of weights trimmed
fc1
torch.Size([341, 784])
torch.Size([341])
fc2
torch.Size([171, 341])
torch.Size([171])
fc3
torch.Size([86, 171])
torch.Size([86])
fc41
torch.Size([28, 86])
torch.Size([28])
fc42
torch.Size([28, 86])
torch.Size([28])
fc5
torch.Size([86, 28])
torch.Size([86])
fc6
torch.Size(

KeyboardInterrupt: 

In [None]:
for name, module in model.named_modules():
    if len(list(module.children())) == 0:
        print(name)
        print(module.weight.shape)
        print(module.bias.shape)