In [1]:
import torch as t 
import torchvision
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pylab as plt
from torch.nn.utils import clip_grad_value_
%matplotlib inline
import pickle
from torchvision import datasets, transforms
import tqdm
import os

In [2]:
import importlib 
import src.var_net as var_net
import src.linear_var_hypernet as linear_var_hypernet
import src.lowrank_var_hypernet as lowrank_var_hypernet
import src.utils as utils
importlib.reload(var_net)
importlib.reload(linear_var_hypernet)
importlib.reload(lowrank_var_hypernet)
importlib.reload(utils)



<module 'src.utils' from '/home/legin/reps/VarHyperNet/code/src/utils.py'>

In [3]:
device = 'cpu:0' # cuda or cpu
device = t.device(device)
if device == 'cuda':
    t.backends.cudnn.deterministic = True
    t.backends.cudnn.benchmark = False

In [4]:
batch_size = 64
init_log_sigma = -3.0 # логарифм дисперсии вариационного распределения при инициализации
prior_sigma = 0.1 # априорная дисперсия
epoch_num = 5 #количество эпох
lamb = [0.1, 1,  10, 100]
hidden_num = 50 # количество нейронов на скрытом слое
acc_delete = [] 
start_num = 1


lowrank_hidden_num = 50 # количество нейронов на скрытом слое в low-rank слое
lambda_sample_num = 5
path_to_save = 'saved_mnist'

if not os.path.exists(path_to_save):
    os.mkdir(path_to_save)

In [None]:
# загрузка данных
train_data = torchvision.datasets.MNIST('./files/', train=True, download=True,
                             transform = transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.5,), (0.5,)),
                                  torchvision.transforms.Lambda(lambda x: x.view(-1))
                              ]))

test_data = torchvision.datasets.MNIST('./files/', train=False, download=True,
                             transform = transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.5,), (0.5,)),
                                  torchvision.transforms.Lambda(lambda x: x.view(-1))
                              ]))


train_loader = t.utils.data.DataLoader(train_data, batch_size=batch_size, pin_memory=True )
test_loader = t.utils.data.DataLoader(test_data, batch_size=batch_size)

def test_acc(net): # точность классификации
    acc = []    
    correct = 0
    net.eval()
    for x,y in test_loader: 
        x = x.to(device)
        y = y.to(device)  
        out = net(x)    
        correct += out.argmax(1).eq(y).sum().cpu().numpy()
    acc = (correct / len(test_data))

    return acc




In [None]:
def train_batches(net, loss_fn, optimizer, lam, label):
    tq = tqdm.tqdm(train_loader)
    losses = []
    for x,y in tq:            
        x = x.to(device)
        y = y.to(device)          
        optimizer.zero_grad()  
        loss = 0
        if lam is None:
            for _ in range(lambda_sample_num):  
                p = t.rand(1).to(device)*3 -1
                lam_param = 10**p[0]                
                #t.rand(1).to(device)[0]*100.0                  
                out = net(x, lam_param/100.0)
                loss = loss + loss_fn(out, y)/lambda_sample_num
                loss += net.KLD(lam_param/100.0)*lam_param/len(train_data)/lambda_sample_num
                losses+=[loss.cpu().detach().numpy()]
            # правдоподобие должно суммироваться по всей обучающей выборке
            # в случае батчей - она приводится к тому же порядку 
        else:
            out = net(x)
            loss = loss + loss_fn(out, y)
            loss += net.KLD()*lam/len(train_data)
            losses+=[loss.cpu().detach().numpy()]
        tq.set_description(label+str(np.mean(losses)))
        loss.backward()       
        clip_grad_value_(net.parameters(), 1.0) # для стабильности градиента. С этим можно играться
        optimizer.step()

In [None]:
t.manual_seed(0)
for lam in lamb:
    for start in range(start_num):                    
            net = var_net.VarNet(var_net.VarLayer(784,  hidden_num, 
                                            prior_sigma = prior_sigma, init_log_sigma=init_log_sigma), 
                                 var_net.VarLayer(hidden_num, 10, prior_sigma=prior_sigma,  init_log_sigma=init_log_sigma,
                                          act=lambda x:x))
            
            net = net.to(device)
            optim = t.optim.Adam(net.parameters(), lr=5e-4)
            loss_fn = nn.CrossEntropyLoss().to(device)            
            for e in range(epoch_num):
                label = 'lambda {}, epoch {}: '.format(lam, e)                
                train_batches(net, loss_fn, optim, lam, label)
            t.save(net.state_dict(), os.path.join(path_to_save, 'var_net_lam_{}_start_{}.cpk'.format(lam, start)))

In [None]:
t.manual_seed(0)
for start in range(start_num):         
    net = var_net.VarNet(linear_var_hypernet.VarLayerLinearAppr(784,  hidden_num, prior_sigma=prior_sigma, init_log_sigma=init_log_sigma),
                         linear_var_hypernet.VarLayerLinearAppr(hidden_num, 10,  prior_sigma=prior_sigma, act=lambda x:x, init_log_sigma=init_log_sigma))
    net = net.to(device)
    optim = t.optim.Adam(net.parameters(), lr=5e-4)
    loss_fn = nn.CrossEntropyLoss().to(device)            
    for e in range(epoch_num):
        label = 'linear, epoch {}: '.format(e)                
        train_batches(net, loss_fn, optim, None, label)
    t.save(net.state_dict(), os.path.join(path_to_save, 'linear_start_{}.cpk'.format( start)))

In [None]:
t.manual_seed(0)
for start in range(start_num):         
    net = var_net.VarNet(lowrank_var_hypernet.VarLayerLowRank(784,  hidden_num, lowrank_hidden_num, init_log_sigma=init_log_sigma,  prior_sigma=prior_sigma), lowrank_var_hypernet.VarLayerLowRank(hidden_num, 10,  lowrank_hidden_num,
    prior_sigma=prior_sigma, act=lambda x:x, init_log_sigma=init_log_sigma))
    net = net.to(device)
    optim = t.optim.Adam(net.parameters(), lr=5e-4)
    loss_fn = nn.CrossEntropyLoss().to(device)            
    for e in range(epoch_num):
        label = 'lowrank, epoch {}: '.format(e)                
        train_batches(net, loss_fn, optim, None, label)
    t.save(net.state_dict(), os.path.join(path_to_save, 'lowrank_start_{}.cpk'.format( start)))

In [None]:
net = var_net.VarNet(var_net.VarLayer(784,  hidden_num, 
                                            prior_sigma = prior_sigma, init_log_sigma=init_log_sigma), 
                                 var_net.VarLayer(hidden_num, 10, prior_sigma=prior_sigma,  init_log_sigma=init_log_sigma,
                                          act=lambda x:x))               
lam_results = {}
for lam in lamb:
    lam_results[lam] = []
    for s in range(start_num):
        print (lam, s)
        net.load_state_dict(t.load(os.path.join(path_to_save, 'var_net_lam_{}_start_{}.cpk'.format(lam, s))))
        lam_results[lam].append(utils.delete_10(net, device, lambda:test_acc(net)))
import json
with open(os.path.join(saved, 'results_var.json'),'w') as out:
    out.write(json.dumps(lam_results))        

In [None]:
proc = [0,10,20,30,40,50,60,70,80,90]
plt.rcParams['figure.figsize'] = 12, 12
plt.rcParams.update({'font.size': 27})
plt.rc('lines', linewidth=4)
    
    
for lam in lamb:
    plt.fill_between(proc, np.min(lam_results[lam], 0), np.max(lam_results[lam], 0), alpha=0.2)
    plt.plot(proc, np.mean(lam_results[lam], 0), label='$\lambda={}$'.format(lam))
plt.ylabel('Точность классификации', fontsize = 27)
plt.xlabel('Процент удаления', fontsize = 27)
plt.tick_params(axis='both', which='major', labelsize=27)
plt.legend(loc='lower left')
plt.autoscale(enable=True, axis='x', tight=True)
#plt.savefig('Hypernet_lowrank3 + 1')
#plt.show()


In [None]:
net = var_net.VarNet(var_net.VarLayer(784,  hidden_num, 
                                            prior_sigma = prior_sigma, init_log_sigma=init_log_sigma), 
                                 var_net.VarLayer(hidden_num, 10, prior_sigma=prior_sigma,  init_log_sigma=init_log_sigma,
                                          act=lambda x:x))    
hnet = var_net.VarNet(linear_var_hypernet.VarLayerLinearAppr(784,  hidden_num, prior_sigma=prior_sigma, init_log_sigma=init_log_sigma),
                         linear_var_hypernet.VarLayerLinearAppr(hidden_num, 10,  prior_sigma=prior_sigma, act=lambda x:x, init_log_sigma=init_log_sigma))

lam_results = {}
for lam in lamb:
    lam_results[lam] = []
    for s in range(start_num):        
        hnet.load_state_dict(t.load(os.path.join(path_to_save, 'linear_start_{}.cpk'.format(s))))
        utils.net_copy(hnet, net, lam/100.0)
        lam_results[lam].append(utils.delete_10(net, device, lambda:test_acc(net)))
import json
with open(os.path.join(saved, 'results_linear.json'),'w') as out:
    out.write(json.dumps(lam_results))        

In [None]:
proc = [0,10,20,30,40,50,60,70,80,90]
plt.rcParams['figure.figsize'] = 12, 12
plt.rcParams.update({'font.size': 27})
plt.rc('lines', linewidth=4)
    
    
for lam in lamb:
    plt.fill_between(proc, np.min(lam_results[lam], 0), np.max(lam_results[lam], 0), alpha=0.2)
    plt.plot(proc, np.mean(lam_results[lam], 0), label='$\lambda={}$'.format(lam))
plt.ylabel('Точность классификации', fontsize = 27)
plt.xlabel('Процент удаления', fontsize = 27)
plt.tick_params(axis='both', which='major', labelsize=27)
plt.legend(loc='lower left')
plt.autoscale(enable=True, axis='x', tight=True)
#plt.savefig('Hypernet_lowrank3 + 1')
#plt.show()


In [None]:
net = var_net.VarNet(var_net.VarLayer(784,  hidden_num, 
                                            prior_sigma = prior_sigma, init_log_sigma=init_log_sigma), 
                                 var_net.VarLayer(hidden_num, 10, prior_sigma=prior_sigma,  init_log_sigma=init_log_sigma,
                                          act=lambda x:x))    
hnet = var_net.VarNet(lowrank_var_hypernet.VarLayerLowRank(784,  hidden_num, lowrank_hidden_num, init_log_sigma=init_log_sigma,  prior_sigma=prior_sigma), lowrank_var_hypernet.VarLayerLowRank(hidden_num, 10,  lowrank_hidden_num,
prior_sigma=prior_sigma, act=lambda x:x, init_log_sigma=init_log_sigma))

lam_results = {}
for lam in lamb:
    lam_results[lam] = []
    for s in range(start_num):        
        hnet.load_state_dict(t.load(os.path.join(path_to_save, 'lowrank_start_{}.cpk'.format(s))))
        utils.net_copy(hnet, net, lam/100.0)
        lam_results[lam].append(utils.delete_10(net, device, lambda:test_acc(net)))
import json
with open(os.path.join(saved, 'results_lowrank.json'),'w') as out:
    out.write(json.dumps(lam_results))        

In [None]:
proc = [0,10,20,30,40,50,60,70,80,90]
plt.rcParams['figure.figsize'] = 12, 12
plt.rcParams.update({'font.size': 27})
plt.rc('lines', linewidth=4)
    
    
for lam in lamb:
    plt.fill_between(proc, np.min(lam_results[lam], 0), np.max(lam_results[lam], 0), alpha=0.2)
    plt.plot(proc, np.mean(lam_results[lam], 0), label='$\lambda={}$'.format(lam))
plt.ylabel('Точность классификации', fontsize = 27)
plt.xlabel('Процент удаления', fontsize = 27)
plt.tick_params(axis='both', which='major', labelsize=27)
plt.legend(loc='lower left')
plt.autoscale(enable=True, axis='x', tight=True)
#plt.savefig('Hypernet_lowrank3 + 1')
#plt.show()


In [None]:
for mode in ['results_var','results_lowrank', 'results_linear']:    
    with open(os.path.join(saved, mode+'.json')) as inp:
        lam_results = json.loads(inp.read())
    proc = [0,10,20,30,40,50,60,70,80,90]
    plt.rcParams['figure.figsize'] = 8, 8
    plt.rcParams.update({'font.size': 12})
    plt.rc('lines', linewidth=4)
    plt.title(mode)

    for lam in lamb:
        lam = str(lam)
        plt.fill_between(proc, np.min(lam_results[lam], 0), np.max(lam_results[lam], 0), alpha=0.2)
        plt.plot(proc, np.mean(lam_results[lam], 0), label='$\lambda={}$'.format(lam))
    plt.ylabel('Точность классификации', fontsize = 12)
    plt.xlabel('Процент удаления', fontsize = 12)
    plt.tick_params(axis='both', which='major', labelsize=12)
    plt.legend(loc='lower left')
    plt.autoscale(enable=True, axis='x', tight=True)    
    plt.show()
