In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

from utils import mnist, plot_graphs, plot_mnist, to_onehot
import numpy as np
import os 

%matplotlib inline

In [2]:
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

In [3]:
mnist_tanh = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5,), (0.5,)),
                lambda x: x.to(device)
           ])

In [69]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

lr = 0.0001
prior_size = 10
train_epoch = 200
batch_size = 250

train_loader, valid_loader, test_loader = mnist(batch_size=batch_size, valid=10000, transform=mnist_tanh)

cpu


In [70]:
class FullyConnected(nn.Module):
    def __init__(self, sizes, dropout=False, activation_fn=nn.Tanh(), flatten=False, 
                 last_fn=None, first_fn=None, device='cpu'):
        super(FullyConnected, self).__init__()
        layers = []
        self.flatten = flatten
        if first_fn is not None:
            layers.append(first_fn)
        for i in range(len(sizes) - 2):
            layers.append(nn.Linear(sizes[i], sizes[i+1]))
            if dropout:
                layers.append(nn.Dropout(dropout))
            layers.append(activation_fn) # нам не нужен дропаут и фнкция активации в последнем слое
        else: 
            layers.append(nn.Linear(sizes[-2], sizes[-1]))
        if last_fn is not None:
            layers.append(last_fn)
        self.model = nn.Sequential(*layers)
        self.to(device)
        
    def forward(self, x):
        if self.flatten:
            x = x.view(x.shape[0], -1)
        out = self.model(x)
        return out

In [71]:
class Classifier():
    def __init__(self, train_loader, valid_loader, test_loader):
        self.net = FullyConnected([28*28, 1024, 1024, prior_size], activation_fn=nn.ReLU(), flatten=True, device=device)
        self.optimizer = optim.Adam(self.net.parameters(), lr=lr)
        self.criterion = nn.CrossEntropyLoss()
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.test_loader = test_loader
        self.loss = 0.0
        self.loss_log = []
        self.acc_log = []
        
    def train(self, epoch):
        train_size = len(self.train_loader.sampler)        
        for batch_idx, (data, label) in enumerate(self.train_loader):
            label = label.to(device)
            
            # train
            self.optimizer.zero_grad()              
            output = self.net(data)       
            self.loss = self.criterion(output, label)

            self.loss.backward()
            self.optimizer.step()

            if batch_idx % 100 == 0:
                line = 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLosses '.format(
                    epoch, batch_idx * len(data), train_size, 100. * batch_idx / len(self.train_loader))
                losses = '{:.4f}'.format(self.loss.item())
                print(line + losses)

        else:
            batch_idx += 1
            line = 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLosses '.format(
                epoch, batch_idx * len(data), train_size, 100. * batch_idx / len(self.train_loader))
            losses = '{:.4f}'.format(self.loss.item())
            print(line + losses)            
        
    def validation(self, epoch):
        val_size = len(self.valid_loader)
        loss = 0.
        total = 0
        correct = 0        
        with torch.no_grad():
            for data, label in self.valid_loader:
                label = label.to(device)
                output = self.net(data)
                loss += self.criterion(output, label)

                _, predicted = torch.max(output.data, 1)
                total += label.size(0)
                correct += (predicted == label).sum().item()

            loss /= val_size       
            
        accuracy = 100 * correct / total

        report = 'Valid loss: {:.4f} Accuracy: {:.2f} \n'.format(loss, accuracy)
        print(report)    
        
        self.loss_log.append(loss.item())
        self.acc_log.append(accuracy)

    def start_training(self, train_epoch):
        for epoch in range(1, train_epoch + 1):
            self.net.train()
            self.train(epoch)
            self.net.eval()   
            self.validation(epoch)
            
    def test(self):
        test_size = len(self.test_loader)
        loss = 0.
        total = 0
        correct = 0        
        with torch.no_grad():
            for data, label in self.test_loader:
                label = label.to(device)
                output = self.net(data)
                loss += self.criterion(output, label)

                _, predicted = torch.max(output.data, 1)
                total += label.size(0)
                correct += (predicted == label).sum().item()

            loss /= test_size       
            
        accuracy = 100 * correct / total

        report = 'Test loss: {:.4f} Accuracy: {:.2f} \n'.format(loss, accuracy)
        print(report)    
     


### Классификатор на исходных данных

In [72]:
cls_raw = Classifier(train_loader, valid_loader, test_loader)
cls_raw.start_training(1)

Valid loss: 0.3306 Accuracy: 90.29 



In [73]:
print (cls_raw.loss_log)
print (cls_raw.acc_log)

[0.33062952756881714]
[90.29]


In [74]:
cls_raw.test()

Test loss: 0.3141 Accuracy: 91.06 



### Классификатор на данных CAAE

In [75]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, ptype='train', path=''):
        self.type = ptype 
        self.data = torch.randn((50000, 1, 28, 28))
        self.labels = torch.ones((50000)).to(torch.long)
    
    def __len__(self):
        return len(str(self.data))

    def getData(self):
        return self.data
    
    def sampler(self):
        return self.data

    def __getitem__(self, index):        
        data, labels  = self.data[index], self.labels[index]        
        return data, labels

In [76]:
caae_train = CustomDataset('train', './caae_train_data')
caae_valid = CustomDataset('valid', './caae_valid_data')
caae_test = CustomDataset('test', './caae_test_data')


In [77]:
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler

caae_train_indices = list(range(caae_train.data.shape[0]))
caae_valid_indices = list(range(caae_valid.data.shape[0]))
caae_test_indices = list(range(caae_test.data.shape[0]))

caae_train_loader = DataLoader(caae_train, batch_size=250, sampler=SubsetRandomSampler(caae_train_indices))
caae_valid_loader = DataLoader(caae_valid, batch_size=250, sampler=SubsetRandomSampler(caae_valid_indices))
caae_test_loader = DataLoader(caae_test, batch_size=250, sampler=SubsetRandomSampler(caae_test_indices))

In [78]:
cls_caae = Classifier(caae_train_loader, caae_valid_loader, caae_test_loader)
cls_caae.start_training(1)

Valid loss: 0.0007 Accuracy: 100.00 



In [None]:
cls_caae.test()