In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

from utils import mnist, plot_graphs, plot_mnist
import numpy as np

%matplotlib inline

In [3]:
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

In [4]:
train_loader, valid_loader, test_loader = mnist(valid=10000, transform=transforms.ToTensor())

In [5]:
class ConvLayer(nn.Module):
    def __init__(self, size, padding=1, pool_layer=nn.MaxPool2d(2, stride=2),
                 bn=False, dropout=False, activation_fn=nn.ReLU(), stride=1):
        super(ConvLayer, self).__init__()
        layers = []
        layers.append(nn.Conv2d(size[0], size[1], size[2], padding=padding, stride=stride))
        if pool_layer is not None:
            layers.append(pool_layer)
        if bn:
            layers.append(nn.BatchNorm2d(size[1]))
        if dropout:
            layers.append(nn.Dropout2d())
        layers.append(activation_fn)
        
        self.model = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.model(x)

In [6]:
class FullyConnected(nn.Module):
    def __init__(self, sizes, dropout=False, activation_fn=nn.Tanh):
        super(FullyConnected, self).__init__()
        layers = []
        
        for i in range(len(sizes) - 2):
            layers.append(nn.Linear(sizes[i], sizes[i+1]))
            if dropout:
                layers.append(nn.Dropout())
            layers.append(activation_fn())
        else: # нам не нужен дропаут и фнкция активации в последнем слое
            layers.append(nn.Linear(sizes[-2], sizes[-1]))
        
        self.model = nn.Sequential(*layers)
        
    def forward(self, x):
        return self.model(x)

In [7]:
class Net(nn.Module):
    def __init__(self, batchnorm=False, dropout=False, optim_type='SGD', **optim_params):
        super(Net, self).__init__()
        
        self._conv1 = ConvLayer([1, 16, 4], padding=0, bn=batchnorm, stride=2, 
                                pool_layer=None, activation_fn=nn.Tanh())
        self._conv2 = ConvLayer([16, 32, 4], padding=0, bn=batchnorm, stride=2, 
                                pool_layer=None, activation_fn=nn.Tanh())
        self._conv3 = ConvLayer([32, 32, 3], padding=0, bn=batchnorm, stride=2, 
                                pool_layer=None, activation_fn=nn.Tanh())
            
        self.fc1 = FullyConnected([32*2*2, 32])
        self.fc2 = FullyConnected([32, 32])
        self.fc3 = FullyConnected([32, 32])
        self.fc4 = FullyConnected([32, 10])
        
        self._loss = None
        if optim_type == 'SGD':
            self.optim = optim.SGD(self.parameters(), **optim_params)
        elif optim_type == 'Adadelta':
            self.optim = optim.Adadelta(self.parameters(), **optim_params)
        elif optim_type == 'RMSProp':
            self.optim = optim.RMSprop(self.parameters(), **optim_params)
        elif optim_type == 'Adam':
            self.optim = optim.Adam(self.parameters(), **optim_params)
        
    def conv(self, x):
        l1 = self._conv1(x)
        l2 = self._conv2(l1)
        l3 = self._conv3(l2)
        return l3, l2, l1
        
    def forward(self, x):
        l3 = self.conv(x)[0]
        flatten = l3.view(-1, 32*2*2)
        x = self.fc1(flatten)
        x = self.fc2(x)
        x = self.fc3(x)
        h = self.fc4(x)
        return h
    
    def loss(self, output, target):
        self._loss = F.cross_entropy(output, target)
        return self._loss

In [33]:
device = torch.device("cpu") #"cuda:0" if torch.cuda.is_available() else 
print(device)
lr = 0.0001
prior_size = 10

class FC(nn.Module):
    def __init__(self, sizes, dropout=False, activation_fn=nn.Tanh(), flatten=False, 
                 last_fn=None, first_fn=None, device='cpu'):
        super(FC, self).__init__()
        layers = []
        self.flatten = flatten
        if first_fn is not None:
            layers.append(first_fn)
        for i in range(len(sizes) - 2):
            layers.append(nn.Linear(sizes[i], sizes[i+1]))
            if dropout:
                layers.append(nn.Dropout(dropout))
            layers.append(activation_fn) # нам не нужен дропаут и фнкция активации в последнем слое
        else: 
            layers.append(nn.Linear(sizes[-2], sizes[-1]))
        if last_fn is not None:
            layers.append(last_fn)
        self.model = nn.Sequential(*layers)
        self.to(device)
        
    def forward(self, x, y=None):
        if self.flatten:
            x = x.view(x.shape[0], -1)
        if y is not None:
            x = torch.cat([x, y], dim=1)
        return self.model(x)

caae_Enc = FC([28*28, 1024, 1024, prior_size], activation_fn=nn.LeakyReLU(0.2), flatten=True, device=device)
caae_Enc.load_state_dict(torch.load("./model_Enc_caae"))
caae_Enc.eval()

aaec_Enc = FC([28*28, 1024, 1024, prior_size], activation_fn=nn.LeakyReLU(0.2), flatten=True, device=device)
aaec_Enc.load_state_dict(torch.load("./model_Enc_aaec"))
aaec_Enc.eval()

cpu


FC(
  (model): Sequential(
    (0): Linear(in_features=784, out_features=1024, bias=True)
    (1): LeakyReLU(negative_slope=0.2)
    (2): Linear(in_features=1024, out_features=1024, bias=True)
    (3): LeakyReLU(negative_slope=0.2)
    (4): Linear(in_features=1024, out_features=10, bias=True)
  )
)

In [39]:
class CAAE_Clf(nn.Module):
    def __init__(self, conditional=True, batchnorm=False, dropout=False, optim_type='SGD', **optim_params):
        super(CAAE_Clf, self).__init__()
        self.conditional = conditional
        self.fc1 = FullyConnected([10, 32])
        self.fc2 = FullyConnected([32, 32])
        self.fc3 = FullyConnected([32, 64])
        self.fc4 = FullyConnected([64, 10])
        
        self._loss = None
        if optim_type == 'SGD':
            self.optim = optim.SGD(self.parameters(), **optim_params)
        elif optim_type == 'Adadelta':
            self.optim = optim.Adadelta(self.parameters(), **optim_params)
        elif optim_type == 'RMSProp':
            self.optim = optim.RMSprop(self.parameters(), **optim_params)
        elif optim_type == 'Adam':
            self.optim = optim.Adam(self.parameters(), **optim_params)
        
    def forward(self, x):
        if self.conditional:
            x = self.fc1(caae_Enc(x))
        else:
            x = self.fc1(aaec_Enc(x))

        x = self.fc2(x)
        x = self.fc3(x)
        h = self.fc4(x)
        return h
    
    def loss(self, output, target):
        self._loss = F.cross_entropy(output, target)
        return self._loss

In [40]:
models = {'Clf_Img': Net(False, False, optim_type='Adam', lr=1e-4), 
          'CAAE_clf': CAAE_Clf(True, False, False, optim_type='Adam', lr=1e-4, ),
          'AAEC_clf': CAAE_Clf(False, False, False, optim_type='Adam', lr=1e-4),
         }
train_log = {k: [] for k in models}
test_log = {k: [] for k in models}

In [41]:
def train(epoch, models, log=None):
    train_size = len(train_loader.sampler)
    for batch_idx, (data, target) in enumerate(train_loader):
        for model in models.values():
            model.optim.zero_grad()
            output = model(data)
            loss = model.loss(output, target)
            loss.backward()
            model.optim.step()
            
        if batch_idx % 200 == 0:
            line = 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLosses '.format(
                epoch, batch_idx * len(data), train_size, 100. * batch_idx / len(train_loader))
            losses = ' '.join(['{}: {:.4f}'.format(k, m._loss.item()) for k, m in models.items()])
            print(line + losses)
            
    else:
        batch_idx += 1
        line = 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLosses '.format(
            epoch, batch_idx * len(data), train_size, 100. * batch_idx / len(train_loader))
        losses = ' '.join(['{}: {:.4f}'.format(k, m._loss.item()) for k, m in models.items()])
        if log is not None:
            for k in models:
                log[k].append((models[k]._loss,))
        print(line + losses)

In [42]:
def test(models, loader, log=None):
    test_size = len(loader)
    test_loss = {k: 0. for k in models}
    with torch.no_grad():
        for data, target in loader:
            output = {k: m(data) for k, m in models.items()}
            for k, m in models.items():
                test_loss[k] += m.loss(output[k], target).item() # sum up batch loss
    
    for k in models:
        test_loss[k] /= test_size
    report = 'Test losses: ' + ' '.join(['{}: {:.4f}'.format(k, test_loss[k]) for k in test_loss])
    if log is not None:
        for k in models:
            log[k].append((test_loss[k],))
    print(report)

In [43]:
for epoch in range(1, 31):
    for model in models.values():
        model.train()
    train(epoch, models, train_log)
    for model in models.values():
        model.eval()
    test(models, valid_loader, test_log)

Test losses: Clf_Img: 0.4613 CAAE_clf: 1.9714 AAEC_clf: 1.3775
Test losses: Clf_Img: 0.3126 CAAE_clf: 1.8719 AAEC_clf: 1.0289
Test losses: Clf_Img: 0.2413 CAAE_clf: 1.8463 AAEC_clf: 0.9634
Test losses: Clf_Img: 0.1963 CAAE_clf: 1.8320 AAEC_clf: 0.9437
Test losses: Clf_Img: 0.1653 CAAE_clf: 1.8205 AAEC_clf: 0.9318
Test losses: Clf_Img: 0.1417 CAAE_clf: 1.8176 AAEC_clf: 0.9204
Test losses: Clf_Img: 0.1274 CAAE_clf: 1.8134 AAEC_clf: 0.9111
Test losses: Clf_Img: 0.1120 CAAE_clf: 1.8116 AAEC_clf: 0.9046
Test losses: Clf_Img: 0.1032 CAAE_clf: 1.8099 AAEC_clf: 0.8993
Test losses: Clf_Img: 0.0951 CAAE_clf: 1.8103 AAEC_clf: 0.8992
Test losses: Clf_Img: 0.0889 CAAE_clf: 1.8097 AAEC_clf: 0.8995
Test losses: Clf_Img: 0.0885 CAAE_clf: 1.8087 AAEC_clf: 0.8989
Test losses: Clf_Img: 0.0820 CAAE_clf: 1.8082 AAEC_clf: 0.8983


Test losses: Clf_Img: 0.0779 CAAE_clf: 1.8093 AAEC_clf: 0.8970
Test losses: Clf_Img: 0.0789 CAAE_clf: 1.8091 AAEC_clf: 0.9002
Test losses: Clf_Img: 0.0724 CAAE_clf: 1.8085 AAEC_clf: 0.8968
Test losses: Clf_Img: 0.0719 CAAE_clf: 1.8069 AAEC_clf: 0.8986
Test losses: Clf_Img: 0.0684 CAAE_clf: 1.8074 AAEC_clf: 0.8983
Test losses: Clf_Img: 0.0669 CAAE_clf: 1.8086 AAEC_clf: 0.8975
Test losses: Clf_Img: 0.0670 CAAE_clf: 1.8076 AAEC_clf: 0.8976
Test losses: Clf_Img: 0.0678 CAAE_clf: 1.8071 AAEC_clf: 0.8973
Test losses: Clf_Img: 0.0658 CAAE_clf: 1.8062 AAEC_clf: 0.8972
Test losses: Clf_Img: 0.0658 CAAE_clf: 1.8065 AAEC_clf: 0.8964
Test losses: Clf_Img: 0.0671 CAAE_clf: 1.8065 AAEC_clf: 0.8964
Test losses: Clf_Img: 0.0665 CAAE_clf: 1.8071 AAEC_clf: 0.8969
Test losses: Clf_Img: 0.0638 CAAE_clf: 1.8066 AAEC_clf: 0.8971


Test losses: Clf_Img: 0.0666 CAAE_clf: 1.8057 AAEC_clf: 0.8969
Test losses: Clf_Img: 0.0659 CAAE_clf: 1.8065 AAEC_clf: 0.8972
Test losses: Clf_Img: 0.0663 CAAE_clf: 1.8076 AAEC_clf: 0.8997
Test losses: Clf_Img: 0.0660 CAAE_clf: 1.8065 AAEC_clf: 0.8971
