In [None]:
# all required modules
import matplotlib.pyplot as plt
import torch
from torchvision import datasets, transforms, utils

# Part 1. About dataset
Cifar-10 dataset is a dataset of 50,000 32x32 color training images, labeled over 10 categories, and 10,000 test images. There are 50000 training images and 10000 test images. Each image is 32x32 with 3 color channels (red, green, blue). The dataset is divided into five training batches and one test batch, each with 10000 images. The test batch contains exactly 1000 randomly-selected images from each class. The training batches contain the remaining images in random order, but some training batches may contain more images from one class than another. Between them, the training batches contain exactly 5000 images from each class.

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# extra transfrom for the training data, in order to achieve better performance
train_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    transforms.RandomCrop(32, padding=2, padding_mode='reflect'), 
    transforms.RandomHorizontalFlip(), 
])

batch_size = 16

trainset = datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=train_transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=False, num_workers=2)

testset = datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# functions to show an image


def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


# get some random training images
dataiter = iter(trainloader)
images, labels = next(dataiter)

# show images
imshow(utils.make_grid(images))

In [None]:
from models import ResNet, train_model
from models import device

base_model = ResNet().to(device=device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(base_model.parameters(), betas = (0.851436, 0.999689), amsgrad=True, lr = 8e-5)
loaders = {"train": trainloader, "valid": testloader}

### Weight regularization pruning

In [None]:
from weight_regularization import l1_s_norm

_, t_1 = train_model(base_model, optimizer, criterion,
                     loaders, N_epochs = 10, reg = l1_s_norm, display=True)

In [None]:
import torch.nn.utils.prune as prune
import copy
import numpy as np

amounts = 1e-2 * np.array([95, 97, 99, 99.5, 99.9])

train_time = []
acc = []

for amount in amounts:

    model = copy.deepcopy(base_model)
    model = copy.deepcopy(base_model)

    optimizer = optimizer = torch.optim.Adam(model.parameters(), betas = (0.851436, 0.999689), amsgrad=True, lr = 8e-5)
    criterion = torch.nn.CrossEntropyLoss()

    parameters = (
    (model.conv1[0], 'weight'),
    (model.conv2[0], 'weight'),
    (model.res1[0][0], 'weight'),
    (model.res1[1][0], 'weight'),
    (model.conv3[0], 'weight'),
    (model.conv4[0], 'weight'),
    (model.res2[0][0], 'weight'),
    (model.res2[1][0], 'weight'),
    (model.classifier[2], 'weight')
              )

    prune.global_unstructured(
        parameters,
        pruning_method=prune.L1Unstructured,
        amount=amount,
    )

    accuracy, t_2 = train_model(model, optimizer, criterion, loaders, N_epochs=30)

    print('Amount : {} %, Valid accuracy : {:.5f}, time {:.1f} s'.format(
        amount * 100, accuracy, t_1+t_2) )

    acc.append(accuracy)
    train_time.append(t_1 + t_2)

### Optimal brain damage

In [None]:
base_model = ResNet().to(device=device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(base_model.parameters(), betas = (0.851436, 0.999689), amsgrad=True, lr = 8e-5)
loaders = {"train": trainloader, "valid": testloader}

_, t_1 = train_model(base_model, optimizer, criterion, loaders, N_epochs = 10)

In [None]:
loss = 0

N = 10
i = 0

for x_batch, y_batch in trainloader:
    i += 1

    x_batch = x_batch.to(device)
    y_batch = y_batch.to(device)

    outp = base_model(x_batch)
    loss += criterion(outp, y_batch)
    if i == N:
        break

In [None]:
def pth_jacobian(y, x, create_graph = False):
    jac = []
    flat_y = y.reshape(-1)
    grad_y = torch.zeros_like(flat_y)
    grad_y = torch.zeros_like(flat_y)
    for i in range(len(flat_y)):
        grad_y[i] = 1.
        grad_x, = torch.autograd.grad(flat_y, x, grad_y, retain_graph=True, create_graph=create_graph)
        grad_x = grad_x.reshape(x.shape)
        jac.append(grad_x.reshape(x.shape))
        grad_y[i] = 0.

    return torch.stack(jac, axis = 0).reshape(y.shape + x.shape)

In [None]:
parameters = {
    'conv1' : base_model.conv1[0],
    'conv2' : base_model.conv2[0],
    'res1_conv1' : base_model.res1[0][0],
    'res1_conv2' : base_model.res1[1][0],
    'conv3' : base_model.conv3[0],
    'conv4' : base_model.conv4[0],
    'res2_conv1' : base_model.res2[0][0],
    'res2_conv2' : base_model.res2[1][0],
    'fc' : base_model.classifier[2]
}

h_ii = {}
s = {}
# calculate diagonal elements of hessian
for name, module in parameters.items():
    weights = module.weight
    grad = pth_jacobian(loss, weights, create_graph=True)

    flat_y = grad.reshape(-1)
    h_ii[name] = torch.zeros_like(flat_y)
    for i in range(len(flat_y)):
        grad_x, = torch.autograd.grad(flat_y[i], weights, retain_graph=True)
        h_ii[name][i] = grad_x.reshape(-1)[i].item()

    h_ii[name] = h_ii[name].reshape(weights.shape)
    s_ii = 1/2 * h_ii[name]  * module.weight ** 2

In [None]:
total_params = sum(p.numel() for p in model.parameters())
print(total_params)

In [None]:
# plot amount over threshold plot
import numpy as np

s_t_list = np.logspace(-6, -2.6, 15)
amount = []

for s_th in s_t_list:

    num = 0.

    for name, s_ii in s_all.items():
        num += torch.sum(s_ii > s_th).item()

    num /= total_params
    amount.append(num)

plt.plot(s_t_list, np.array(amount) * total_params, marker='o')
plt.xscale('log')
plt.xlabel('Importance threshold')
plt.ylabel('Number of parameters')
plt.savefig('./data/threshold.png', dpi = 140)

In [None]:
import torch.nn.utils.prune as prune
import copy

train_time = []
acc = []

for s_th in s_t_list:

    model = copy.deepcopy(base_model)

    optimizer = torch.optim.Adam(model.parameters(), lr = 5e-3)
    criterion = torch.nn.CrossEntropyLoss()

    parameters = {
      'conv1' : base_model.conv1[0],
      'conv2' : base_model.conv2[0],
      'res1_conv1' : base_model.res1[0][0],
      'res1_conv2' : base_model.res1[1][0],
      'conv3' : base_model.conv3[0],
      'conv4' : base_model.conv4[0],
      'res2_conv1' : base_model.res2[0][0],
      'res2_conv2' : base_model.res2[1][0],
      'fc' : base_model.classifier[2]
    }

    for name, module in parameters.items():
        mask = s[name] >= s_th
        torch.nn.utils.prune.CustomFromMask.apply(module, 'weight', mask)

    accuracy, t_2 = train_model(model, optimizer, criterion, loaders, N_epochs=30)

    print('s_th : {}, Valid accuracy : {:.5f}, time {:.1f} s'.format(
        s_th, accuracy, t_1+t_2) )

    acc.append(accuracy)
    train_time.append(t_1 + t_2)

print(acc)