<a href="https://colab.research.google.com/github/NesterukSergey/hidden-networks/blob/master/mnist.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
from __future__ import print_function
import argparse
import os
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.autograd as autograd

from tqdm import tqdm

In [0]:
class GetSubnet(autograd.Function):
    @staticmethod
    def forward(ctx, scores, k):
        # Get the supermask by sorting the scores and using the top k%
        out = scores.clone()
        _, idx = scores.flatten().sort()
        j = int((1 - k) * scores.numel())

        # flat_out and out access the same memory.
        flat_out = out.flatten()
        flat_out[idx[:j]] = 0
        flat_out[idx[j:]] = 1

        return out

    @staticmethod
    def backward(ctx, g):
        # send the gradient g straight-through on the backward pass.
        return g, None


class SupermaskConv(nn.Conv2d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # initialize the scores
        self.scores = nn.Parameter(torch.Tensor(self.weight.size()))
        nn.init.kaiming_uniform_(self.scores, a=math.sqrt(5))

        # NOTE: initialize the weights like this.
        nn.init.kaiming_normal_(self.weight, mode="fan_in", nonlinearity="relu")

        # NOTE: turn the gradient on the weights off
        self.weight.requires_grad = False

    def forward(self, x):
        subnet = GetSubnet.apply(self.scores.abs(), 0.5)
        w = self.weight * subnet
        x = F.conv2d(
            x, w, self.bias, self.stride, self.padding, self.dilation, self.groups
        )
        return x

class SupermaskLinear(nn.Linear):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # initialize the scores
        self.scores = nn.Parameter(torch.Tensor(self.weight.size()))
        nn.init.kaiming_uniform_(self.scores, a=math.sqrt(5))

        # NOTE: initialize the weights like this.
        nn.init.kaiming_normal_(self.weight, mode="fan_in", nonlinearity="relu")

        # NOTE: turn the gradient on the weights off
        self.weight.requires_grad = False

    def forward(self, x):
        subnet = GetSubnet.apply(self.scores.abs(), 0.5)
        w = self.weight * subnet
        return F.linear(x, w, self.bias)
        return x

# NOTE: not used here but we use NON-AFFINE Normalization!
# So there is no learned parameters for your nomralization layer.
class NonAffineBatchNorm(nn.BatchNorm2d):
    def __init__(self, dim):
        super(NonAffineBatchNorm, self).__init__(dim, affine=False)

In [0]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = SupermaskConv(1, 32, 3, 1, bias=False)
        self.conv2 = SupermaskConv(32, 64, 3, 1, bias=False)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = SupermaskLinear(9216, 128, bias=False)
        self.fc2 = SupermaskLinear(128, 10, bias=False)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

In [0]:
def train(model, device, train_loader, optimizer, criterion, epoch):
    model.train()
    for batch_idx, (data, target) in tqdm(enumerate(train_loader)):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 1000 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))


def test(model, device, criterion, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target)
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [0]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

In [0]:
kwargs = {}
data = '../data'
batch_size = 8
test_batch_size = 8

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(os.path.join(data, 'mnist'), train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(os.path.join(data, 'mnist'), train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=test_batch_size, shuffle=True, **kwargs)

model = Net().to(device)
# NOTE: only pass the parameters where p.requires_grad == True to the optimizer! Important!
optimizer = optim.SGD(
    [p for p in model.parameters() if p.requires_grad],
    lr=0.1,
    momentum=0.9,
    weight_decay=0.0005,
)
criterion = nn.CrossEntropyLoss().to(device)
scheduler = CosineAnnealingLR(optimizer, T_max=14)

In [14]:
for epoch in range(1, 14 + 1):
    train(model, device, train_loader, optimizer, criterion, epoch)
    test(model, device, criterion, test_loader)
    scheduler.step()

16it [00:00, 74.08it/s]



1013it [00:08, 114.78it/s]



2017it [00:17, 123.20it/s]



3015it [00:25, 121.47it/s]



4021it [00:34, 120.04it/s]



5026it [00:42, 123.71it/s]



6021it [00:51, 118.43it/s]



7012it [00:59, 112.87it/s]



7500it [01:03, 117.75it/s]
11it [00:00, 107.42it/s]


Test set: Average loss: 0.0442, Accuracy: 8725/10000 (87%)



1024it [00:08, 121.47it/s]



2021it [00:17, 113.39it/s]



3022it [00:25, 114.67it/s]



4013it [00:34, 113.93it/s]



5013it [00:42, 122.55it/s]



6016it [00:51, 118.23it/s]



7017it [01:00, 116.22it/s]



7500it [01:04, 116.72it/s]
12it [00:00, 119.83it/s]


Test set: Average loss: 0.0432, Accuracy: 8866/10000 (89%)



1020it [00:08, 115.55it/s]



2023it [00:17, 112.77it/s]



3015it [00:26, 114.92it/s]



4023it [00:34, 117.73it/s]



5018it [00:43, 113.48it/s]



6012it [00:51, 111.76it/s]



7016it [01:00, 113.16it/s]



7500it [01:04, 115.88it/s]
11it [00:00, 107.05it/s]


Test set: Average loss: 0.0319, Accuracy: 9245/10000 (92%)



1014it [00:08, 111.51it/s]



2018it [00:17, 117.93it/s]



3017it [00:25, 111.65it/s]



4017it [00:34, 118.74it/s]



5020it [00:43, 124.81it/s]



6023it [00:51, 116.62it/s]



7022it [01:00, 114.24it/s]



7500it [01:04, 116.58it/s]
12it [00:00, 111.08it/s]


Test set: Average loss: 0.0293, Accuracy: 9345/10000 (93%)



1020it [00:08, 110.93it/s]



2013it [00:17, 116.12it/s]



3017it [00:25, 119.11it/s]



4014it [00:34, 115.85it/s]



5017it [00:42, 118.65it/s]



6017it [00:51, 117.81it/s]



7021it [00:59, 124.73it/s]



7500it [01:03, 117.66it/s]
11it [00:00, 108.53it/s]


Test set: Average loss: 0.0312, Accuracy: 9263/10000 (93%)



1018it [00:08, 123.57it/s]



2016it [00:17, 119.36it/s]



3017it [00:25, 116.85it/s]



4014it [00:34, 117.42it/s]



5020it [00:42, 114.79it/s]



6021it [00:51, 115.02it/s]



7021it [00:59, 125.61it/s]



7500it [01:03, 117.43it/s]
13it [00:00, 126.03it/s]


Test set: Average loss: 0.0312, Accuracy: 9256/10000 (93%)



1019it [00:08, 113.63it/s]



2015it [00:17, 116.51it/s]



3021it [00:25, 114.85it/s]



4022it [00:34, 114.05it/s]



5013it [00:42, 114.82it/s]



6017it [00:51, 117.23it/s]



7022it [00:59, 121.68it/s]



7500it [01:03, 117.40it/s]
13it [00:00, 123.30it/s]


Test set: Average loss: 0.0275, Accuracy: 9317/10000 (93%)



1024it [00:08, 116.61it/s]



2017it [00:17, 115.32it/s]



3018it [00:25, 112.38it/s]



4014it [00:34, 117.49it/s]



5018it [00:42, 114.63it/s]



6018it [00:51, 116.44it/s]



7018it [01:00, 116.08it/s]



7500it [01:04, 116.75it/s]
11it [00:00, 108.70it/s]


Test set: Average loss: 0.0316, Accuracy: 9220/10000 (92%)



1013it [00:08, 114.37it/s]



2012it [00:17, 113.40it/s]



3012it [00:25, 119.83it/s]



4020it [00:34, 123.04it/s]



5018it [00:42, 118.47it/s]



6020it [00:50, 115.18it/s]



7023it [00:59, 119.04it/s]



7500it [01:03, 117.84it/s]
11it [00:00, 108.59it/s]


Test set: Average loss: 0.0292, Accuracy: 9292/10000 (93%)



1024it [00:08, 117.25it/s]



2013it [00:17, 115.58it/s]



3020it [00:25, 119.95it/s]



4022it [00:34, 114.07it/s]



5016it [00:42, 121.75it/s]



6023it [00:51, 113.16it/s]



7020it [00:59, 118.40it/s]



7500it [01:03, 117.31it/s]
12it [00:00, 110.80it/s]


Test set: Average loss: 0.0239, Accuracy: 9432/10000 (94%)



1025it [00:08, 120.14it/s]



2012it [00:17, 117.42it/s]



3014it [00:25, 123.26it/s]



4014it [00:34, 124.51it/s]



5016it [00:42, 122.91it/s]



6013it [00:51, 122.46it/s]



7019it [01:00, 112.81it/s]



7500it [01:04, 116.93it/s]
11it [00:00, 102.60it/s]


Test set: Average loss: 0.0169, Accuracy: 9552/10000 (96%)



1016it [00:08, 116.49it/s]



2017it [00:17, 115.08it/s]



3011it [00:25, 114.06it/s]



4015it [00:34, 120.04it/s]



5015it [00:42, 117.77it/s]



6013it [00:51, 120.96it/s]



7023it [00:59, 114.88it/s]



7500it [01:03, 117.25it/s]
11it [00:00, 106.30it/s]


Test set: Average loss: 0.0167, Accuracy: 9593/10000 (96%)



1021it [00:08, 119.43it/s]



2020it [00:17, 126.16it/s]



3014it [00:25, 123.08it/s]



4012it [00:34, 114.47it/s]



5021it [00:42, 122.29it/s]



6019it [00:51, 125.18it/s]



7022it [00:59, 120.03it/s]



7500it [01:03, 117.66it/s]
13it [00:00, 122.62it/s]


Test set: Average loss: 0.0097, Accuracy: 9745/10000 (97%)



1020it [00:08, 118.08it/s]



2021it [00:17, 120.31it/s]



3017it [00:25, 114.05it/s]



4018it [00:34, 122.83it/s]



5021it [00:42, 113.19it/s]



6021it [00:51, 114.63it/s]



7022it [00:59, 115.54it/s]



7500it [01:03, 117.39it/s]



Test set: Average loss: 0.0069, Accuracy: 9811/10000 (98%)

