In [9]:
import torch as t
import torchvision
import torchvision.transforms as transforms
import numpy as np
import torch.nn.functional as F
import matplotlib.pylab as plt
from torch.nn.utils import clip_grad_value_
%matplotlib inline

from torch.autograd import Variable
import torch.nn as nn

import argparse

import torch.optim as optim

from primary_net import PrimaryNetwork

from torchvision import datasets
import tqdm
import os
import json

In [10]:
device = 'cuda' # cuda or cpu
device = t.device(device)
if device == 'cuda':
    t.backends.cudnn.deterministic = True
    t.backends.cudnn.benchmark = False

In [11]:
batch_size = 128
prior_sigma = 1.0 # априорная дисперсия
epoch_num = 25 #количество эпох
lamb = [0.01, 0.1, 1,  10, 100]
start_num = 5

lambda_encode = lambda x : (t.log(x) + 4.6052)/(4.6052+ 4.6052)
lambda_sample_num = 5
path_to_save = 'saved_cifar_2'

if not os.path.exists(path_to_save):
    os.mkdir(path_to_save)
    
learning_rate = 0.002
weight_decay = 0.0005
milestones = [168000, 336000, 400000, 450000, 550000, 600000]
max_iter = 1000000


In [12]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='../data', train=True,
                                        download=True, transform=transform_train)
trainloader = t.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=True, num_workers=4)

testset = torchvision.datasets.CIFAR10(root='../data', train=False,
                                       download=True, transform=transform_test)
testloader = t.utils.data.DataLoader(testset, batch_size=128,
                                         shuffle=False, num_workers=4)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [13]:
lamb = [0.01, 0.1, 1,  10, 100]

In [14]:
def test_acc(net): # точность классификации
    acc = []    
    net.eval()
    lamb =  [0.01, 0.1, 1,  10, 100]
    for l in lamb:
        correct = 0
        for x,y in testloader: 
            x = x.to(device)
            y = y.to(device)  
            out = net(x,l)    
            correct += out.argmax(1).eq(y).sum().cpu().numpy()
            t.cuda.empty_cache()
        acc.append(correct / len(testset))
        t.cuda.empty_cache()
    return acc


In [15]:
def train_batches(net, loss_fn, optimizer, lam, label):
    tq = tqdm.tqdm(trainloader)
    losses = []
    for x,y in tq:            
        x = x.to(device)
        y = y.to(device)          
        optimizer.zero_grad()  
        loss = 0
        if lam is None:
            for _ in range(lambda_sample_num):  
                p = t.rand(1).to(device)*4 -2
                lam_param = 10**p[0]                
                #t.rand(1).to(device)[0]*100.0                  
                out = net(x, lambda_encode(lam_param))
                loss = loss + loss_fn(out, y)/lambda_sample_num
                loss += net.KLD(lambda_encode(lam_param))*lam_param/len(trainset)/lambda_sample_num
                #loss += net.KLD(lam_param)*t.log(lam_param)/len(trainset)/lambda_sample_num
                losses+=[loss.cpu().detach().numpy()]       
        tq.set_description(label+str(np.mean(losses)))
        loss.backward()       
        clip_grad_value_(net.parameters(), 1.0) # для стабильности градиента. С этим можно играться
        optimizer.step()
        #lr_scheduler.step()
    print(test_acc(net))

In [16]:
t.manual_seed(0)
for start in range(start_num):         
    net = PrimaryNetwork(prior_sigma = prior_sigma, device = device)
    net = net.to(device)
    optim = t.optim.Adam(net.parameters(), lr=1e-4)
    #lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=milestones, gamma=0.5)
    loss_fn = nn.CrossEntropyLoss().to(device)            
    for e in range(epoch_num):
        label = 'CIFAR, epoch {}: '.format(e)                
        train_batches(net, loss_fn, optim, None, label)
        t.save(net.state_dict(), os.path.join(path_to_save, 'cifar_epoch_{}.cpk'.format( e)))
    t.save(net.state_dict(), os.path.join(path_to_save, 'cifar_start_{}.cpk'.format( start)))

CIFAR, epoch 0: 1.7295864: 100%|██████████| 391/391 [19:35<00:00,  3.01s/it]
  0%|          | 0/391 [00:00<?, ?it/s]

[0.1, 0.1, 0.1, 0.1, 0.1]


CIFAR, epoch 1: 1.4357768: 100%|██████████| 391/391 [16:31<00:00,  2.54s/it]
  0%|          | 0/391 [00:00<?, ?it/s]

[0.1, 0.1, 0.0954, 0.1, 0.1]


CIFAR, epoch 2: 1.4291575: 100%|██████████| 391/391 [16:29<00:00,  2.53s/it]
  0%|          | 0/391 [00:00<?, ?it/s]

[0.1002, 0.1003, 0.0908, 0.0944, 0.1]


CIFAR, epoch 3: 1.4253792: 100%|██████████| 391/391 [16:32<00:00,  2.54s/it]
  0%|          | 0/391 [00:00<?, ?it/s]

[0.1365, 0.1387, 0.1253, 0.0978, 0.1]


CIFAR, epoch 4: 1.4235057: 100%|██████████| 391/391 [16:31<00:00,  2.54s/it]
  0%|          | 0/391 [00:00<?, ?it/s]

[0.186, 0.1859, 0.1578, 0.1168, 0.1]


CIFAR, epoch 5: 1.4156984: 100%|██████████| 391/391 [16:30<00:00,  2.53s/it]
  0%|          | 0/391 [00:00<?, ?it/s]

[0.1779, 0.1786, 0.185, 0.1575, 0.1]


CIFAR, epoch 6: 1.3722352: 100%|██████████| 391/391 [16:31<00:00,  2.53s/it]
  0%|          | 0/391 [00:00<?, ?it/s]

[0.1909, 0.1939, 0.1875, 0.1, 0.1]


CIFAR, epoch 7: 1.3140529: 100%|██████████| 391/391 [16:31<00:00,  2.54s/it]
  0%|          | 0/391 [00:00<?, ?it/s]

[0.2115, 0.211, 0.203, 0.1705, 0.1]


CIFAR, epoch 8: 1.293697: 100%|██████████| 391/391 [16:42<00:00,  2.57s/it] 
  0%|          | 0/391 [00:00<?, ?it/s]

[0.2204, 0.2206, 0.2083, 0.1675, 0.1]


CIFAR, epoch 9: 1.2840931: 100%|██████████| 391/391 [17:18<00:00,  2.66s/it]
  0%|          | 0/391 [00:00<?, ?it/s]

[0.2258, 0.2234, 0.219, 0.1437, 0.1]


CIFAR, epoch 10: 1.2755562: 100%|██████████| 391/391 [17:03<00:00,  2.62s/it]
  0%|          | 0/391 [00:00<?, ?it/s]

[0.2322, 0.2333, 0.2318, 0.1394, 0.1]


CIFAR, epoch 11: 1.2637212: 100%|██████████| 391/391 [17:01<00:00,  2.61s/it]
  0%|          | 0/391 [00:00<?, ?it/s]

[0.2388, 0.2411, 0.2408, 0.102, 0.1]


CIFAR, epoch 12: 1.2509215: 100%|██████████| 391/391 [16:35<00:00,  2.55s/it]
  0%|          | 0/391 [00:00<?, ?it/s]

[0.2521, 0.2526, 0.2474, 0.1009, 0.1107]


CIFAR, epoch 13: 1.2378803: 100%|██████████| 391/391 [16:45<00:00,  2.57s/it]
  0%|          | 0/391 [00:00<?, ?it/s]

[0.2485, 0.2476, 0.2487, 0.1191, 0.1]


CIFAR, epoch 14: 1.2252408: 100%|██████████| 391/391 [17:11<00:00,  2.64s/it]
  0%|          | 0/391 [00:00<?, ?it/s]

[0.2547, 0.254, 0.2603, 0.1318, 0.1121]


CIFAR, epoch 15: 1.212877: 100%|██████████| 391/391 [17:00<00:00,  2.61s/it] 
  0%|          | 0/391 [00:00<?, ?it/s]

[0.2576, 0.2599, 0.2658, 0.1305, 0.12]


CIFAR, epoch 16: 1.2038865: 100%|██████████| 391/391 [16:41<00:00,  2.56s/it]
  0%|          | 0/391 [00:00<?, ?it/s]

[0.2695, 0.2704, 0.2843, 0.1191, 0.1452]


CIFAR, epoch 17: 1.1946598: 100%|██████████| 391/391 [17:23<00:00,  2.67s/it]
  0%|          | 0/391 [00:00<?, ?it/s]

[0.2686, 0.2716, 0.2799, 0.1219, 0.0968]


CIFAR, epoch 18: 1.1861304: 100%|██████████| 391/391 [17:06<00:00,  2.62s/it]
  0%|          | 0/391 [00:00<?, ?it/s]

[0.2707, 0.27, 0.2844, 0.1137, 0.0965]


CIFAR, epoch 19: 1.1777681: 100%|██████████| 391/391 [16:35<00:00,  2.55s/it]
  0%|          | 0/391 [00:00<?, ?it/s]

[0.2667, 0.2725, 0.2841, 0.1248, 0.1062]


CIFAR, epoch 20: 1.1706154: 100%|██████████| 391/391 [16:40<00:00,  2.56s/it]
  0%|          | 0/391 [00:00<?, ?it/s]

[0.2701, 0.2779, 0.298, 0.1005, 0.1026]


CIFAR, epoch 21: 1.1615381: 100%|██████████| 391/391 [16:36<00:00,  2.55s/it]
  0%|          | 0/391 [00:00<?, ?it/s]

[0.2868, 0.2886, 0.3072, 0.1042, 0.1059]


CIFAR, epoch 22: 1.1555077: 100%|██████████| 391/391 [16:38<00:00,  2.55s/it]
  0%|          | 0/391 [00:00<?, ?it/s]

[0.2781, 0.282, 0.3045, 0.1005, 0.1033]


CIFAR, epoch 23: 1.1454769: 100%|██████████| 391/391 [16:38<00:00,  2.55s/it]
  0%|          | 0/391 [00:00<?, ?it/s]

[0.2812, 0.2862, 0.307, 0.1002, 0.0949]


CIFAR, epoch 24: 1.1364912: 100%|██████████| 391/391 [16:39<00:00,  2.56s/it]
  0%|          | 0/391 [00:00<?, ?it/s]

[0.29, 0.2975, 0.3179, 0.1051, 0.111]


CIFAR, epoch 0: 1.822396:  58%|█████▊    | 227/391 [10:57<07:55,  2.90s/it] 


KeyboardInterrupt: 

In [None]:
t.cuda.empty_cache()