In [14]:
import torch as t 
import torchvision
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pylab as plt
from torch.nn.utils import clip_grad_value_
%matplotlib inline
import pickle
from torchvision import datasets, transforms
import tqdm
import os
from importlib import reload

import VarConv
reload(VarConv)
VarConvNet = VarConv.VarConvNet

In [147]:
epoch_num = 50
batch_size = 128
learning_rate = 0.001
num_workers = 4
start_num = 1
lambda_sample_num = 5
path_to_save = 'saved_cifar_new'
init_log_sigma = -5.0 # логарифм дисперсии вариационного распределения при инициализации
prior_sigma = 1.0
lambda_encode = lambda x: (x-1e-3)/(10.0-1e-3)
    
if not os.path.exists(path_to_save):
    os.mkdir(path_to_save)

In [5]:
device = 'cuda' # cuda or cpu
device = t.device(device)
if device == 'cuda':
    t.backends.cudnn.deterministic = True
    t.backends.cudnn.benchmark = False

In [6]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

cifar_trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
train_loader = t.utils.data.DataLoader(cifar_trainset, batch_size=batch_size,shuffle=True, num_workers=num_workers)

cifar_testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
test_loader = t.utils.data.DataLoader(cifar_testset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

Files already downloaded and verified
Files already downloaded and verified


In [164]:
def test_acc(net): # точность классификации
    acc = []    
    net.eval()
    lamb =  [ 0.001, 0.01, 0.1, 1.0, 10.0]
    for l in lamb:
        correct = 0
        for x,y in test_loader: 
            x = x.to(device)
            y = y.to(device)
            #out = net(x[:,0].view(-1, 32*32)) 
            out = net(x,lambda_encode(t.tensor(l)))    
            correct += out.argmax(1).eq(y).sum().cpu().numpy()
        acc.append(correct / len(cifar_testset))
    net.train()
    return acc

In [167]:
def train_batches(net, loss_fn, optimizer, lam, label, e):
    tq = tqdm.tqdm(train_loader)
    losses = []
    for x,y in tq:            
        x = x.to(device)
        y = y.to(device)          
        optimizer.zero_grad()  
        loss = 0
        if lam is None:
            
            for _ in range(lambda_sample_num):  
                p = t.rand(1).to(device)*4 -3
                #p =  t.rand(1) * 2 - 1
                lam_param = 10**p[0]
                
                #lam_param = t.tensor(1.0)
                #t.rand(1).to(device)[0]*100.0                  
                out = net(x, lambda_encode(lam_param))
                #out = net(x)
                loss = loss + loss_fn(out, y)/lambda_sample_num
                loss += net.KLD(lambda_encode(lam_param))*lam_param/len(cifar_trainset)/lambda_sample_num
                #loss += net.KLD(lam_param)*t.log(lam_param)/len(trainset)/lambda_sample_num
                losses+=[loss.cpu().detach().numpy()]    
                
        tq.set_description(label+str(np.mean(losses)))
       
        loss.backward()       
      
        clip_grad_value_(net.parameters(), 1.0) # для стабильности градиента. С этим можно играться
        #print (losses)
        
        optimizer.step()
        
        #lr_scheduler.step()
    acc = test_acc(net)
    print (acc)
    return acc

In [166]:
import VarConv
reload(VarConv)
VarConvNet = VarConv.VarConvNet
t.manual_seed(0)
for start in range(start_num): 
    
    net = VarConvNet(init_log_sigma, prior_sigma)
    
    net.to(device)
   
    optim = t.optim.Adam(net.parameters())
    
    #lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=milestones, gamma=0.5)
    loss_fn = nn.CrossEntropyLoss().to(device) 
    for e in range(epoch_num):
        label = 'CIFAR, epoch {}: '.format(e)                
        acc = train_batches(net, loss_fn, optim, None, label, e )
        with open('acc.log', 'a') as out:
            out.write('{}:{}\n'.format(e, acc))
        t.save(net.state_dict(), os.path.join(path_to_save, 'cifar_epoch_{}.cpk'.format( e)))
    t.save(net.state_dict(), os.path.join(path_to_save, 'cifar_start_{}.cpk'.format( start)))

CIFAR, epoch 0: 32.96936: 100%|██████████| 391/391 [00:42<00:00,  9.21it/s] 


[0.4928, 0.4928, 0.4934, 0.4935, 0.4827]


CIFAR, epoch 1: 30.74869: 100%|██████████| 391/391 [00:43<00:00,  9.04it/s] 


[0.5853, 0.5851, 0.586, 0.5884, 0.5149]


CIFAR, epoch 2: 25.651733: 100%|██████████| 391/391 [00:43<00:00,  8.93it/s]


[0.6782, 0.6782, 0.6783, 0.6795, 0.6181]


CIFAR, epoch 3: 22.726408: 100%|██████████| 391/391 [00:43<00:00,  8.89it/s]


[0.7303, 0.7305, 0.7308, 0.7282, 0.6755]


CIFAR, epoch 4: 19.142748: 100%|██████████| 391/391 [00:43<00:00,  8.89it/s]


[0.7484, 0.7481, 0.7476, 0.7432, 0.6525]


CIFAR, epoch 5: 16.720543: 100%|██████████| 391/391 [00:44<00:00,  8.89it/s]


[0.7627, 0.7627, 0.7626, 0.761, 0.6156]


CIFAR, epoch 6: 13.504285: 100%|██████████| 391/391 [00:43<00:00,  8.89it/s] 


[0.7764, 0.7765, 0.7766, 0.7755, 0.6745]


CIFAR, epoch 7: 11.649849: 100%|██████████| 391/391 [00:43<00:00,  8.90it/s] 


[0.7635, 0.7635, 0.7643, 0.7635, 0.4992]


CIFAR, epoch 8: 10.956531: 100%|██████████| 391/391 [00:43<00:00,  8.91it/s] 


[0.7799, 0.7797, 0.7797, 0.7774, 0.4795]


CIFAR, epoch 9: 9.085499: 100%|██████████| 391/391 [00:43<00:00,  8.90it/s] 


[0.7726, 0.7724, 0.7718, 0.7736, 0.3975]


CIFAR, epoch 10: 8.72101: 100%|██████████| 391/391 [00:43<00:00,  8.90it/s]  


[0.7995, 0.7997, 0.8007, 0.7956, 0.5568]


CIFAR, epoch 11: 8.427563: 100%|██████████| 391/391 [00:43<00:00,  8.91it/s] 


[0.7983, 0.7985, 0.7979, 0.7953, 0.5668]


CIFAR, epoch 12: 7.8971252: 100%|██████████| 391/391 [00:43<00:00,  8.90it/s]


[0.8059, 0.8061, 0.805, 0.7996, 0.4532]


CIFAR, epoch 13: 7.6077447: 100%|██████████| 391/391 [00:43<00:00,  8.91it/s]


[0.8131, 0.8132, 0.8137, 0.8069, 0.4876]


CIFAR, epoch 14: 7.962111: 100%|██████████| 391/391 [00:43<00:00,  8.92it/s] 


[0.7964, 0.7963, 0.7964, 0.7819, 0.2304]


CIFAR, epoch 15: 7.4142814: 100%|██████████| 391/391 [00:43<00:00,  8.91it/s]


[0.8118, 0.8117, 0.8112, 0.8053, 0.4254]


CIFAR, epoch 16: 7.276651: 100%|██████████| 391/391 [00:43<00:00,  8.91it/s] 


[0.7946, 0.7946, 0.7949, 0.7813, 0.2442]


CIFAR, epoch 17: 7.1815906: 100%|██████████| 391/391 [00:43<00:00,  8.91it/s]


[0.8208, 0.8208, 0.8208, 0.8149, 0.4639]


CIFAR, epoch 18: 6.800963: 100%|██████████| 391/391 [00:43<00:00,  8.91it/s] 


[0.8188, 0.8185, 0.819, 0.8156, 0.5224]


CIFAR, epoch 19: 6.6920366: 100%|██████████| 391/391 [00:43<00:00,  8.91it/s]


[0.8195, 0.8195, 0.8195, 0.8165, 0.4703]


CIFAR, epoch 20: 6.4193344: 100%|██████████| 391/391 [00:43<00:00,  8.91it/s]


[0.8223, 0.8221, 0.8222, 0.8181, 0.4967]


CIFAR, epoch 21: 6.2673116: 100%|██████████| 391/391 [00:43<00:00,  8.91it/s]


[0.8221, 0.8221, 0.8215, 0.8156, 0.4863]


CIFAR, epoch 22: 6.612328: 100%|██████████| 391/391 [00:43<00:00,  8.91it/s] 


[0.8223, 0.8223, 0.8233, 0.8181, 0.5124]


CIFAR, epoch 23: 6.1337385: 100%|██████████| 391/391 [00:43<00:00,  8.91it/s]


[0.8301, 0.8301, 0.8302, 0.8255, 0.4887]


CIFAR, epoch 24: 6.3903594: 100%|██████████| 391/391 [00:43<00:00,  8.91it/s]


[0.8271, 0.8271, 0.8268, 0.82, 0.3242]


CIFAR, epoch 25: 5.974559: 100%|██████████| 391/391 [00:43<00:00,  8.91it/s] 


[0.8316, 0.8316, 0.8314, 0.8263, 0.4641]


CIFAR, epoch 26: 5.950825: 100%|██████████| 391/391 [00:43<00:00,  8.90it/s] 


[0.8302, 0.83, 0.8305, 0.8251, 0.3925]


CIFAR, epoch 27: 5.88078: 100%|██████████| 391/391 [00:43<00:00,  8.91it/s]  


[0.832, 0.8321, 0.8305, 0.8236, 0.5115]


CIFAR, epoch 28: 5.5054264: 100%|██████████| 391/391 [00:43<00:00,  8.90it/s]


[0.8374, 0.8374, 0.8372, 0.8312, 0.4365]


CIFAR, epoch 29: 5.4162936: 100%|██████████| 391/391 [00:43<00:00,  8.91it/s]


[0.8401, 0.8397, 0.8392, 0.8315, 0.4007]


CIFAR, epoch 30: 5.27608: 100%|██████████| 391/391 [00:43<00:00,  8.91it/s]  
  0%|          | 0/391 [00:00<?, ?it/s]

[0.8372, 0.8371, 0.8376, 0.8303, 0.3612]


CIFAR, epoch 31: 5.2516885: 100%|██████████| 391/391 [00:43<00:00,  8.91it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[0.8401, 0.8396, 0.84, 0.8344, 0.3932]


CIFAR, epoch 32: 4.8745036: 100%|██████████| 391/391 [00:43<00:00,  8.91it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[0.839, 0.839, 0.8379, 0.8297, 0.4721]


CIFAR, epoch 33: 5.829136: 100%|██████████| 391/391 [00:43<00:00,  8.90it/s] 
  0%|          | 0/391 [00:00<?, ?it/s]

[0.8125, 0.8124, 0.8115, 0.7979, 0.2684]


CIFAR, epoch 34: 7.474493: 100%|██████████| 391/391 [00:43<00:00,  8.93it/s]  
  0%|          | 0/391 [00:00<?, ?it/s]

[0.7648, 0.7647, 0.7631, 0.728, 0.1214]


CIFAR, epoch 35: 5.1755276: 100%|██████████| 391/391 [00:43<00:00,  8.92it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[0.8109, 0.8106, 0.8084, 0.7912, 0.1255]


CIFAR, epoch 36: 5.5299344: 100%|██████████| 391/391 [00:43<00:00,  8.92it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[0.8194, 0.8189, 0.8176, 0.805, 0.1183]


CIFAR, epoch 37: 4.9714336: 100%|██████████| 391/391 [00:43<00:00,  8.92it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[0.824, 0.8238, 0.8225, 0.8133, 0.1591]


CIFAR, epoch 38: 5.0338254: 100%|██████████| 391/391 [00:43<00:00,  8.92it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[0.8279, 0.8278, 0.8275, 0.8185, 0.248]


CIFAR, epoch 39: 5.294804: 100%|██████████| 391/391 [00:43<00:00,  8.92it/s] 
  0%|          | 0/391 [00:00<?, ?it/s]

[0.8125, 0.8124, 0.8099, 0.7927, 0.222]


CIFAR, epoch 40: 5.161281: 100%|██████████| 391/391 [00:43<00:00,  8.92it/s] 
  0%|          | 0/391 [00:00<?, ?it/s]

[0.8282, 0.828, 0.8275, 0.8178, 0.2991]


CIFAR, epoch 41: 5.0857916: 100%|██████████| 391/391 [00:43<00:00,  8.91it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[0.8377, 0.8377, 0.8368, 0.8284, 0.2769]


CIFAR, epoch 42: 5.4114423: 100%|██████████| 391/391 [00:43<00:00,  8.92it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[0.8358, 0.8356, 0.8362, 0.8311, 0.3201]


CIFAR, epoch 43: 5.440544: 100%|██████████| 391/391 [00:43<00:00,  8.92it/s] 
  0%|          | 0/391 [00:00<?, ?it/s]

[0.8394, 0.8394, 0.8388, 0.8322, 0.3374]


CIFAR, epoch 44: 5.0332975: 100%|██████████| 391/391 [00:43<00:00,  8.91it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[0.8477, 0.8477, 0.8461, 0.8347, 0.3138]


CIFAR, epoch 45: 4.7409844: 100%|██████████| 391/391 [00:43<00:00,  8.91it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[0.8436, 0.8437, 0.8436, 0.8388, 0.327]


CIFAR, epoch 46: 4.9325247: 100%|██████████| 391/391 [00:43<00:00,  8.91it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[0.8458, 0.8458, 0.8452, 0.8384, 0.2058]


CIFAR, epoch 47: 5.033571: 100%|██████████| 391/391 [00:43<00:00,  8.91it/s] 
  0%|          | 0/391 [00:00<?, ?it/s]

[0.8408, 0.8412, 0.8412, 0.8324, 0.2964]


CIFAR, epoch 48: 4.6173706: 100%|██████████| 391/391 [00:43<00:00,  8.91it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[0.8502, 0.8502, 0.8495, 0.8425, 0.3107]


CIFAR, epoch 49: 4.663518: 100%|██████████| 391/391 [00:43<00:00,  8.90it/s] 


[0.8467, 0.8468, 0.8462, 0.836, 0.2864]
