In [32]:
from IPython.display import clear_output
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import random
import shutil
import time
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from torchsummary import summary
from torchvision import transforms, datasets

In [25]:
def plot_wb(model, fig_path, ranges=None):

    tmp = list(model.named_parameters())
    layers = []
    for i in range(0, len(tmp), 2):
          w, b = tmp[i], tmp[i + 1]
          if ("conv" in w[0] or "conv" in b[0]) or ("fc" in w[0] or "fc" in b[0]):
            layers.append((w, b))

    num_rows = len(layers)

    fig = plt.figure(figsize=(20, 40))

    i = 1
    for w, b in layers:
        w_flatten = w[1].flatten().detach().cpu().numpy()
        b_flatten = b[1].flatten().detach().cpu().numpy()

        fig.add_subplot(num_rows, 2, i)
        plt.title(w[0])
        plt.hist(w_flatten, bins=100, range=ranges);

        fig.add_subplot(num_rows, 2, i + 1)
        plt.title(b[0])
        plt.hist(b_flatten, bins=100, range=ranges);

        i += 2
    
    fig.tight_layout()
    plt.savefig(fig_path)
    plt.close()

In [26]:
def load_mnist(BATCH_SIZE=64):
    transform = transforms.Compose([transforms.Resize((32,32)),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.5], std=[0.5])])
            
    train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)
    
    # Clear downloading message.
    clear_output()
    
    # Split dataset into training set and validation set.
    train_dataset, val_dataset = random_split(train_dataset, (55000, 5000))
    
    print("Image Shape: {}".format(train_dataset[0][0].numpy().shape), end = '\n\n')
    print("Training Set:   {} samples".format(len(train_dataset)))
    print("Validation Set:   {} samples".format(len(val_dataset)))
    print("Test Set:       {} samples".format(len(test_dataset)))
    
    # Create iterator.
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=len(test_dataset), shuffle=True)
    
    # Delete the data/ folder.
    shutil.rmtree('./data')
    
    return (train_loader, val_loader, test_loader)

In [27]:
train_loader, val_loader, test_loader = load_mnist()

Image Shape: (1, 32, 32)

Training Set:   55000 samples
Validation Set:   5000 samples
Test Set:       10000 samples


In [43]:
class NN(nn.Module):
    
    def __init__(self):
        super(NN, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=(5,5))
        self.pool1 = nn.AvgPool2d(kernel_size=(2,2), stride=2)
        self.fc1 = nn.Linear(in_features=6*14*14, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=10)  
        
    def forward(self, x):
        x = self.conv1(x)
        x = torch.relu(x)
        x = self.pool1(x)
       
        x = x.view(-1, 6*14*14)
        x = self.fc1(x)
        x = torch.relu(x)
        x = self.fc2(x)
        return x

class DSDTraining(nn.Module):
    
    def __init__(self, model, sparsity, train_on_sparse = False):
        super(DSDTraining, self).__init__()
        
        self.model = model
        self.sparsity = sparsity
        self.train_on_sparse = train_on_sparse

        # Get only conv/fc layers. Not 1st conv.
        tmp = list(self.model.named_parameters())
        self.layers = []
        for i in range(2, len(tmp), 2):
          w, b = tmp[i], tmp[i + 1]
          if ("conv" in w[0] or "conv" in b[0]) or ("fc" in w[0] or "fc" in b[0]):
            self.layers.append((w[1], b[1]))

        # Init masks
        self.reset_masks()

    def reset_masks(self):
        
        self.masks = []
        for w, b in self.layers:
          mask_w = torch.ones_like(w, dtype=bool)
          mask_b = torch.ones_like(b, dtype=bool)
          self.masks.append((mask_w, mask_b))
        
        return self.masks

    def update_masks(self):

      for i, (w, b) in enumerate(dsd_model.layers):
        q_w = torch.quantile(torch.abs(w), q = self.sparsity)
        mask_w = torch.where(torch.abs(w) < q_w, True, False)
        
        q_b = torch.quantile(torch.abs(b), q = self.sparsity)
        mask_b = torch.where(torch.abs(b) < q_b, True, False)

        self.masks[i] = (mask_w, mask_b)

    def forward(self, x):
        return self.model(x)

model = NN()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
dsd_model = DSDTraining(model, sparsity=0.3)
summary(dsd_model, (1, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 6, 28, 28]             156
         AvgPool2d-2            [-1, 6, 14, 14]               0
            Linear-3                  [-1, 120]         141,240
            Linear-4                   [-1, 10]           1,210
                NN-5                   [-1, 10]               0
Total params: 142,606
Trainable params: 142,606
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.05
Params size (MB): 0.54
Estimated Total Size (MB): 0.59
----------------------------------------------------------------


In [44]:
def set_all_seed(seed_value=42):
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)

set_all_seed(42)

EPOCH_DENSE1 = 3
EPOCH_SPARSE = 2
EPOCH_DENSE2 = 3
EPOCHS = EPOCH_DENSE1 + EPOCH_SPARSE + EPOCH_DENSE2
NB_TRAIN_EXAMPLES = len(train_loader.dataset)

# Initialize
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(dsd_model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

for epoch in range(EPOCHS):

    if (epoch >= EPOCH_DENSE1 and epoch < EPOCH_DENSE1 + EPOCH_SPARSE):
        dsd_model.train_on_sparse = True
    else:
        dsd_model.train_on_sparse = False

    if dsd_model.train_on_sparse:
        dsd_model.update_masks()

    train_loss, correct_train = 0, 0

    for step, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        prediction = dsd_model(inputs)
        loss = criterion(prediction, labels)
        loss.backward()
        
        # Sparse
        if dsd_model.train_on_sparse:
            for (w, b), (mask_w, mask_b) in zip(dsd_model.layers, dsd_model.masks):
                # Values
                w.data[mask_w] = 0
                b.data[mask_b] = 0
                # Grad
                w.grad.data[mask_w] = 0
                b.grad.data[mask_b] = 0

        optimizer.step()

        _, predicted = torch.max(prediction.data, 1)
        correct_train += (predicted == labels).sum().item()
        train_loss += (loss.data.item() * inputs.shape[0])

    train_loss /= NB_TRAIN_EXAMPLES
    train_acc =  correct_train / NB_TRAIN_EXAMPLES

    info = "[Epoch {}/{}]: lr = {:0.6f} | train-loss = {:0.6f} | train-acc = {:0.6f} | train_on_sparse = {}"
    print(info.format(epoch+1, EPOCHS, optimizer.param_groups[0]['lr'], train_loss, train_acc, dsd_model.train_on_sparse))

    scheduler.step()

    if (epoch + 1 == EPOCH_DENSE1):
        plot_wb(dsd_model, "dense1.png")
    elif (epoch + 1 == EPOCH_DENSE1 + EPOCH_SPARSE):
        plot_wb(dsd_model, "sparse.png")
    elif (epoch + 1 == EPOCHS):
        plot_wb(dsd_model, "dense2.png")

[Epoch 1/8]: lr = 0.001000 | train-loss = 2.245783 | train-acc = 0.328200 | train_on_sparse = False
[Epoch 2/8]: lr = 0.001000 | train-loss = 1.841837 | train-acc = 0.587764 | train_on_sparse = False
[Epoch 3/8]: lr = 0.001000 | train-loss = 0.917730 | train-acc = 0.793255 | train_on_sparse = False
[Epoch 4/8]: lr = 0.000100 | train-loss = 0.643202 | train-acc = 0.835418 | train_on_sparse = True
[Epoch 5/8]: lr = 0.000100 | train-loss = 0.618310 | train-acc = 0.839782 | train_on_sparse = True
[Epoch 6/8]: lr = 0.000100 | train-loss = 0.595139 | train-acc = 0.843364 | train_on_sparse = False
[Epoch 7/8]: lr = 0.000010 | train-loss = 0.582308 | train-acc = 0.846055 | train_on_sparse = False
[Epoch 8/8]: lr = 0.000010 | train-loss = 0.580096 | train-acc = 0.846273 | train_on_sparse = False
