In [1]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import os
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
from IPython import embed
from collections import OrderedDict
import time
import torch.nn.functional as F
import torch.optim as optim
import torch.nn.utils.prune as prune

In [6]:
def get10(batch_size, data_root='/tmp/public_dataset/pytorch', train=True, val=True, **kwargs):
    data_root = os.path.expanduser(os.path.join(data_root, 'cifar10-data'))
    num_workers = kwargs.setdefault('num_workers', 1)
    kwargs.pop('input_size', None)
    print("Building CIFAR-10 data loader with {} workers".format(num_workers))
    ds = []
    if train:
        train_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(
                root=data_root, train=True, download=True,
                transform=transforms.Compose([
                    transforms.Pad(4),
                    transforms.RandomCrop(32),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ])),
            batch_size=batch_size, shuffle=True, **kwargs)
        ds.append(train_loader)
    if val:
        test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(
                root=data_root, train=False, download=True,
                transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ])),
            batch_size=batch_size, shuffle=False, **kwargs)
        ds.append(test_loader)
    ds = ds[0] if len(ds) == 1 else ds
    return ds

In [7]:
model_urls = {
    'cifar10': 'http://ml.cs.tsinghua.edu.cn/~chenxi/pytorch-models/cifar10-d875770b.pth',
    'cifar100': 'http://ml.cs.tsinghua.edu.cn/~chenxi/pytorch-models/cifar100-3a55a987.pth',
}

class CIFAR(nn.Module):
    def __init__(self, features, n_channel, num_classes):
        super(CIFAR, self).__init__()
        assert isinstance(features, nn.Sequential), type(features)
        self.features = features
        self.classifier = nn.Sequential(
            nn.Linear(n_channel, num_classes)
        )
        print(self.features)
        print(self.classifier)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

def make_layers(cfg, batch_norm=False):
    layers = []
    in_channels = 3
    for i, v in enumerate(cfg):
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            padding = v[1] if isinstance(v, tuple) else 1
            out_channels = v[0] if isinstance(v, tuple) else v
            conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=padding)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(out_channels, affine=False), nn.ReLU()]
            else:
                layers += [conv2d, nn.ReLU()]
            in_channels = out_channels
    return nn.Sequential(*layers)

def cifar10(n_channel, pretrained=None):
    cfg = [n_channel, n_channel, 'M', 2*n_channel, 2*n_channel, 'M', 4*n_channel, 4*n_channel, 'M', (8*n_channel, 0), 'M']
    layers = make_layers(cfg, batch_norm=True)
    model = CIFAR(layers, n_channel=8*n_channel, num_classes=10)
    if pretrained is not None:
        m = model_zoo.load_url(model_urls['cifar10'])
        state_dict = m.state_dict() if isinstance(m, nn.Module) else m
        assert isinstance(state_dict, (dict, OrderedDict)), type(state_dict)
        model.load_state_dict(state_dict)
    return model

In [21]:
train_loader, test_loader = get10(batch_size=200, num_workers=1)
model = cifar10(128, pretrained=True)
# model = cifar10(128, pretrained=None)
use_cuda = torch.cuda.is_available()
if use_cuda:
    model.cuda()

Building CIFAR-10 data loader with 1 workers
Files already downloaded and verified
Files already downloaded and verified
Sequential(
  (0): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
  (2): ReLU()
  (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
  (5): ReLU()
  (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (7): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
  (9): ReLU()
  (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
  (12): ReLU()
  (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_m

In [10]:
# lr = 0.001
# wd = 0.00
# epochs = 150
log_interval = 100
test_interval = 5
# decreasing_lr = '80,120'
# optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
# decreasing_lr = list(map(int, decreasing_lr.split(',')))
# print('decreasing_lr: ' + str(decreasing_lr))

In [29]:
def prune_model(model, conv_prune=0.3, lin_prune=0.6):
    for name, module in model.named_modules():
        # prune 30% of connections in all 2D-conv layers
        if isinstance(module, torch.nn.Conv2d):
            prune.l1_unstructured(module, name='weight', amount=conv_prune)
        # prune 60% of connections in all linear layers
        elif isinstance(module, torch.nn.Linear):
            prune.l1_unstructured(module, name='weight', amount=lin_prune)

    print(dict(model.named_buffers()).keys())  # to verify that all masks exist

In [33]:
def train(model, epochs=150, lr=0.001, decreasing_lr='80,120', wd=0):
    best_acc = 0
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    decreasing_lr = list(map(int, decreasing_lr.split(',')))
    t_begin = time.time()
    for epoch in range(epochs):
        model.train()
        if epoch in decreasing_lr:
            optimizer.param_groups[0]['lr'] *= 0.1
        for batch_idx, (data, target) in enumerate(train_loader):
            indx_target = target.clone()
            if use_cuda:
                data, target = data.cuda(), target.cuda()

            optimizer.zero_grad()
            output = model(data)
            loss = F.cross_entropy(output, target)
            loss.backward()
            optimizer.step()

            if batch_idx % log_interval == 0 and batch_idx > 0:
                pred = output.data.max(1)[1]  # get the index of the max log-probability
                correct = pred.cpu().eq(indx_target).sum()
                acc = correct * 1.0 / len(data)
                print('Train Epoch: {} [{}/{}] Loss: {:.6f} Acc: {:.4f} lr: {:.2e}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                    loss.data.item(), acc, optimizer.param_groups[0]['lr']))

        elapse_time = time.time() - t_begin
        speed_epoch = elapse_time / (epoch + 1)
        speed_batch = speed_epoch / len(train_loader)
        eta = speed_epoch * epochs - elapse_time
        print("Elapsed {:.2f}s, {:.2f} s/epoch, {:.2f} s/batch, ets {:.2f}s".format(
            elapse_time, speed_epoch, speed_batch, eta))
        # misc.model_snapshot(model, os.path.join(args.logdir, 'latest.pth'))

        if epoch % test_interval == 0:
            model.eval()
            test_loss = 0
            correct = 0
            for data, target in test_loader:
                indx_target = target.clone()
                if use_cuda:
                    data, target = data.cuda(), target.cuda()
                output = model(data)
                test_loss += F.cross_entropy(output, target).data.item()
                pred = output.data.max(1)[1]  # get the index of the max log-probability
                correct += pred.cpu().eq(indx_target).sum()

            test_loss = test_loss / len(test_loader) # average over number of mini-batch
            acc = 100. * correct / len(test_loader.dataset)
            print('\tTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
                test_loss, correct, len(test_loader.dataset), acc))
            if acc > best_acc:
                # new_file = os.path.join(args.logdir, 'best-{}.pth'.format(epoch))
                # misc.model_snapshot(model, new_file, old_file=old_file, verbose=True)
                best_acc = acc
                # old_file = new_file
    print("Total Elapse: {:.2f}, Best Result: {:.3f}%".format(time.time()-t_begin, best_acc))

In [26]:
def evaluate(model):
    model.eval()
    test_loss = 0
    correct = 0
    for data, target in test_loader:
        indx_target = target.clone()
        if use_cuda:
            data, target = data.cuda(), target.cuda()
        output = model(data)
        test_loss += F.cross_entropy(output, target).data.item()
        pred = output.data.max(1)[1]  # get the index of the max log-probability
        correct += pred.cpu().eq(indx_target).sum()

    test_loss = test_loss / len(test_loader) # average over number of mini-batch
    acc = 100. * correct / len(test_loader.dataset)
    print('\tTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
        test_loss, correct, len(test_loader.dataset), acc))

In [27]:
evaluate(model)

	Test set: Average loss: 0.3789, Accuracy: 9379/10000 (94%)


In [30]:
prune_model(model)

dict_keys(['features.0.weight_mask', 'features.1.running_mean', 'features.1.running_var', 'features.1.num_batches_tracked', 'features.3.weight_mask', 'features.4.running_mean', 'features.4.running_var', 'features.4.num_batches_tracked', 'features.7.weight_mask', 'features.8.running_mean', 'features.8.running_var', 'features.8.num_batches_tracked', 'features.10.weight_mask', 'features.11.running_mean', 'features.11.running_var', 'features.11.num_batches_tracked', 'features.14.weight_mask', 'features.15.running_mean', 'features.15.running_var', 'features.15.num_batches_tracked', 'features.17.weight_mask', 'features.18.running_mean', 'features.18.running_var', 'features.18.num_batches_tracked', 'features.21.weight_mask', 'features.22.running_mean', 'features.22.running_var', 'features.22.num_batches_tracked', 'classifier.0.weight_mask'])


In [31]:
evaluate(model)

	Test set: Average loss: 0.4045, Accuracy: 9172/10000 (92%)


In [34]:
train(model, epochs=10, lr=0.00001)

Train Epoch: 0 [20000/50000] Loss: 0.003458 Acc: 1.0000 lr: 1.00e-05
Train Epoch: 0 [40000/50000] Loss: 0.002431 Acc: 1.0000 lr: 1.00e-05
Elapsed 32.81s, 32.81 s/epoch, 0.13 s/batch, ets 295.32s
	Test set: Average loss: 0.3505, Accuracy: 9301/10000 (93%)
Train Epoch: 1 [20000/50000] Loss: 0.001057 Acc: 1.0000 lr: 1.00e-05
Train Epoch: 1 [40000/50000] Loss: 0.001426 Acc: 1.0000 lr: 1.00e-05
Elapsed 71.25s, 35.63 s/epoch, 0.14 s/batch, ets 285.02s
Train Epoch: 2 [20000/50000] Loss: 0.000681 Acc: 1.0000 lr: 1.00e-05
Train Epoch: 2 [40000/50000] Loss: 0.001516 Acc: 1.0000 lr: 1.00e-05
Elapsed 104.33s, 34.78 s/epoch, 0.14 s/batch, ets 243.44s
Train Epoch: 3 [20000/50000] Loss: 0.001103 Acc: 1.0000 lr: 1.00e-05
Train Epoch: 3 [40000/50000] Loss: 0.001624 Acc: 1.0000 lr: 1.00e-05
Elapsed 137.42s, 34.36 s/epoch, 0.14 s/batch, ets 206.13s
Train Epoch: 4 [20000/50000] Loss: 0.002847 Acc: 1.0000 lr: 1.00e-05
Train Epoch: 4 [40000/50000] Loss: 0.000381 Acc: 1.0000 lr: 1.00e-05
Elapsed 170.86s, 34.

In [35]:
evaluate(model)

	Test set: Average loss: 0.3371, Accuracy: 9341/10000 (93%)


In [36]:
prune_model(model, 0.2, 0.3)

dict_keys(['features.0.weight_mask', 'features.1.running_mean', 'features.1.running_var', 'features.1.num_batches_tracked', 'features.3.weight_mask', 'features.4.running_mean', 'features.4.running_var', 'features.4.num_batches_tracked', 'features.7.weight_mask', 'features.8.running_mean', 'features.8.running_var', 'features.8.num_batches_tracked', 'features.10.weight_mask', 'features.11.running_mean', 'features.11.running_var', 'features.11.num_batches_tracked', 'features.14.weight_mask', 'features.15.running_mean', 'features.15.running_var', 'features.15.num_batches_tracked', 'features.17.weight_mask', 'features.18.running_mean', 'features.18.running_var', 'features.18.num_batches_tracked', 'features.21.weight_mask', 'features.22.running_mean', 'features.22.running_var', 'features.22.num_batches_tracked', 'classifier.0.weight_mask'])


In [37]:
evaluate(model)

	Test set: Average loss: 0.3884, Accuracy: 9070/10000 (91%)


In [38]:
train(model, epochs=10, lr=0.00001)

Train Epoch: 0 [20000/50000] Loss: 0.003664 Acc: 1.0000 lr: 1.00e-05
Train Epoch: 0 [40000/50000] Loss: 0.005356 Acc: 1.0000 lr: 1.00e-05
Elapsed 33.06s, 33.06 s/epoch, 0.13 s/batch, ets 297.51s
	Test set: Average loss: 0.3214, Accuracy: 9282/10000 (93%)
Train Epoch: 1 [20000/50000] Loss: 0.007936 Acc: 0.9950 lr: 1.00e-05
Train Epoch: 1 [40000/50000] Loss: 0.001997 Acc: 1.0000 lr: 1.00e-05
Elapsed 73.73s, 36.87 s/epoch, 0.15 s/batch, ets 294.94s
Train Epoch: 2 [20000/50000] Loss: 0.004576 Acc: 1.0000 lr: 1.00e-05
Train Epoch: 2 [40000/50000] Loss: 0.005964 Acc: 0.9950 lr: 1.00e-05
Elapsed 108.01s, 36.00 s/epoch, 0.14 s/batch, ets 252.03s
Train Epoch: 3 [20000/50000] Loss: 0.002531 Acc: 1.0000 lr: 1.00e-05
Train Epoch: 3 [40000/50000] Loss: 0.002713 Acc: 1.0000 lr: 1.00e-05
Elapsed 142.42s, 35.61 s/epoch, 0.14 s/batch, ets 213.63s
Train Epoch: 4 [20000/50000] Loss: 0.006665 Acc: 0.9950 lr: 1.00e-05
Train Epoch: 4 [40000/50000] Loss: 0.002761 Acc: 1.0000 lr: 1.00e-05
Elapsed 176.80s, 35.