In [18]:
import torch as t 
import torchvision
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pylab as plt
from torch.nn.utils import clip_grad_value_
%matplotlib inline
import pickle
from torchvision import datasets, transforms
import tqdm
import os
import sys
sys.path.append('../src/')

In [19]:
import importlib 
import var_net as var_net
import mnist_utils as utils
importlib.reload(utils)
importlib.reload(var_net)



<module 'var_net' from '../src/var_net.py'>

In [3]:
device = 'cuda' # cuda or cpu
device = t.device(device)
if device == 'cuda':
    t.backends.cudnn.deterministic = True
    t.backends.cudnn.benchmark = False

In [4]:
batch_size = 256
init_log_sigma = -3.0 # логарифм дисперсии вариационного распределения при инициализации
prior_sigma = .1 # априорная дисперсия
epoch_num = 25 #количество эпох
lamb = [0.01, 0.1, 1,  10, 100]

hidden_num = 50 # количество нейронов на скрытом слое
acc_delete = [] 
start_num = 5
path_to_save = 'saved_mnist'

if not os.path.exists(path_to_save):
    os.mkdir(path_to_save)

In [6]:
# загрузка данных
train_data = torchvision.datasets.MNIST('./files/', train=True, download=True,
                             transform = transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.5,), (0.5,)),
                                  torchvision.transforms.Lambda(lambda x: x.view(-1))
                              ]))

test_data = torchvision.datasets.MNIST('./files/', train=False, download=True,
                             transform = transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.5,), (0.5,)),
                                  torchvision.transforms.Lambda(lambda x: x.view(-1))
                              ]))


train_loader = t.utils.data.DataLoader(train_data, batch_size=batch_size, pin_memory=True )
test_loader = t.utils.data.DataLoader(test_data, batch_size=batch_size)





In [None]:
t.manual_seed(0)
for lam in lamb:
    for start in range(start_num):                    
            net = var_net.VarNet(var_net.VarLayer(784,  hidden_num, 
                                            prior_sigma = prior_sigma, init_log_sigma=init_log_sigma), 
                                 var_net.VarLayer(hidden_num, 10, prior_sigma=prior_sigma,  init_log_sigma=init_log_sigma,
                                          act=lambda x:x))
            
            net = net.to(device)
            optim = t.optim.Adam(net.parameters(), lr=5e-4)
            loss_fn = nn.CrossEntropyLoss().to(device)            
            for e in range(epoch_num):
                label = 'lambda {}, epoch {}: '.format(lam, e)                
                utils.train_batches_net(train_loader,len(train_data),  net, device, loss_fn, optim, t.tensor(lam), label, rep=True)
                if e%5 == 0:
                    print (utils.test_acc_net(net, device, test_loader))
            t.save(net.state_dict(), os.path.join(path_to_save, 'rep_net_lam_{}_start_{}.cpk'.format(lam, start)))

lambda 0.01, epoch 0: 3.0935786: 100%|██████████| 235/235 [00:16<00:00, 13.84it/s]
lambda 0.01, epoch 1: 2.2388937:   1%|          | 2/235 [00:00<00:14, 16.61it/s]

0.8345


lambda 0.01, epoch 1: 2.0767634: 100%|██████████| 235/235 [00:16<00:00, 14.21it/s]
lambda 0.01, epoch 2: 1.988475: 100%|██████████| 235/235 [00:16<00:00, 14.63it/s] 
lambda 0.01, epoch 3: 1.9374174: 100%|██████████| 235/235 [00:16<00:00, 14.49it/s]
lambda 0.01, epoch 4: 1.893619: 100%|██████████| 235/235 [00:16<00:00, 14.27it/s] 
lambda 0.01, epoch 5: 1.854912: 100%|██████████| 235/235 [00:16<00:00, 14.12it/s] 
lambda 0.01, epoch 6: 1.8272017:   1%|          | 2/235 [00:00<00:15, 15.34it/s]

0.9359


lambda 0.01, epoch 6: 1.820779: 100%|██████████| 235/235 [00:16<00:00, 14.48it/s] 
lambda 0.01, epoch 7: 1.7884583: 100%|██████████| 235/235 [00:16<00:00, 14.50it/s]
lambda 0.01, epoch 8: 1.7571614: 100%|██████████| 235/235 [00:16<00:00, 13.92it/s]
lambda 0.01, epoch 9: 1.7277081: 100%|██████████| 235/235 [00:16<00:00, 14.48it/s]
lambda 0.01, epoch 10: 1.699405: 100%|██████████| 235/235 [00:15<00:00, 15.02it/s] 
lambda 0.01, epoch 11: 1.6936074:   1%|          | 2/235 [00:00<00:17, 13.25it/s]

0.9533


lambda 0.01, epoch 11: 1.6716992: 100%|██████████| 235/235 [00:15<00:00, 14.76it/s]
lambda 0.01, epoch 12: 1.6441003: 100%|██████████| 235/235 [00:16<00:00, 13.95it/s]
lambda 0.01, epoch 13: 1.6161317: 100%|██████████| 235/235 [00:16<00:00, 14.14it/s]
lambda 0.01, epoch 14: 1.5877055: 100%|██████████| 235/235 [00:16<00:00, 14.42it/s]
lambda 0.01, epoch 15: 1.5584408: 100%|██████████| 235/235 [00:16<00:00, 14.44it/s]
lambda 0.01, epoch 16: 1.5591784:   1%|          | 2/235 [00:00<00:15, 14.63it/s]

0.9609


lambda 0.01, epoch 16: 1.5282521: 100%|██████████| 235/235 [00:16<00:00, 14.06it/s]
lambda 0.01, epoch 17: 1.4968576: 100%|██████████| 235/235 [00:15<00:00, 14.90it/s]
lambda 0.01, epoch 18: 1.4641479: 100%|██████████| 235/235 [00:16<00:00, 14.35it/s]
lambda 0.01, epoch 19: 1.4298816: 100%|██████████| 235/235 [00:16<00:00, 14.18it/s]
lambda 0.01, epoch 20: 1.394397: 100%|██████████| 235/235 [00:16<00:00, 13.88it/s] 
lambda 0.01, epoch 21: 1.3945572:   1%|          | 2/235 [00:00<00:19, 11.72it/s]

0.9658


lambda 0.01, epoch 21: 1.3572217: 100%|██████████| 235/235 [00:16<00:00, 14.22it/s]
lambda 0.01, epoch 22: 1.3183703: 100%|██████████| 235/235 [00:16<00:00, 14.15it/s]
lambda 0.01, epoch 23: 1.277953: 100%|██████████| 235/235 [00:16<00:00, 14.38it/s] 
lambda 0.01, epoch 24: 1.2360145: 100%|██████████| 235/235 [00:15<00:00, 14.92it/s]
lambda 0.01, epoch 0: 3.1298442: 100%|██████████| 235/235 [00:16<00:00, 14.15it/s]
lambda 0.01, epoch 1: 2.2004023:   1%|          | 2/235 [00:00<00:12, 19.09it/s]

0.8489


lambda 0.01, epoch 1: 2.0783312: 100%|██████████| 235/235 [00:16<00:00, 14.50it/s]
lambda 0.01, epoch 2: 1.9848433: 100%|██████████| 235/235 [00:16<00:00, 14.41it/s]
lambda 0.01, epoch 3: 1.9335927: 100%|██████████| 235/235 [00:16<00:00, 14.57it/s]
lambda 0.01, epoch 4: 1.8941554: 100%|██████████| 235/235 [00:16<00:00, 14.66it/s]
lambda 0.01, epoch 5: 1.8593812: 100%|██████████| 235/235 [00:16<00:00, 14.16it/s]
lambda 0.01, epoch 6: 1.8415438:   1%|          | 2/235 [00:00<00:14, 16.05it/s]

0.9338


lambda 0.01, epoch 6: 1.8273203: 100%|██████████| 235/235 [00:15<00:00, 15.14it/s]
lambda 0.01, epoch 7: 1.7967103: 100%|██████████| 235/235 [00:16<00:00, 14.15it/s]
lambda 0.01, epoch 8: 1.7668797: 100%|██████████| 235/235 [00:16<00:00, 14.35it/s]
lambda 0.01, epoch 9: 1.737649: 100%|██████████| 235/235 [00:16<00:00, 14.39it/s] 
lambda 0.01, epoch 10: 1.7089196: 100%|██████████| 235/235 [00:16<00:00, 14.27it/s]
lambda 0.01, epoch 11: 1.6965008:   1%|          | 2/235 [00:00<00:11, 19.81it/s]

0.9472


lambda 0.01, epoch 11: 1.6799262: 100%|██████████| 235/235 [00:16<00:00, 14.45it/s]
lambda 0.01, epoch 12: 1.6508029: 100%|██████████| 235/235 [00:16<00:00, 14.26it/s]
lambda 0.01, epoch 13: 1.6207767: 100%|██████████| 235/235 [00:15<00:00, 14.87it/s]
lambda 0.01, epoch 14: 1.5894728: 100%|██████████| 235/235 [00:16<00:00, 14.19it/s]
lambda 0.01, epoch 15: 1.5571113: 100%|██████████| 235/235 [00:16<00:00, 14.50it/s]
lambda 0.01, epoch 16: 1.5515974:   1%|          | 2/235 [00:00<00:13, 16.74it/s]

0.9553


lambda 0.01, epoch 16: 1.5238476: 100%|██████████| 235/235 [00:16<00:00, 14.23it/s]
lambda 0.01, epoch 17: 1.4893647: 100%|██████████| 235/235 [00:16<00:00, 14.39it/s]
lambda 0.01, epoch 18: 1.4534336: 100%|██████████| 235/235 [00:15<00:00, 14.83it/s]
lambda 0.01, epoch 19: 1.4161633: 100%|██████████| 235/235 [00:16<00:00, 14.31it/s]
lambda 0.01, epoch 20: 1.3772165: 100%|██████████| 235/235 [00:15<00:00, 14.75it/s]
lambda 0.01, epoch 21: 1.3727509:   1%|          | 2/235 [00:00<00:18, 12.74it/s]

0.9587


lambda 0.01, epoch 21: 1.3367704: 100%|██████████| 235/235 [00:16<00:00, 14.00it/s]
lambda 0.01, epoch 22: 1.2946718: 100%|██████████| 235/235 [00:16<00:00, 13.98it/s]
lambda 0.01, epoch 23: 1.2510365: 100%|██████████| 235/235 [00:16<00:00, 14.34it/s]
lambda 0.01, epoch 24: 1.2058861: 100%|██████████| 235/235 [00:16<00:00, 14.44it/s]
lambda 0.01, epoch 0: 3.2010841: 100%|██████████| 235/235 [00:16<00:00, 14.13it/s]
lambda 0.01, epoch 1: 2.2628798:   1%|          | 2/235 [00:00<00:12, 18.32it/s]

0.8368


lambda 0.01, epoch 1: 2.0965922: 100%|██████████| 235/235 [00:15<00:00, 14.72it/s]
lambda 0.01, epoch 2: 2.000088: 100%|██████████| 235/235 [00:16<00:00, 14.40it/s] 
lambda 0.01, epoch 3: 1.9509931: 100%|██████████| 235/235 [00:16<00:00, 14.36it/s]
lambda 0.01, epoch 4: 1.9132193: 100%|██████████| 235/235 [00:17<00:00, 13.81it/s]
lambda 0.01, epoch 5: 1.8797823: 100%|██████████| 235/235 [00:16<00:00, 14.23it/s]
lambda 0.01, epoch 6: 1.8613377:   1%|          | 2/235 [00:00<00:13, 16.72it/s]

0.9283


lambda 0.01, epoch 6: 1.8481175: 100%|██████████| 235/235 [00:16<00:00, 14.42it/s]
lambda 0.01, epoch 7: 1.8172925: 100%|██████████| 235/235 [00:16<00:00, 14.11it/s]
lambda 0.01, epoch 8: 1.7871468: 100%|██████████| 235/235 [00:16<00:00, 14.30it/s]
lambda 0.01, epoch 9: 1.7571745: 100%|██████████| 235/235 [00:15<00:00, 14.84it/s]
lambda 0.01, epoch 10: 1.7270874: 100%|██████████| 235/235 [00:16<00:00, 14.09it/s]
lambda 0.01, epoch 11: 1.7245218:   1%|          | 2/235 [00:00<00:19, 11.78it/s]

0.9449


lambda 0.01, epoch 11: 1.696841: 100%|██████████| 235/235 [00:16<00:00, 14.55it/s] 
lambda 0.01, epoch 12: 1.6662034: 100%|██████████| 235/235 [00:16<00:00, 14.31it/s]
lambda 0.01, epoch 13: 1.635303: 100%|██████████| 235/235 [00:16<00:00, 14.55it/s] 
lambda 0.01, epoch 14: 1.603666: 100%|██████████| 235/235 [00:16<00:00, 14.29it/s] 
lambda 0.01, epoch 15: 1.5710187: 100%|██████████| 235/235 [00:16<00:00, 14.35it/s]
lambda 0.01, epoch 16: 1.55851:   1%|          | 2/235 [00:00<00:14, 16.37it/s]  

0.9557


lambda 0.01, epoch 16: 1.5374793: 100%|██████████| 235/235 [00:15<00:00, 14.72it/s]
lambda 0.01, epoch 17: 1.5028019: 100%|██████████| 235/235 [00:16<00:00, 13.98it/s]
lambda 0.01, epoch 18: 1.4669621: 100%|██████████| 235/235 [00:16<00:00, 14.13it/s]
lambda 0.01, epoch 19: 1.429614: 100%|██████████| 235/235 [00:16<00:00, 14.45it/s] 
lambda 0.01, epoch 20: 1.3909254: 100%|██████████| 235/235 [00:17<00:00, 13.81it/s]
lambda 0.01, epoch 21: 1.3787603:   1%|          | 2/235 [00:00<00:13, 17.08it/s]

0.96


lambda 0.01, epoch 21: 1.3506997: 100%|██████████| 235/235 [00:16<00:00, 14.19it/s]
lambda 0.01, epoch 22: 1.3089917: 100%|██████████| 235/235 [00:16<00:00, 14.19it/s]
lambda 0.01, epoch 23: 1.2658359: 100%|██████████| 235/235 [00:15<00:00, 14.71it/s]
lambda 0.01, epoch 24: 1.2211046: 100%|██████████| 235/235 [00:16<00:00, 14.42it/s]
lambda 0.01, epoch 0: 3.0537035: 100%|██████████| 235/235 [00:16<00:00, 14.17it/s]
lambda 0.01, epoch 1: 2.269013:   1%|          | 2/235 [00:00<00:15, 15.48it/s] 

0.8327


lambda 0.01, epoch 1: 2.087068: 100%|██████████| 235/235 [00:16<00:00, 14.41it/s] 
lambda 0.01, epoch 2: 1.9872317: 100%|██████████| 235/235 [00:16<00:00, 14.39it/s]
lambda 0.01, epoch 3: 1.9317696: 100%|██████████| 235/235 [00:16<00:00, 14.14it/s]
lambda 0.01, epoch 4: 1.8871639: 100%|██████████| 235/235 [00:16<00:00, 14.39it/s]
lambda 0.01, epoch 5: 1.8471907: 100%|██████████| 235/235 [00:16<00:00, 14.65it/s]
lambda 0.01, epoch 6: 1.8156791:   1%|          | 2/235 [00:00<00:14, 15.75it/s]

0.9346


lambda 0.01, epoch 6: 1.8101621: 100%|██████████| 235/235 [00:16<00:00, 14.31it/s]
lambda 0.01, epoch 7: 1.7751112: 100%|██████████| 235/235 [00:16<00:00, 14.22it/s]
lambda 0.01, epoch 8: 1.7417102: 100%|██████████| 235/235 [00:16<00:00, 14.37it/s]
lambda 0.01, epoch 9: 1.7089603: 100%|██████████| 235/235 [00:16<00:00, 14.36it/s]
lambda 0.01, epoch 10: 1.6765815: 100%|██████████| 235/235 [00:16<00:00, 14.22it/s]
lambda 0.01, epoch 11: 1.6574821:   1%|          | 2/235 [00:00<00:18, 12.82it/s]

0.9528


lambda 0.01, epoch 11: 1.6441196: 100%|██████████| 235/235 [00:16<00:00, 14.24it/s]
lambda 0.01, epoch 12: 1.6111811: 100%|██████████| 235/235 [00:16<00:00, 14.59it/s]
lambda 0.01, epoch 13: 1.5777295: 100%|██████████| 235/235 [00:16<00:00, 14.20it/s]
lambda 0.01, epoch 14: 1.5434242: 100%|██████████| 235/235 [00:16<00:00, 14.58it/s]
lambda 0.01, epoch 15: 1.508171: 100%|██████████| 235/235 [00:16<00:00, 14.52it/s] 
lambda 0.01, epoch 16: 1.4921764:   1%|          | 2/235 [00:00<00:16, 13.73it/s]

0.9602


lambda 0.01, epoch 16: 1.472136: 100%|██████████| 235/235 [00:16<00:00, 14.15it/s] 
lambda 0.01, epoch 17: 1.4346788: 100%|██████████| 235/235 [00:16<00:00, 13.92it/s]
lambda 0.01, epoch 18: 1.3960608: 100%|██████████| 235/235 [00:16<00:00, 14.12it/s]
lambda 0.01, epoch 19: 1.3561014: 100%|██████████| 235/235 [00:16<00:00, 14.65it/s]
lambda 0.01, epoch 20: 1.3148725: 100%|██████████| 235/235 [00:16<00:00, 14.40it/s]
lambda 0.01, epoch 21: 1.2992383:   1%|          | 2/235 [00:00<00:15, 14.63it/s]

0.9622


lambda 0.01, epoch 21: 1.2722232: 100%|██████████| 235/235 [00:16<00:00, 14.29it/s]
lambda 0.01, epoch 22: 1.2283667: 100%|██████████| 235/235 [00:16<00:00, 14.25it/s]
lambda 0.01, epoch 23: 1.183279: 100%|██████████| 235/235 [00:16<00:00, 14.23it/s] 
lambda 0.01, epoch 24: 1.1369346: 100%|██████████| 235/235 [00:16<00:00, 14.17it/s]
lambda 0.01, epoch 0: 3.1493034: 100%|██████████| 235/235 [00:16<00:00, 13.93it/s]
lambda 0.01, epoch 1: 2.220579:   1%|          | 2/235 [00:00<00:13, 16.70it/s] 

0.8398


lambda 0.01, epoch 1: 2.0810573: 100%|██████████| 235/235 [00:15<00:00, 14.77it/s]
lambda 0.01, epoch 2: 1.9853176: 100%|██████████| 235/235 [00:16<00:00, 14.07it/s]
lambda 0.01, epoch 3: 1.9344174: 100%|██████████| 235/235 [00:16<00:00, 13.99it/s]
lambda 0.01, epoch 4: 1.8950585: 100%|██████████| 235/235 [00:16<00:00, 14.41it/s]
lambda 0.01, epoch 5: 1.8608328: 100%|██████████| 235/235 [00:16<00:00, 14.53it/s]
lambda 0.01, epoch 6: 1.8296312:   1%|          | 2/235 [00:00<00:15, 14.65it/s]

0.9356


lambda 0.01, epoch 6: 1.8297491: 100%|██████████| 235/235 [00:16<00:00, 14.12it/s]
lambda 0.01, epoch 7: 1.800473: 100%|██████████| 235/235 [00:16<00:00, 14.48it/s] 
lambda 0.01, epoch 8: 1.7723053: 100%|██████████| 235/235 [00:15<00:00, 15.08it/s]
lambda 0.01, epoch 9: 1.7443793: 100%|██████████| 235/235 [00:16<00:00, 14.24it/s]
lambda 0.01, epoch 10: 1.7163376: 100%|██████████| 235/235 [00:16<00:00, 14.44it/s]
lambda 0.01, epoch 11: 1.6968254:   1%|          | 2/235 [00:00<00:13, 17.08it/s]

0.9491


lambda 0.01, epoch 11: 1.6881756: 100%|██████████| 235/235 [00:16<00:00, 14.23it/s]
lambda 0.01, epoch 12: 1.6594009: 100%|██████████| 235/235 [00:16<00:00, 14.50it/s]
lambda 0.01, epoch 13: 1.6298597: 100%|██████████| 235/235 [00:16<00:00, 14.42it/s]
lambda 0.01, epoch 14: 1.5991445: 100%|██████████| 235/235 [00:16<00:00, 13.96it/s]
lambda 0.01, epoch 15: 1.5671825: 100%|██████████| 235/235 [00:16<00:00, 14.34it/s]
lambda 0.01, epoch 16: 1.5501941:   1%|          | 2/235 [00:00<00:11, 19.74it/s]

0.9569


lambda 0.01, epoch 16: 1.533839: 100%|██████████| 235/235 [00:16<00:00, 14.01it/s] 
lambda 0.01, epoch 17: 1.4989972: 100%|██████████| 235/235 [00:16<00:00, 14.31it/s]
lambda 0.01, epoch 18: 1.4624943: 100%|██████████| 235/235 [00:16<00:00, 14.48it/s]
lambda 0.01, epoch 19: 1.4243866: 100%|██████████| 235/235 [00:16<00:00, 14.26it/s]
lambda 0.01, epoch 20: 1.3843585: 100%|██████████| 235/235 [00:16<00:00, 14.39it/s]
lambda 0.01, epoch 21: 1.3668305:   1%|          | 2/235 [00:00<00:13, 16.78it/s]

0.9598


lambda 0.01, epoch 21: 1.3425181: 100%|██████████| 235/235 [00:15<00:00, 14.79it/s]
lambda 0.01, epoch 22: 1.2990117: 100%|██████████| 235/235 [00:15<00:00, 15.04it/s]
lambda 0.01, epoch 23: 1.2537543: 100%|██████████| 235/235 [00:16<00:00, 13.94it/s]
lambda 0.01, epoch 24: 1.206791: 100%|██████████| 235/235 [00:16<00:00, 13.95it/s] 
lambda 0.1, epoch 0: 2.4870646: 100%|██████████| 235/235 [00:16<00:00, 13.95it/s]
lambda 0.1, epoch 1: 1.4689261:   1%|          | 2/235 [00:00<00:17, 13.45it/s]

0.834


lambda 0.1, epoch 1: 1.3567461: 100%|██████████| 235/235 [00:16<00:00, 14.53it/s]
lambda 0.1, epoch 2: 1.2529933: 100%|██████████| 235/235 [00:16<00:00, 14.09it/s]
lambda 0.1, epoch 3: 1.1946988: 100%|██████████| 235/235 [00:16<00:00, 14.49it/s]
lambda 0.1, epoch 4: 1.1511132: 100%|██████████| 235/235 [00:15<00:00, 14.93it/s]
lambda 0.1, epoch 5: 1.1161454: 100%|██████████| 235/235 [00:16<00:00, 14.22it/s]
lambda 0.1, epoch 6: 1.0999607:   1%|          | 2/235 [00:00<00:12, 18.05it/s]

0.9418


lambda 0.1, epoch 6: 1.0865139: 100%|██████████| 235/235 [00:16<00:00, 14.38it/s]
lambda 0.1, epoch 7: 1.0600524: 100%|██████████| 235/235 [00:16<00:00, 14.31it/s]
lambda 0.1, epoch 8: 1.035551: 100%|██████████| 235/235 [00:16<00:00, 14.40it/s] 
lambda 0.1, epoch 9: 1.0122865: 100%|██████████| 235/235 [00:16<00:00, 14.50it/s]
lambda 0.1, epoch 10: 0.9895327: 100%|██████████| 235/235 [00:16<00:00, 14.49it/s] 
lambda 0.1, epoch 11: 0.9814167:   1%|          | 2/235 [00:00<00:14, 16.54it/s]

0.9577


lambda 0.1, epoch 11: 0.9668944: 100%|██████████| 235/235 [00:16<00:00, 14.67it/s] 
lambda 0.1, epoch 12: 0.9441993: 100%|██████████| 235/235 [00:16<00:00, 14.30it/s] 
lambda 0.1, epoch 13: 0.9209024: 100%|██████████| 235/235 [00:16<00:00, 14.08it/s] 
lambda 0.1, epoch 14: 0.8969598: 100%|██████████| 235/235 [00:16<00:00, 14.20it/s] 
lambda 0.1, epoch 15: 0.8721289: 100%|██████████| 235/235 [00:16<00:00, 14.12it/s] 
lambda 0.1, epoch 16: 0.86448777:   1%|          | 2/235 [00:00<00:18, 12.48it/s]

0.9645


lambda 0.1, epoch 16: 0.846188: 100%|██████████| 235/235 [00:17<00:00, 13.75it/s]  
lambda 0.1, epoch 17: 0.8192661: 100%|██████████| 235/235 [00:16<00:00, 14.49it/s] 
lambda 0.1, epoch 18: 0.79136133: 100%|██████████| 235/235 [00:16<00:00, 14.26it/s]
lambda 0.1, epoch 19: 0.7622217: 100%|██████████| 235/235 [00:16<00:00, 14.26it/s] 
lambda 0.1, epoch 20: 0.731871: 100%|██████████| 235/235 [00:16<00:00, 14.58it/s]  
lambda 0.1, epoch 21: 0.7213712:   1%|          | 2/235 [00:00<00:13, 16.85it/s] 

0.9672


lambda 0.1, epoch 21: 0.70023495: 100%|██████████| 235/235 [00:16<00:00, 14.23it/s]
lambda 0.1, epoch 22: 0.6675137: 100%|██████████| 235/235 [00:16<00:00, 13.88it/s] 
lambda 0.1, epoch 23: 0.63357687: 100%|██████████| 235/235 [00:16<00:00, 14.56it/s]
lambda 0.1, epoch 24: 0.59875953: 100%|██████████| 235/235 [00:16<00:00, 14.49it/s]
lambda 0.1, epoch 0: 2.40764: 100%|██████████| 235/235 [00:16<00:00, 13.85it/s]  
lambda 0.1, epoch 1: 1.4798211:   0%|          | 0/235 [00:00<?, ?it/s]

0.8255


lambda 0.1, epoch 1: 1.3537283: 100%|██████████| 235/235 [00:16<00:00, 14.37it/s]
lambda 0.1, epoch 2: 1.2549893: 100%|██████████| 235/235 [00:16<00:00, 13.98it/s]
lambda 0.1, epoch 3: 1.2047163: 100%|██████████| 235/235 [00:16<00:00, 14.25it/s]
lambda 0.1, epoch 4: 1.1681274: 100%|██████████| 235/235 [00:16<00:00, 14.49it/s]
lambda 0.1, epoch 5: 1.1377257: 100%|██████████| 235/235 [00:16<00:00, 14.53it/s]
lambda 0.1, epoch 6: 1.1039879:   1%|          | 2/235 [00:00<00:17, 13.37it/s]

0.9336


lambda 0.1, epoch 6: 1.1101797: 100%|██████████| 235/235 [00:16<00:00, 14.62it/s]
lambda 0.1, epoch 7: 1.0838584: 100%|██████████| 235/235 [00:15<00:00, 15.02it/s]
lambda 0.1, epoch 8: 1.0583202: 100%|██████████| 235/235 [00:16<00:00, 14.26it/s]
lambda 0.1, epoch 9: 1.0332156: 100%|██████████| 235/235 [00:16<00:00, 14.16it/s]
lambda 0.1, epoch 10: 1.0079819: 100%|██████████| 235/235 [00:16<00:00, 14.16it/s]
lambda 0.1, epoch 11: 0.9966023:   1%|          | 2/235 [00:00<00:17, 13.28it/s]

0.9497


lambda 0.1, epoch 11: 0.98175544: 100%|██████████| 235/235 [00:16<00:00, 14.29it/s]
lambda 0.1, epoch 12: 0.9555416: 100%|██████████| 235/235 [00:14<00:00, 15.99it/s] 
lambda 0.1, epoch 13: 0.92899895: 100%|██████████| 235/235 [00:14<00:00, 16.37it/s]
lambda 0.1, epoch 14: 0.90211153: 100%|██████████| 235/235 [00:15<00:00, 15.20it/s]
lambda 0.1, epoch 15: 0.8746856: 100%|██████████| 235/235 [00:15<00:00, 15.41it/s] 
lambda 0.1, epoch 16: 0.8682662:   1%|          | 2/235 [00:00<00:18, 12.29it/s]

0.9575


lambda 0.1, epoch 16: 0.8463854: 100%|██████████| 235/235 [00:15<00:00, 15.29it/s] 
lambda 0.1, epoch 17: 0.81732935: 100%|██████████| 235/235 [00:15<00:00, 15.23it/s]
lambda 0.1, epoch 18: 0.78700536: 100%|██████████| 235/235 [00:15<00:00, 15.23it/s]
lambda 0.1, epoch 19: 0.7557075: 100%|██████████| 235/235 [00:15<00:00, 15.42it/s] 
lambda 0.1, epoch 20: 0.7232344: 100%|██████████| 235/235 [00:14<00:00, 15.75it/s] 
lambda 0.1, epoch 21: 0.7159097:   1%|          | 2/235 [00:00<00:15, 14.93it/s]

0.9618


lambda 0.1, epoch 21: 0.6901959: 100%|██████████| 235/235 [00:15<00:00, 15.10it/s] 
lambda 0.1, epoch 22: 0.65631855: 100%|██████████| 235/235 [00:15<00:00, 14.93it/s]
lambda 0.1, epoch 23: 0.6218598: 100%|██████████| 235/235 [00:15<00:00, 15.61it/s] 
lambda 0.1, epoch 24: 0.5868754: 100%|██████████| 235/235 [00:15<00:00, 15.32it/s] 
lambda 0.1, epoch 0: 2.3335667: 100%|██████████| 235/235 [00:15<00:00, 15.19it/s]
lambda 0.1, epoch 1: 1.4534855:   1%|          | 2/235 [00:00<00:13, 17.50it/s]

0.85


lambda 0.1, epoch 1: 1.3399743: 100%|██████████| 235/235 [00:15<00:00, 15.63it/s]
lambda 0.1, epoch 2: 1.2536644: 100%|██████████| 235/235 [00:15<00:00, 15.13it/s]
lambda 0.1, epoch 3: 1.2060128: 100%|██████████| 235/235 [00:14<00:00, 15.78it/s]
lambda 0.1, epoch 4: 1.1696482: 100%|██████████| 235/235 [00:12<00:00, 19.26it/s]
lambda 0.1, epoch 5: 1.1384759: 100%|██████████| 235/235 [00:14<00:00, 16.38it/s]
lambda 0.1, epoch 6: 1.1195592:   1%|          | 2/235 [00:00<00:14, 16.08it/s]

0.9344


lambda 0.1, epoch 6: 1.109891: 100%|██████████| 235/235 [00:15<00:00, 15.41it/s] 
lambda 0.1, epoch 7: 1.0830224: 100%|██████████| 235/235 [00:15<00:00, 15.37it/s]
lambda 0.1, epoch 8: 1.0576893: 100%|██████████| 235/235 [00:15<00:00, 15.57it/s]
lambda 0.1, epoch 9: 1.0333433: 100%|██████████| 235/235 [00:15<00:00, 15.53it/s]
lambda 0.1, epoch 10: 1.0095785: 100%|██████████| 235/235 [00:15<00:00, 15.54it/s]
lambda 0.1, epoch 11: 1.0048176:   1%|          | 2/235 [00:00<00:15, 15.41it/s]

0.949


lambda 0.1, epoch 11: 0.98571986: 100%|██████████| 235/235 [00:15<00:00, 15.49it/s]
lambda 0.1, epoch 12: 0.96155965: 100%|██████████| 235/235 [00:15<00:00, 15.45it/s]
lambda 0.1, epoch 13: 0.9370007: 100%|██████████| 235/235 [00:13<00:00, 16.91it/s] 
lambda 0.1, epoch 14: 0.91195625: 100%|██████████| 235/235 [00:12<00:00, 19.55it/s]
lambda 0.1, epoch 15: 0.8860227: 100%|██████████| 235/235 [00:12<00:00, 19.44it/s] 
lambda 0.1, epoch 16: 0.8914392:   1%|▏         | 3/235 [00:00<00:11, 20.58it/s]

0.9552


lambda 0.1, epoch 16: 0.85906017: 100%|██████████| 235/235 [00:12<00:00, 19.53it/s]
lambda 0.1, epoch 17: 0.8312026: 100%|██████████| 235/235 [00:12<00:00, 19.28it/s] 
lambda 0.1, epoch 18: 0.8022244: 100%|██████████| 235/235 [00:12<00:00, 19.20it/s] 
lambda 0.1, epoch 19: 0.77208877: 100%|██████████| 235/235 [00:12<00:00, 19.13it/s]
lambda 0.1, epoch 20: 0.740906: 100%|██████████| 235/235 [00:12<00:00, 19.11it/s]  
lambda 0.1, epoch 21: 0.74474496:   1%|          | 2/235 [00:00<00:11, 19.46it/s]

0.9601


lambda 0.1, epoch 21: 0.7085375: 100%|██████████| 235/235 [00:12<00:00, 18.94it/s] 
lambda 0.1, epoch 22: 0.67509: 100%|██████████| 235/235 [00:12<00:00, 18.76it/s]   
lambda 0.1, epoch 23: 0.6405741: 100%|██████████| 235/235 [00:12<00:00, 19.38it/s] 
lambda 0.1, epoch 24: 0.605145: 100%|██████████| 235/235 [00:12<00:00, 19.25it/s]  
lambda 0.1, epoch 0: 2.3189878: 100%|██████████| 235/235 [00:12<00:00, 18.93it/s]
lambda 0.1, epoch 1: 1.4094905:   1%|▏         | 3/235 [00:00<00:12, 18.81it/s]

0.8513


lambda 0.1, epoch 1: 1.3420324: 100%|██████████| 235/235 [00:12<00:00, 19.20it/s]
lambda 0.1, epoch 2: 1.2567482: 100%|██████████| 235/235 [00:12<00:00, 19.36it/s]
lambda 0.1, epoch 3: 1.2097579: 100%|██████████| 235/235 [00:12<00:00, 19.43it/s]
lambda 0.1, epoch 4: 1.1739832: 100%|██████████| 235/235 [00:12<00:00, 18.91it/s]
lambda 0.1, epoch 5: 1.1423249: 100%|██████████| 235/235 [00:12<00:00, 19.09it/s]
lambda 0.1, epoch 6: 1.1045866:   1%|▏         | 3/235 [00:00<00:11, 19.80it/s]

0.9309


lambda 0.1, epoch 6: 1.1126746: 100%|██████████| 235/235 [00:12<00:00, 19.38it/s]
lambda 0.1, epoch 7: 1.0839006: 100%|██████████| 235/235 [00:10<00:00, 22.44it/s]
lambda 0.1, epoch 8: 1.0557791: 100%|██████████| 235/235 [00:09<00:00, 25.00it/s]
lambda 0.1, epoch 9: 1.0284296: 100%|██████████| 235/235 [00:09<00:00, 24.94it/s]
lambda 0.1, epoch 10: 1.0015523: 100%|██████████| 235/235 [00:09<00:00, 25.01it/s]
lambda 0.1, epoch 11: 1.0217464:   1%|▏         | 3/235 [00:00<00:09, 25.22it/s] 

0.9491


lambda 0.1, epoch 11: 0.9747648: 100%|██████████| 235/235 [00:09<00:00, 24.81it/s] 
lambda 0.1, epoch 12: 0.9479595: 100%|██████████| 235/235 [00:09<00:00, 24.41it/s] 
lambda 0.1, epoch 13: 0.92105544: 100%|██████████| 235/235 [00:09<00:00, 24.63it/s]
lambda 0.1, epoch 14: 0.8938301: 100%|██████████| 235/235 [00:09<00:00, 24.69it/s] 
lambda 0.1, epoch 15: 0.8661408: 100%|██████████| 235/235 [00:09<00:00, 24.52it/s] 
lambda 0.1, epoch 16: 0.8631466:   1%|▏         | 3/235 [00:00<00:09, 24.53it/s]

0.9582


lambda 0.1, epoch 16: 0.83790594: 100%|██████████| 235/235 [00:09<00:00, 24.91it/s]
lambda 0.1, epoch 17: 0.80913347: 100%|██████████| 235/235 [00:09<00:00, 25.05it/s]
lambda 0.1, epoch 18: 0.7796554: 100%|██████████| 235/235 [00:09<00:00, 24.99it/s] 
lambda 0.1, epoch 19: 0.74933755: 100%|██████████| 235/235 [00:09<00:00, 24.96it/s]
lambda 0.1, epoch 20: 0.718252: 100%|██████████| 235/235 [00:09<00:00, 23.79it/s]  
lambda 0.1, epoch 21: 0.71529436:   1%|▏         | 3/235 [00:00<00:09, 24.43it/s]

0.9616


lambda 0.1, epoch 21: 0.68636805: 100%|██████████| 235/235 [00:09<00:00, 24.89it/s]
lambda 0.1, epoch 22: 0.65361977: 100%|██████████| 235/235 [00:09<00:00, 25.06it/s]
lambda 0.1, epoch 23: 0.6200047: 100%|██████████| 235/235 [00:09<00:00, 24.93it/s] 
lambda 0.1, epoch 24: 0.5859692: 100%|██████████| 235/235 [00:14<00:00, 16.72it/s] 
lambda 0.1, epoch 0: 2.4629362: 100%|██████████| 235/235 [00:12<00:00, 18.24it/s]
lambda 0.1, epoch 1: 1.5162926:   1%|          | 2/235 [00:00<00:11, 19.51it/s]

0.8139


lambda 0.1, epoch 1: 1.3515037: 100%|██████████| 235/235 [00:13<00:00, 17.35it/s]
lambda 0.1, epoch 2: 1.244113: 100%|██████████| 235/235 [00:13<00:00, 17.16it/s] 
lambda 0.1, epoch 3: 1.1915909: 100%|██████████| 235/235 [00:13<00:00, 17.05it/s]
lambda 0.1, epoch 4: 1.152892: 100%|██████████| 235/235 [00:13<00:00, 17.23it/s] 
lambda 0.1, epoch 5: 1.1205134: 100%|██████████| 235/235 [00:13<00:00, 17.17it/s]
lambda 0.1, epoch 6: 1.0996035:   1%|          | 2/235 [00:00<00:13, 16.83it/s]

0.9364


lambda 0.1, epoch 6: 1.0915351: 100%|██████████| 235/235 [00:12<00:00, 18.44it/s]
lambda 0.1, epoch 7: 1.0646558: 100%|██████████| 235/235 [00:13<00:00, 16.95it/s]
lambda 0.1, epoch 8: 1.039037: 100%|██████████| 235/235 [00:13<00:00, 17.16it/s] 
lambda 0.1, epoch 9: 1.0141991: 100%|██████████| 235/235 [00:13<00:00, 16.97it/s]
lambda 0.1, epoch 10: 0.98950416: 100%|██████████| 235/235 [00:14<00:00, 16.59it/s]
lambda 0.1, epoch 11: 0.980022:   1%|          | 2/235 [00:00<00:12, 19.24it/s] 

0.9504


lambda 0.1, epoch 11: 0.9647619: 100%|██████████| 235/235 [00:13<00:00, 16.97it/s] 
lambda 0.1, epoch 12: 0.93949324: 100%|██████████| 235/235 [00:12<00:00, 18.70it/s]
lambda 0.1, epoch 13: 0.9138163: 100%|██████████| 235/235 [00:13<00:00, 16.99it/s] 
lambda 0.1, epoch 14: 0.8874223: 100%|██████████| 235/235 [00:13<00:00, 17.06it/s] 
lambda 0.1, epoch 15: 0.8601624: 100%|██████████| 235/235 [00:13<00:00, 17.34it/s] 
lambda 0.1, epoch 16: 0.8497776:   0%|          | 1/235 [00:00<00:32,  7.19it/s]

0.9562


lambda 0.1, epoch 16: 0.83215797: 100%|██████████| 235/235 [00:16<00:00, 14.65it/s]
lambda 0.1, epoch 17: 0.80341226: 100%|██████████| 235/235 [00:16<00:00, 14.20it/s]
lambda 0.1, epoch 18: 0.77389073: 100%|██████████| 235/235 [00:15<00:00, 14.72it/s]
lambda 0.1, epoch 19: 0.74362403: 100%|██████████| 235/235 [00:16<00:00, 14.26it/s]
lambda 0.1, epoch 20: 0.7124038: 100%|██████████| 235/235 [00:16<00:00, 14.46it/s] 
lambda 0.1, epoch 21: 0.70649177:   1%|          | 2/235 [00:00<00:14, 15.91it/s]

0.9615


lambda 0.1, epoch 21: 0.68036824: 100%|██████████| 235/235 [00:16<00:00, 14.34it/s]
lambda 0.1, epoch 22: 0.64763796: 100%|██████████| 235/235 [00:16<00:00, 14.11it/s]
lambda 0.1, epoch 23: 0.61430806: 100%|██████████| 235/235 [00:16<00:00, 13.96it/s]
lambda 0.1, epoch 24: 0.5801949: 100%|██████████| 235/235 [00:16<00:00, 14.64it/s] 
lambda 1, epoch 0: 1.8217819: 100%|██████████| 235/235 [00:15<00:00, 15.04it/s]
lambda 1, epoch 1: 1.0074003:   1%|          | 2/235 [00:00<00:13, 17.03it/s] 

0.8204


lambda 1, epoch 1: 0.790225: 100%|██████████| 235/235 [00:15<00:00, 14.86it/s]  
lambda 1, epoch 2: 0.70020324: 100%|██████████| 235/235 [00:16<00:00, 14.00it/s]
lambda 1, epoch 3: 0.6533768: 100%|██████████| 235/235 [00:16<00:00, 14.38it/s] 
lambda 1, epoch 4: 0.6179846: 100%|██████████| 235/235 [00:16<00:00, 14.36it/s] 
lambda 1, epoch 5: 0.5876239: 100%|██████████| 235/235 [00:16<00:00, 14.12it/s] 
lambda 1, epoch 6: 0.55309004:   1%|          | 2/235 [00:00<00:17, 13.06it/s]

0.9305


lambda 1, epoch 6: 0.56088233: 100%|██████████| 235/235 [00:16<00:00, 13.89it/s]
lambda 1, epoch 7: 0.536732: 100%|██████████| 235/235 [00:15<00:00, 14.80it/s]  
lambda 1, epoch 8: 0.5145974: 100%|██████████| 235/235 [00:16<00:00, 14.01it/s] 
lambda 1, epoch 9: 0.50045013:  53%|█████▎    | 125/235 [00:09<00:08, 13.37it/s]

In [None]:
net = var_net.VarNet(var_net.VarLayer(784,  hidden_num, 
                                            prior_sigma = prior_sigma, init_log_sigma=init_log_sigma), 
                                 var_net.VarLayer(hidden_num, 10, prior_sigma=prior_sigma,  init_log_sigma=init_log_sigma,
                                          act=lambda x:x)).to(device)    

net = net.to(device)
lam_results = {}
for lam in lamb:
    lam_results[lam] = []
    for s in range(start_num):        
        print (lam, s)
        
        net.load_state_dict(t.load(os.path.join(path_to_save, 'var_net_lam_{}_start_{}.cpk'.format(lam, s))))        
        lam_results[lam].append(utils.delete_10(net, device, lambda:utils.test_acc_net(net, device, test_loader)))
lam_results = {float(k):lam_results[k] for k in lam_results}
import json
with open(os.path.join(path_to_save, 'results_var.json'),'w') as out:
    out.write(json.dumps(lam_results))  

In [None]:
proc = [0,10,20,30,40,50,60,70,80,90]
plt.rcParams['figure.figsize'] = 12, 12
plt.rcParams.update({'font.size': 27})
plt.rc('lines', linewidth=4)
    
    
for lam in lamb:
    lam = float(lam)
    lam_str = "{:10.2f}".format(lam)

    plt.fill_between(proc, np.min(lam_results[lam], 0), np.max(lam_results[lam], 0), alpha=0.2)
    plt.plot(proc, np.mean(lam_results[lam], 0), label='$\lambda={}$'.format(lam_str))
plt.ylabel('Точность классификации', fontsize = 27)
plt.xlabel('Процент удаления', fontsize = 27)
plt.tick_params(axis='both', which='major', labelsize=27)
plt.legend(loc='lower left')
plt.autoscale(enable=True, axis='x', tight=True)
plt.savefig('Var')
#plt.show()
