In [2]:
import os
import sys
import random

import tqdm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.cuda
import torch.utils
import torch.random
import torch.optim
import warnings
warnings.filterwarnings("ignore")

import torchvision.transforms.transforms as tvff
import torchvision.datasets as tvd
import torchvision.utils as tvu

device = torch.device('cuda:7')

In [3]:
torch.cuda.is_available()

True

In [4]:
data = tvd.MNIST(root='./data', train=True, download=True, transform=None)
xxx, yyy = data.train_data, data.train_labels

xxx.unsqueeze_(1)
yyy.unsqueeze_(1)

nb_digits = 10
yyy_onehot = torch.FloatTensor(yyy.shape[0], nb_digits)
yyy_onehot.zero_()
yyy_onehot.scatter_(1, yyy, 1)
yyy = yyy_onehot

xxx, yyy = xxx.type(torch.float32) / 255, yyy.type(torch.float32)

#x, y = x.to(device), y.to(device)

In [5]:
batch_size = 6000
load_index = 0

def get_data():
    global batch_size, load_index, xxx, yyy, device
    if load_index + batch_size > len(xxx):
        load_index = 0
        #perm = torch.randperm(len(xxx))
        #xxx, yyy = xxx[perm], yyy[perm]
    
    data_xx, data_yy = xxx[load_index:load_index+batch_size], yyy[load_index:load_index+batch_size]
    load_index += batch_size
    
    data_xx, data_yy = data_xx.to(device), data_yy.to(device)
    return data_xx, data_yy

In [18]:
u = lambda x: torch.nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
s = torch.nn.functional.sigmoid
r = torch.nn.functional.relu
p = torch.nn.MaxPool2d(2)
pad1 = torch.nn.ZeroPad2d(1)
N_FEATURES = 9

class Encoder(torch.nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        
        self.ec1 = torch.nn.Conv2d(1, 4, kernel_size=5, padding=2)
        self.ec2 = torch.nn.Conv2d(4, 4, kernel_size=5, padding=2)
        
        self.ec3 = torch.nn.Conv2d(4, 8, kernel_size=5, padding=2)
        self.ec4 = torch.nn.Conv2d(8, 16, kernel_size=5, padding=2)
        
        self.ec5 = torch.nn.Conv2d(16, 32, kernel_size=5, padding=2)
        self.ec6 = torch.nn.Conv2d(32, 32, kernel_size=5, padding=2)
        
        self.fc1 = torch.nn.Linear(32 * 3 * 3 + 10, 32 * 3)
        self.fc2 = torch.nn.Linear(32 * 3, 32 * 3)
        self.fc_mu = torch.nn.Linear(32 * 3, N_FEATURES)
        self.fc_logvar = torch.nn.Linear(32 * 3, N_FEATURES)
        
    def forward(self, x, label):
        global u, s, r, p, pad1
        
        x = r(self.ec1(x))
        x = r(self.ec2(x))
        x = p(x)

        x = r(self.ec3(x))
        x = r(self.ec4(x))
        x = p(x)
        
        x = r(self.ec5(x))
        x = r(self.ec6(x))
        x = p(x)
        
        x = x.reshape(-1, 32 * 3 * 3)
        x = r(self.fc1(torch.cat((x, label), dim=1)))
        x = r(self.fc2(x))
        
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        
        return mu, logvar
    

class Decoder(torch.nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.d_in = torch.nn.Conv2d(1, 16, kernel_size=5, padding=2)
        
        self.dc1 = torch.nn.Conv2d(16, 16, kernel_size=5, padding=2)
        self.dc2 = torch.nn.Conv2d(16, 8, kernel_size=5, padding=2)                                 
                                         
        self.dc3 = torch.nn.Conv2d(8, 8, kernel_size=5, padding=2)
        self.dc4 = torch.nn.Conv2d(8, 4, kernel_size=5, padding=2)     
        
        self.dc5 = torch.nn.Conv2d(4, 4, kernel_size=5, padding=2)
        self.dc6 = torch.nn.Conv2d(4, 1, kernel_size=5, padding=2)
        
        self.fc3 = torch.nn.Linear(10 + N_FEATURES, 32 * 3)
        self.fc4 = torch.nn.Linear(32 * 3, 3 * 3 * 16)
        
    def forward(self, z, label):
        global u, s, r, p, pad1
        z = torch.cat((z, label), dim=1)
        z = r(self.fc3(z))
        z = r(self.fc4(z))
        z = z.reshape(-1, 16, 3, 3)
        
        z = u(z)
        z = r(self.dc1(z))
        z = r(self.dc2(z))
        
        z = u(z)
        z = pad1(z)
        z = r(self.dc3(z))
        z = r(self.dc4(z))
        
        z = u(z)
        z = r(self.dc5(z))
        z = s(self.dc6(z))
        
        return z
        
class Discriminator(torch.nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.c0 = torch.nn.Conv2d(1, 8, kernel_size=7, stride=1, padding=3)
        self.c1 = torch.nn.Conv2d(8, 8, kernel_size=7, stride=2, padding=3) # 14
        self.c2 = torch.nn.Conv2d(8, 16, kernel_size=7, stride=2, padding=3) # 7
        self.c3 = torch.nn.Conv2d(16, 32, kernel_size=3, padding=1) # 3
        self.fc1 = torch.nn.Linear(32 * 3 * 3, 16 * 3)
        self.fc2 = torch.nn.Linear(16 * 3, 1)

    def forward(self, x):
        x = self.c0(x)
        x = self.c1(x)
        x = self.c2(x)
        
        x_features = s(x.reshape(-1, 16 * 7 * 7))
        
        x = p(r(self.c3(x)))
        #x_features = s(x.reshape(-1, 16 * 3 * 3))
        
        x = r(self.fc1(x.reshape(-1, 32 * 3 * 3)))
        x = r(self.fc1(x.reshape(-1, 32 * 3 * 3)))
        x = s(self.fc2(x))
        return x_features, x

def reparametrize(mu, logvar):
    std = torch.exp(0.5*logvar)
    eps = torch.randn_like(std)
    return mu + eps*std
        

In [19]:
enc = Encoder().to(device)
dec = Decoder().to(device)
dis = Discriminator().to(device)

enc_optimizer = torch.optim.Adam(enc.parameters(), 0.0001)
dec_optimizer = torch.optim.Adam(dec.parameters(), 0.0001)
#dis_optimizer = torch.optim.Adam(dis.parameters(), 0.0001)

In [20]:
def train_encoder(x_data, y_data):
    global enc, dec, dis, enc_optimizer, dec_optimizer, dis_optimizer
    enc_optimizer.zero_grad()
    dec_optimizer.zero_grad()
    dis_optimizer.zero_grad()

    mu, logvar = enc(x_data, y_data)
    code = reparametrize(mu, logvar)
    x_recovered = dec(code, y_data)
    
    x_features, dis_result = dis(x_data)
    x_recovered_features, dis_recovered_result = dis(x_recovered)
    
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    BCE = (x_recovered_features - x_features).pow(2).sum() 
    
    #torch.nn.functional.binary_cross_entropy(x_recovered, x_data, reduction='sum')
    
    loss_enc = KLD + BCE
    loss_enc.backward()
    
    enc_optimizer.step()
    
    return loss_enc.item()
    
def train_decoder(x_data, y_data):
    global enc, dec, dis, enc_optimizer, dec_optimizer, dis_optimizer
    enc_optimizer.zero_grad()
    dec_optimizer.zero_grad()
    dis_optimizer.zero_grad()

    mu, logvar = enc(x_data, y_data)
    code = reparametrize(mu, logvar)
    
    
    x_recovered = dec(code, y_data)
    x_features, x_dis = dis(x_data)
    
    x_recovered_features, x_recovered_dis = dis(x_recovered)
    
    #KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    BCE = (x_recovered_features - x_features).pow(2).sum()
    
    x_sampled = dec(mu, y_data)
    x_sampled_features, x_sampled_dis = dis(x_sampled)
    
    GAN = torch.nn.functional.binary_cross_entropy(x_dis, torch.zeros_like(x_dis)) + \
        torch.nn.functional.binary_cross_entropy(x_recovered_dis, torch.ones_like(x_recovered_dis)) + \
        torch.nn.functional.binary_cross_entropy(x_sampled_dis, torch.ones_like(x_sampled_dis))
        
    GAN = GAN.sum()
    
    loss_dec = BCE - GAN
    loss_dec.backward()
    dec_optimizer.step()
    
    return loss_dec.item()
    
def train_discriminator(x_data, y_data):
    global enc, dec, enc_optimizer, dec_optimizer
    enc_optimizer.zero_grad()
    dec_optimizer.zero_grad()
    dis_optimizer.zero_grad()

    mu, logvar = enc(x_data, y_data)
    code = reparametrize(mu, logvar)
    x_recovered = dec(code, y_data)
    
    x_features, x_dis = dis(x_data)
    x_recovered_features, x_recovered_dis = dis(x_recovered)
    
    x_sampled = dec(mu, y_data)
    x_sampled_features, x_sampled_dis = dis(x_sampled)
    
    GAN = torch.nn.functional.binary_cross_entropy(x_dis, torch.zeros_like(x_dis)) + \
        torch.nn.functional.binary_cross_entropy(x_recovered_dis, torch.ones_like(x_recovered_dis)) + \
        torch.nn.functional.binary_cross_entropy(x_sampled_dis, torch.ones_like(x_sampled_dis))
        
    GAN = GAN.sum()
    
    loss_dis = GAN + (x_features.pow(2).sum(dim=-1) - 1).abs().sum() + \
            (x_sampled_features.pow(2).sum(dim=-1) - 1).abs().sum() + \
            (x_recovered_features.pow(2).sum(dim=-1) - 1).abs().sum()
    
    loss_dis.backward()
    dis_optimizer.step()
    
    return loss_dis.item()


In [21]:

def train(x_data, y_data):
    global enc, dec, dis, enc_optimizer, dec_optimizer, dis_optimizer
    enc_optimizer.zero_grad()
    dec_optimizer.zero_grad()
    dis_optimizer.zero_grad()

    mu, logvar = enc(x_data, y_data)
    code = reparametrize(mu, logvar)
       
    x_features, x_dis = dis(x_data)
    
    x_recovered = dec(code, y_data)
    x_recovered_features, x_recovered_dis = dis(x_recovered)
    
    x_sampled = dec(mu, y_data)
    x_sampled_features, x_sampled_dis = dis(x_sampled)
    
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    BCE = (x_recovered_features - x_features).pow(2).sum()
    GAN = torch.nn.functional.binary_cross_entropy(x_dis, torch.zeros_like(x_dis), reduction='sum')
    
    loss_enc = KLD + BCE
    loss_enc.backward(retain_graph=True)
    enc_optimizer.step()
    
    loss_dec = 20 * BCE - GAN
    loss_dec.backward(retain_graph=True)
    dec_optimizer.step()
    
    loss_dis = GAN
    loss_dis.backward(retain_graph=True)
    dis_optimizer.step()
    
    
    return loss_enc.item(), loss_dec.item(), loss_dis.item()

In [None]:
enc = torch.load('enc.hht').to(device)
dec = torch.load('dec.hht').to(device)
#dis = torch.load('dis.hht').to(device)
enc_optimizer = torch.optim.Adam(enc.parameters(), 0.0001)
dec_optimizer = torch.optim.Adam(dec.parameters(), 0.0001)
dis_optimizer = torch.optim.Adam(dis.parameters(), 0.0001, weight_decay=0.001)

for _ in range(5000):
    le = []
    lc = [] 
    ld = []
    
    for __ in range(15):
        #for i in range(6):
        #    data_x, data_y = get_data()
        #    l = train_encoder(data_x, data_y)
        #    le.append(l)

        #for i in range(6):
        #    data_x, data_y = get_data()
        #    l = train_decoder(data_x, data_y)
        #    ld.append(l)

        #for i in range(1):
        #    data_x, data_y = get_data()
        #    l = train_discriminator(data_x, data_y)
        #    lc.append(l)
            
        for i in range(5):
            data_x, data_y = get_data()
            le1, ld1, lc1 = train(data_x, data_y)
            le.append(le1)
            ld.append(ld1)
            lc.append(lc1)
        
    print('Enc:', sum(le))   
    print('Dec:', sum(ld))    
    print('Dis:', sum(lc))
    print('-' * 30)
    
    torch.save(enc, 'enc.hht')
    torch.save(dec, 'dec.hht')
    torch.save(dis, 'dis.hht')

Enc: 14021.179809570312
Dec: 227188.51348876953
Dis: 53234.84069824219
------------------------------
Enc: 3469.7980041503906
Dec: 16406.199279785156
Dis: 52989.652770996094
------------------------------
Enc: 1755.807071685791
Dec: -17628.837432861328
Dis: 52744.79150390625
------------------------------
Enc: 1078.4487962722778
Dec: -30962.611206054688
Dis: 52531.0
------------------------------
Enc: 733.2361268997192
Dec: -37695.15899658203
Dis: 52359.29296875
------------------------------
Enc: 530.6094303131104
Dec: -41618.39685058594
Dis: 52230.36785888672
------------------------------
Enc: 401.9226984977722
Dec: -44100.33923339844
Dis: 52138.68212890625
------------------------------
Enc: 314.45177483558655
Dec: -45787.81524658203
Dis: 52076.81555175781
------------------------------
Enc: 252.65333461761475
Dec: -46984.27746582031
Dis: 52037.209411621094
------------------------------
Enc: 206.98565459251404
Dec: -47873.682189941406
Dis: 52013.23858642578
-----------------------

In [None]:
a, b = get_data()
a.shape, b.shape
pred = dec(enc(a[:100], b[:100])[0], b[:100])

In [None]:
i = np.random.randint(0, 99)
plt.imshow(a[i, 0, :, :].detach().cpu())
plt.show()
plt.imshow(pred[i, 0, :, :].detach().cpu())
plt.show()