In [1]:
import torch as t 
import torchvision
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pylab as plt
from torch.nn.utils import clip_grad_value_
%matplotlib inline
import pickle
from torchvision import datasets, transforms
import tqdm
import os
from VarConv import VarConvNet

In [2]:
epoch_num = 50
batch_size = 128
learning_rate = 0.001
num_workers = 4
start_num = 1
lambda_sample_num = 5
path_to_save = 'saved_cifar_new'
init_log_sigma = -5.0 # логарифм дисперсии вариационного распределения при инициализации
prior_sigma = 1.0
lambda_encode = t.log

    
if not os.path.exists(path_to_save):
    os.mkdir(path_to_save)

In [3]:
device = 'cuda' # cuda or cpu
device = t.device(device)
if device == 'cuda':
    t.backends.cudnn.deterministic = True
    t.backends.cudnn.benchmark = False

In [4]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

cifar_trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
train_loader = t.utils.data.DataLoader(cifar_trainset, batch_size=batch_size,shuffle=True, num_workers=num_workers)

cifar_testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
test_loader = t.utils.data.DataLoader(cifar_testset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
def test_acc(net): # точность классификации
    acc = []    
    net.eval()
    lamb =  [0.01, 0.1, 1,  10, 100]
    for l in lamb:
        correct = 0
        for x,y in test_loader: 
            x = x.to(device)
            y = y.to(device)  
            out = net(x,l)    
            correct += out.argmax(1).eq(y).sum().cpu().numpy()
        acc.append(correct / len(cifar_testset))
    net.train()
    return acc

In [6]:
def train_batches(net, loss_fn, optimizer, lam, label):
    tq = tqdm.tqdm(train_loader)
    losses = []
    for x,y in tq:            
        x = x.to(device)
        y = y.to(device)          
        optimizer.zero_grad()  
        loss = 0
        if lam is None:
            
            for _ in range(lambda_sample_num):  
                p = t.rand(1).to(device)*4 -2
                lam_param = 10**p[0]                
                #t.rand(1).to(device)[0]*100.0                  
                out = net(x, lambda_encode(lam_param))
                loss = loss + loss_fn(out, y)/lambda_sample_num
                loss += net.KLD(lambda_encode(lam_param))*lam_param/len(cifar_trainset)/lambda_sample_num
                #loss += net.KLD(lam_param)*t.log(lam_param)/len(trainset)/lambda_sample_num
                losses+=[loss.cpu().detach().numpy()]       
        tq.set_description(label+str(np.mean(losses)))
        loss.backward()       
        clip_grad_value_(net.parameters(), 1.0) # для стабильности градиента. С этим можно играться
        optimizer.step()
        
        #lr_scheduler.step()
    acc = test_acc(net)
    print (acc)
    return acc

In [None]:
t.manual_seed(0)
for start in range(start_num): 
    net = VarConvNet(init_log_sigma, prior_sigma)
    net.to(device)
    optim = t.optim.Adam(net.parameters(), lr=0.001)
    #lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=milestones, gamma=0.5)
    loss_fn = nn.CrossEntropyLoss().to(device) 
    for e in range(epoch_num):
        label = 'CIFAR, epoch {}: '.format(e)                
        acc = train_batches(net, loss_fn, optim, None, label)
        with open('acc.log', 'a') as out:
            out.write('{}:{}\n'.format(e, acc))
        t.save(net.state_dict(), os.path.join(path_to_save, 'cifar_epoch_{}.cpk'.format( e)))
    t.save(net.state_dict(), os.path.join(path_to_save, 'cifar_start_{}.cpk'.format( start)))

CIFAR, epoch 0: 14274552000.0: 100%|██████████| 391/391 [02:42<00:00,  2.40it/s]


[0.1, 0.1, 0.1, 0.1, 0.1]


CIFAR, epoch 1: 543.8427: 100%|██████████| 391/391 [02:42<00:00,  2.40it/s] 


[0.1, 0.1, 0.1, 0.1, 0.1]


CIFAR, epoch 2: 858.6892: 100%|██████████| 391/391 [02:41<00:00,  2.42it/s]  
  0%|          | 0/391 [00:00<?, ?it/s]

[0.1, 0.1, 0.1, 0.1, 0.1]


CIFAR, epoch 3: 1647.1409: 100%|██████████| 391/391 [02:46<00:00,  2.35it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[0.1, 0.1, 0.1, 0.1, 0.1]


CIFAR, epoch 4: 130793416.0: 100%|██████████| 391/391 [02:49<00:00,  2.31it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[0.1, 0.1, 0.1, 0.1, 0.1]


CIFAR, epoch 5: 1068642.1: 100%|██████████| 391/391 [02:45<00:00,  2.36it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[0.1, 0.1, 0.0997, 0.0995, 0.0995]


CIFAR, epoch 6: 2596475100.0: 100%|██████████| 391/391 [02:43<00:00,  2.39it/s] 
  0%|          | 0/391 [00:00<?, ?it/s]

[0.1, 0.1, 0.1, 0.1, 0.1]


CIFAR, epoch 7: 93436530.0: 100%|██████████| 391/391 [02:43<00:00,  2.40it/s] 
  0%|          | 0/391 [00:00<?, ?it/s]

[0.1, 0.1, 0.1, 0.1, 0.1]


CIFAR, epoch 8: 710451.5: 100%|██████████| 391/391 [02:42<00:00,  2.40it/s] 


[0.1, 0.1, 0.1, 0.1, 0.1]


CIFAR, epoch 9: 170800910.0:  80%|████████  | 313/391 [02:12<00:32,  2.39it/s]