In [1]:
from zlib import crc32
theory = crc32("Шульгин".lower().encode("utf-8"))%5+1
practice = crc32("Shulgin".lower().encode("utf-8"))%3+1

theory, practice

(1, 2)

# Task  
Реализовать пример удаления параметров для логистической регрессии на
MNIST и сравнить качество со случайным удалением параметров (ось X — процент удаленных
параметров):  
2) С использованием вариационного вывода (Graves, 2011);

In [2]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader, TensorDataset
from torch.autograd import Variable
from torchvision import datasets, transforms
from torch.distributions import MultivariateNormal

import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns
sns.set(style="whitegrid", context="talk", font_scale=1.5)

from copy import deepcopy
from tqdm import tqdm

Dataset loading

In [4]:
batch_size = 16

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)),])
trainset = datasets.MNIST('./data/', download=True, train=True, transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

testset = datasets.MNIST('./data/', download=True, train=False, transform=transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing...
Done!


Model class

In [4]:
class LogisticRegression(nn.Module):
    def __init__(self, features=784, classes=10):
        super(LogisticRegression, self).__init__()
        
        self.features = features
        self.classes = classes
        self.distr = MultivariateNormal(torch.zeros(self.features*self.classes), 
                                       precision_matrix=torch.eye(self.features*self.classes))
        self.loc = nn.Parameter(torch.randn(features * classes))
        self.var = nn.Parameter(torch.abs(torch.randn(features * classes)))
        
    def forward(self, x):
        x = x.view(-1, self.features)       
        z = self.distr.sample().to(device)
        w = self.loc + self.var * z
        return torch.mm(x, w.view(self.features, self.classes))
    
    def value(self, x):
        x = x.view(-1, self.features)
        return torch.mm(x, self.loc.view(self.features, self.classes))

In [5]:
device = 'cpu'
model = LogisticRegression().to(device)
criterion = nn.CrossEntropyLoss(reduction='sum')
optimizer = optim.Adam(model.parameters(), lr=0.001)

Training

In [None]:
for data, target in tqdm(train_loader):
    data, target = data.to(device), target.to(device)
    
    optimizer.zero_grad()
    cross_entropy = criterion(model(data), target)
    kl_divergence = 0.5 * (torch.sum(model.var) + torch.sum(model.loc * model.loc) - torch.sum(torch.log(model.var)))
    loss = cross_entropy + kl_divergence
    
    loss.backward()
    optimizer.step()

 75%|███████▍  | 2807/3750 [01:35<00:25, 37.53it/s]

In [None]:
def prune_params(params, lam):
    vb_params = deepcopy(params)
    rand_params = deepcopy(params)
    
    lambdas = torch.abs(params['loc'] / params['var'])
    mask = (lambdas < lam)
    vb_params['loc'][mask] = 0
    
    pruned = mask.sum().float() / mask.shape[0]
    
    mask = mask.view(-1)[torch.randperm(mask.nelement())].view(mask.size())
    rand_params['loc'][mask] = 0
    
    return vb_params, rand_params, pruned

def calculate_accuracy(loader, model):
    correct = 0
    total = 0
    
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        output = torch.argmax(model.value(data), dim=-1)
        correct += (output == target).float().sum()
        total += len(output)
        
    return correct / total

Pruning

In [None]:
source_params = deepcopy(model.state_dict())

x = (np.exp(np.linspace(0, 15, 14)) - 1) / 1e4
random_accs = []
vb_accs = []
pruned = []

for lam in tqdm(x):
    pruned_params, rand_parms, deleted = prune_params(source_params, lam)
    pruned.append(deleted)
    
    model.load_state_dict(rand_parms)
    model.eval()
    random_accs.append(calculate_accuracy(test_loader, model))
    
    model.load_state_dict(pruned_params)
    model.eval()
    vb_accs.append(calculate_accuracy(test_loader, model))

In [None]:
plt.figure(figsize=(10,8))
plt.plot(pruned, vb_accs, label='Variational Bayes')
plt.plot(pruned, random_accs, label='Random pruning')
plt.xlabel("% of pruned params")
plt.ylabel("Accuracy")
plt.title("Random vs VB")
plt.legend()
plt.show()

Как видно из графика, вариационный вывод показывает гораздо более хорошие показатели: средняя точность падает весьма слабо при удалении ~70% параметров.