In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image

from utils_new import mnist, plot_graphs, plot_mnist
import numpy as np
import os 

%matplotlib inline

In [2]:
def to_onehot(x, n, device=None):
    if isinstance(x, np.ndarray):
        x = torch.Tensor(x).to(torch.long)
    one_hot = torch.zeros((x.shape[0], n))
    one_hot.scatter_(1, x[:, None], 1.)
    if device is not None:
        one_hot = one_hot.to(device)
    return one_hot

In [3]:
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

In [4]:
root_folder = 'FC_GAE_results'
fixed_folder = root_folder + '/Fixed_results/'
recon_folder = root_folder + '/Recon_results/'

if os.path.isdir(root_folder):
    !rm -r $root_folder
os.mkdir(root_folder)
os.mkdir(fixed_folder)
os.mkdir(recon_folder)

In [5]:
mnist_tanh = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5,), (0.5,)),
                lambda x: x.to(device)
           ])

In [6]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

lr = 0.0001
prior_size = 10
train_epoch = 1000
batch_size = 250
train_loader, valid_loader, test_loader = mnist(batch_size=batch_size, valid=10000, transform=mnist_tanh)
fixed_z = torch.randn((10, prior_size)).repeat((1,10)).view(-1, prior_size).to(device)
fixed_z_label = to_onehot(torch.tensor(list(range(10))).repeat((10)), 10).to(device)
fixed_data, fixed_label = next(iter(test_loader))
fixed_data = fixed_data[:100].to(device)
fixed_label = to_onehot(fixed_label[:100], 10).to(device)

cpu


In [7]:
data, label = next(iter(train_loader))

In [8]:
class FullyConnected(nn.Module):
    def __init__(self, sizes, dropout=False, activation_fn=nn.Tanh(), flatten=False, 
                 last_fn=None, first_fn=None, device='cpu'):
        super(FullyConnected, self).__init__()
        layers = []
        self.flatten = flatten
        if first_fn is not None:
            layers.append(first_fn)
        for i in range(len(sizes) - 2):
            layers.append(nn.Linear(sizes[i], sizes[i+1]))
            if dropout:
                layers.append(nn.Dropout(dropout))
            layers.append(activation_fn) # нам не нужен дропаут и фнкция активации в последнем слое
        else: 
            layers.append(nn.Linear(sizes[-2], sizes[-1]))
        if last_fn is not None:
            layers.append(last_fn)
        self.model = nn.Sequential(*layers)
        self.to(device)
        
    def forward(self, x, y=None):
        if self.flatten:
            x = x.view(x.shape[0], -1)
        if y is not None:
            x = torch.cat([x, y], dim=1)
        return self.model(x)

In [9]:
Enc = FullyConnected([28*28, 1024, 1024, prior_size], activation_fn=nn.LeakyReLU(0.2), flatten=True, device=device)

Dec = FullyConnected([prior_size, 1024, 1024, 28*28], activation_fn=nn.LeakyReLU(0.2), last_fn=nn.Tanh(), device=device)
Disc = FullyConnected([prior_size, 1024, 1024, 1], dropout=0.3, activation_fn=nn.LeakyReLU(0.2), device=device)

Enc_optimizer = optim.Adam(Enc.parameters(), lr=lr)
Dec_optimizer = optim.Adam(Dec.parameters(), lr=lr)
Disc_optimizer = optim.Adam(Disc.parameters(), lr=lr)

In [10]:
train_log = {'E': [],'AE': [], 'D': []}
test_log = {'E': [],'AE': [], 'D': []}

In [11]:
batch_zeros = torch.zeros((batch_size, 1)).to(device)
batch_ones = torch.ones((batch_size, 1)).to(device)

In [12]:
def train(epoch, Enc, Dec, Disc, log=None):
    train_size = len(train_loader.sampler)
    for batch_idx, (data, label) in enumerate(train_loader):
        label = to_onehot(label, 10, device)
        # train D
        Enc.zero_grad()
        Disc.zero_grad()
        
        z = torch.randn((batch_size, prior_size)).to(device)

        fake_pred = Disc(Enc(data))
        true_pred = Disc(z)

        
        fake_loss = F.binary_cross_entropy_with_logits(fake_pred, batch_zeros)
        true_loss = F.binary_cross_entropy_with_logits(true_pred, batch_ones)
        
        Disc_loss = 0.5*(fake_loss + true_loss)
        
        Disc_loss.backward()
        Disc_optimizer.step()
        # train AE
        Enc.zero_grad()
        Dec.zero_grad()
        Disc.zero_grad()
        
        z = torch.randn((batch_size, prior_size))
        
        latent = Enc(data)
        
        reconstructed = Dec(latent).view(-1, 1, 28, 28)
        fake_pred = Disc(latent)
        
        Enc_loss = F.binary_cross_entropy_with_logits(fake_pred, batch_ones)
        AE_loss = F.mse_loss(reconstructed, data)
        G_loss = AE_loss + Enc_loss
        
        G_loss.backward()
        Dec_optimizer.step()
        Enc_optimizer.step()
            
        if batch_idx % 100 == 0:
            line = 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLosses '.format(
                epoch, batch_idx * len(data), train_size, 100. * batch_idx / len(train_loader))
            losses = 'E: {:.4f}, AE: {:.4f}, D: {:.4f}'.format(Enc_loss.item(), AE_loss.item(), Disc_loss.item())
            print(line + losses)
            
    else:
        batch_idx += 1
        line = 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLosses '.format(
            epoch, batch_idx * len(data), train_size, 100. * batch_idx / len(train_loader))
        losses = 'E: {:.4f}, AE: {:.4f}, D: {:.4f}'.format(Enc_loss.item(), AE_loss.item(), Disc_loss.item())
        print(line + losses)
        log['E'].append(Enc_loss.item())
        log['AE'].append(AE_loss.item())
        log['D'].append(Disc_loss.item())

In [13]:
def test(Enc, Dec, Disc, loader, epoch, log=None):
    test_size = len(loader)
    E_loss = 0.
    AE_loss = 0.
    D_loss = 0.
    test_loss = {'E': 0., 'AE': 0., 'D': 0.}
    with torch.no_grad():
        for data, label in loader:
            label = to_onehot(label, 10, device)
            z = torch.randn((batch_size, prior_size)).to(device)
            latent = Enc(data)
            reconstructed = Dec(latent).view(-1, 1, 28, 28)
            fake_pred = Disc(latent)
            true_pred = Disc(z)
        
            fake_loss = F.binary_cross_entropy_with_logits(fake_pred, batch_zeros).item()
            true_loss = F.binary_cross_entropy_with_logits(true_pred, batch_ones).item()
            
            D_loss += 0.5*(fake_loss + true_loss)
            E_loss += F.binary_cross_entropy_with_logits(fake_pred, batch_ones).item()
            AE_loss += F.mse_loss(reconstructed, data)
            
        E_loss /= test_size
        D_loss /= test_size
        AE_loss /= test_size

        fixed_gen = Dec(fixed_z).cpu().data.numpy().reshape(100, 1, 28, 28)
        plot_mnist(fixed_gen, (10, 10), False, fixed_folder + '%03d.png' % epoch)
        fixed_reconstruction = Dec(Enc(fixed_data)).cpu().data.numpy().reshape(100, 1, 28, 28)
        plot_mnist(fixed_reconstruction, (10, 10), False, recon_folder + '%03d.png' % epoch)
        
    report = 'Test losses. E: {:.4f}, AE: {:.4f}, D: {:.4f}'.format(E_loss, AE_loss, D_loss)
    print(report)

In [None]:
for epoch in range(1, 1001):
    Enc.train()
    Dec.train()
    Disc.train()
    train(epoch, Enc, Dec, Disc, train_log)
    Enc.eval()
    Dec.eval()
    Disc.eval()
    test(Enc, Dec, Disc, valid_loader, epoch, test_log)
    

Test losses. E: 0.8521, AE: 0.2751, D: 1.6651
Test losses. E: 0.6697, AE: 0.2675, D: 0.9060
Test losses. E: 0.4575, AE: 0.2564, D: 0.8212
Test losses. E: 0.5427, AE: 0.2475, D: 0.7885
Test losses. E: 0.6095, AE: 0.2427, D: 0.7983
Test losses. E: 0.6595, AE: 0.2474, D: 0.6731
Test losses. E: 0.5748, AE: 0.2299, D: 0.7250
Test losses. E: 0.6857, AE: 0.2259, D: 0.7719
Test losses. E: 0.5814, AE: 0.2230, D: 0.7267
Test losses. E: 0.4892, AE: 0.2087, D: 0.7930
Test losses. E: 0.5805, AE: 0.1866, D: 0.7268
Test losses. E: 0.6610, AE: 0.1887, D: 0.7519
Test losses. E: 0.6841, AE: 0.1738, D: 0.7526
Test losses. E: 0.6213, AE: 0.1637, D: 0.7295
Test losses. E: 0.6450, AE: 0.1509, D: 0.6925
Test losses. E: 0.7121, AE: 0.1523, D: 0.7256
Test losses. E: 2.2350, AE: 0.3299, D: 0.5578
Test losses. E: 0.5932, AE: 0.1533, D: 0.6856
Test losses. E: 0.6788, AE: 0.1413, D: 0.7015
Test losses. E: 0.7395, AE: 0.1582, D: 0.7611
Test losses. E: 0.5273, AE: 0.2389, D: 0.8080
Test losses. E: 0.6379, AE: 0.1507

Test losses. E: 0.6817, AE: 0.1169, D: 0.6974
Test losses. E: 0.6385, AE: 0.1223, D: 0.7255
Test losses. E: 0.6245, AE: 0.1122, D: 0.7058
Test losses. E: 0.7167, AE: 0.1157, D: 0.6883
Test losses. E: 0.6504, AE: 0.1155, D: 0.7010
Test losses. E: 0.7504, AE: 0.1139, D: 0.6976
Test losses. E: 0.6873, AE: 0.1121, D: 0.7048
Test losses. E: 0.6414, AE: 0.1078, D: 0.6970
Test losses. E: 0.7022, AE: 0.1274, D: 0.7399
Test losses. E: 0.7073, AE: 0.1279, D: 0.7208
Test losses. E: 0.6695, AE: 0.1126, D: 0.7066
Test losses. E: 0.6400, AE: 0.1013, D: 0.7046
Test losses. E: 0.7390, AE: 0.1108, D: 0.6919
Test losses. E: 0.7518, AE: 0.1174, D: 0.6878
Test losses. E: 0.6748, AE: 0.0954, D: 0.6957
Test losses. E: 0.6066, AE: 0.1030, D: 0.7091
Test losses. E: 0.7110, AE: 0.0946, D: 0.6948
Test losses. E: 0.9517, AE: 0.2169, D: 0.6242
Test losses. E: 0.6708, AE: 0.1033, D: 0.6896
Test losses. E: 0.7590, AE: 0.1096, D: 0.6958
Test losses. E: 0.6636, AE: 0.1108, D: 0.7102
Test losses. E: 0.6952, AE: 0.1058

Test losses. E: 0.6716, AE: 0.0954, D: 0.6976
Test losses. E: 0.6981, AE: 0.0939, D: 0.6856
Test losses. E: 0.7105, AE: 0.0900, D: 0.6982
Test losses. E: 0.6216, AE: 0.0930, D: 0.7142
Test losses. E: 0.6371, AE: 0.0888, D: 0.7036
Test losses. E: 0.6248, AE: 0.0879, D: 0.6904
Test losses. E: 0.7021, AE: 0.0929, D: 0.7073
Test losses. E: 0.7806, AE: 0.0924, D: 0.6845
Test losses. E: 0.7062, AE: 0.0818, D: 0.6952
Test losses. E: 0.6888, AE: 0.0932, D: 0.6938
Test losses. E: 0.6868, AE: 0.0917, D: 0.7104
Test losses. E: 0.6829, AE: 0.0993, D: 0.6905
Test losses. E: 0.7303, AE: 0.0873, D: 0.7010
Test losses. E: 0.7135, AE: 0.0880, D: 0.6966
Test losses. E: 0.6714, AE: 0.0820, D: 0.6980
Test losses. E: 0.7035, AE: 0.0845, D: 0.7013
Test losses. E: 0.6795, AE: 0.0813, D: 0.6999
Test losses. E: 0.6699, AE: 0.0789, D: 0.6957
Test losses. E: 0.7550, AE: 0.0955, D: 0.7117
Test losses. E: 0.6753, AE: 0.0825, D: 0.6908
Test losses. E: 0.7614, AE: 0.0911, D: 0.7036
