In [None]:
!wget http://www.cs.columbia.edu/CAVE/databases/pubfig/download/lfw_attributes.txt
!wget http://vis-www.cs.umass.edu/lfw/lfw-deepfunneled.tgz
!wget http://vis-www.cs.umass.edu/lfw/lfw.tgz

In [None]:
import os
import numpy as np
from sklearn.model_selection import train_test_split
from lfw_dataset import load_lfw_dataset
%matplotlib inline
import matplotlib.pyplot as plt
#import download_utils
import numpy as np
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

## Utils and parameters

In [None]:
DEVICE = torch.device('cuda:0' if torch.cuda.is_available else 'cpu')
GAN_CODE_SIZE = 256

In [None]:
def init_weights(m):
    if isinstance(m, nn.Conv2d):
        nn.init.xavier_uniform_(m.weight)
    elif isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.ConvTranspose2d):
        nn.init.xavier_uniform_(m.weight)

def sample_noise_batch(bsize):
    return torch.randn(bsize, GAN_CODE_SIZE).to(DEVICE)

def sample_data_batch(bsize, data):
    idxs = np.random.choice(np.arange(data.shape[0]), size=bsize)
    ind_to_show = np.random.choice(bsize)
    tensor_data = torch.Tensor(data[idxs].transpose((0,3,1,2)))
    return tensor_data.to(DEVICE)

def sample_images(nrow, ncol, images, data, sharp=False):
    if np.var(images)!=0:
        images = images.clip(np.min(data),np.max(data))
    for i in range(nrow*ncol):
        plt.subplot(nrow,ncol,i+1)
        if sharp:
            plt.imshow(images[i].reshape(data.shape[1:]),cmap="gray", interpolation="none")
        else:
            plt.imshow(images[i].reshape(data.shape[1:]),cmap="gray")
    plt.show()

def sample_probas(desc_real_prediction, desc_gen_prediction):
    plt.title('Generated vs real data')
    plt.hist(np.exp(desc_real_prediction)[:,1],
             label='D(x)', alpha=0.5,range=[0,1])
    plt.hist(np.exp(desc_gen_prediction)[:,1],
             label='D(G(z))',alpha=0.5,range=[0,1])
    plt.legend(loc='best')
    plt.show()

def show_results(test_image_tensor, reconstructed_tensor):
    image = test_image_tensor.numpy().transpose((1,2,0)) # detach? cpu?
    reconstructed = reconstructed_tensor.cpu().detach().numpy().squeeze().transpose((1,2,0))

    plt.subplot(1,2,1)
    plt.title("Original")
    plt.imshow(np.clip(image + 0.5, 0, 1))

    plt.subplot(1,2,2)
    plt.title("Reconstructed")
    plt.imshow(np.clip(reconstructed + 0.5, 0, 1))
    plt.show()

##### Common convolutional block 
class ConvBlock(nn.Module):
    def __init__(self, inp_depth=32, first=False):
        super(ConvBlock, self).__init__()
        out_depth = 32 if first else inp_depth*2
        self.conv = nn.Conv2d(inp_depth, out_depth, 3, padding=1)
        self.max_pool = nn.MaxPool2d(3, stride=2, padding=1)
        self.elu = nn.ELU()
    
    def forward(self, x):
        out = self.conv(x)
        out = self.elu(out)
        return self.max_pool(out)

##Autoencoder

### Model 

In [None]:
# parameters:
NROF_ITERS = 4

class TranspConvBlock(nn.Module):
    def __init__(self, inp_depth=32, last=False):
        super(TranspConvBlock, self).__init__()
        out_depth = 3 if last else inp_depth//2
        self.tr_conv = nn.ConvTranspose2d(inp_depth, out_depth, 3, stride=2, padding=1, output_padding=1)
        self.f_act = nn.Identity() if last else nn.ELU() # identity ??
    
    def forward(self, x):
        out = self.tr_conv(x)
        return self.f_act(out)

class Model(nn.Module):
    def __init__(self, code_size = 32):
        super(Model,self).__init__()
        self._init_encoder(code_size)
        self._init_decoder(code_size)
        self.elu = nn.ELU()

        for m in self.modules():
            m.apply(init_weights)

    def _init_encoder(self, code_size):
        in_depths = [32, 64, 128]

        self.conv_blocks=[ConvBlock(3, True)]
        for depth in in_depths:
            self.conv_blocks.append(ConvBlock(depth))
        self.conv_blocks = nn.Sequential(*self.conv_blocks)

        self.flat = nn.Flatten()
        self.fc_enc = nn.Linear(1024, code_size) 

    def _init_decoder(self, code_size):
        in_depths = [256, 128, 64]

        self.fc_dec = nn.Linear(code_size, 1024)
        self.conv_tr_blocks = []
        for depth in in_depths:
            self.conv_tr_blocks.append(TranspConvBlock(depth))
        self.conv_tr_blocks.append(TranspConvBlock(32, True))
        self.conv_tr_blocks = nn.Sequential(*self.conv_tr_blocks)
        

    def forward(self, x):
        out = self.conv_blocks(x)
        out = self.flat(out)
        out = self.elu(self.fc_enc(out))
        out = self.elu(self.fc_dec(out))
        out = out.unflatten(1, (('A', 256),('B', 2), ('C',2)))
        out = out.rename(None)
        out = self.conv_tr_blocks(out)
        return out


### Training

In [None]:
# data preparing
X, attr = load_lfw_dataset(use_raw=True, dimx=32, dimy=32)

In [None]:
cfg_batch_size = 32
cfg_ckpt_path = 'AE/checkpoint'
cfg_ckpt_load = False 

class Train:
    def __init__(self, X):
        self._prepare_data(X)
        self.model = Model()
        self.model.to(DEVICE)

        self.optim = optim.Adamax(self.model.parameters())
        self.crit = nn.MSELoss()
        self.nrof_epochs = 25
        self.cur_epoch, self.global_step = 0, 0
        self.epoch_size = len(self.dataset) // cfg_batch_size +1
    
    def _prepare_data(self,X):
        X = X.astype('float32') / 255.0 - 0.5
        X_train, X_test = train_test_split(X, test_size=0.1, random_state=42)

        tr_x_tensor = torch.Tensor(X_train.transpose((0,3,1,2)))
        tst_x_tensor = torch.Tensor(X_test.transpose((0,3,1,2)))
        self.dataset = TensorDataset(tr_x_tensor)
        self.tst_dataset = TensorDataset(tst_x_tensor)
        self.dataloader = DataLoader(self.dataset, batch_size = cfg_batch_size)
        self.tst_dataloader = DataLoader(self.tst_dataset, batch_size = cfg_batch_size)

    def save_model(self):
        if not os.path.exists(os.path.dirname(cfg_ckpt_path)):
            os.makedirs(os.path.dirname(cfg_ckpt_path))
        torch.save({'step':self.global_step,
                    'model':self.model.state_dict(),
                    'opt':self.optim.state_dict()
                    },
                   cfg_ckpt_path)
        print('Model saved')

    def load_model(self):
        try:
            ckpt = torch.load(cfg_ckpt_path)
            self.cur_epoch = ckpt['step'] // self.epoch_size
            self.global_step = ckpt['step'] + 1
            self.model.load_state_dict(ckpt['model'])
            self.optim.load_state_dict(ckpt['opt'])
            print('Model loaded')
        except FileNotFoundError as FNFer:
            raise FNFer

    def train_epoch(self):
        self.model.train()
        for batch_idx, batch in enumerate(self.dataloader):
            inputs, outputs = batch[0].to(DEVICE), batch[0].to(DEVICE) 
            self.optim.zero_grad()

            reconstruct = self.model(inputs)
            loss = self.crit(reconstruct, outputs)
            
            if batch_idx%100==0:
                print('{} step passed'.format(batch_idx))
                ind = np.random.choice(len(self.tst_dataset))
                tst_image = self.tst_dataset[ind][0]
                reconstruct = self.model(tst_image.unsqueeze(0).to(DEVICE))
                show_results(tst_image, reconstruct)

            loss.backward()
            self.optim.step()

    def train(self):
        if cfg_ckpt_load:
            self.load_model()

        for epoch in range(self.cur_epoch,self.cur_epoch+self.nrof_epochs):
            self.train_epoch()
            self.save_model()
    
    def evaluate(self):
        self.load_model()
        self.model.eval()
        tr_loss, tst_loss, tr_len, tst_len = 0, 0, 0, 0

        for batch_idx, batch in enumerate(self.dataloader):
            inputs, outputs = batch[0].to(DEVICE), batch[0].to(DEVICE) 

            reconstruct = self.model(inputs)
            tr_loss = self.crit(reconstruct, outputs)
            tr_len += 1
        
        for batch_idx, batch in enumerate(self.tst_dataloader):
            inputs, outputs = batch[0].to(DEVICE), batch[0].to(DEVICE) 

            reconstruct = self.model(inputs)
            tst_loss = self.crit(reconstruct, outputs)
            tst_len += 1

        tr_loss/=tr_len
        tst_loss/=tst_len

        ind = np.random.choice(len(self.tst_dataset))
        tst_image = self.tst_dataset[ind][0]
        reconstruct = self.model(tst_image.unsqueeze(0).to(DEVICE))
        show_results(tst_image, reconstruct)

        print("Train loss: {:.6f}\nTest loss: {:.6f}".format(tr_loss, tst_loss))


Tr_model = Train(X)
Tr_model.train()
#Tr_model.evaluate()

## GAN

## Model

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        in_depths = [32, 64, 128]
        self.conv_blocks = []

        self.conv_blocks.append(ConvBlock(3, True))
        for depth in in_depths:
            self.conv_blocks.append(ConvBlock(depth))
        self.conv_blocks = nn.Sequential(*self.conv_blocks)

        self.flat = nn.Flatten()
        self.fc_1 = nn.Linear(2304, 256) 
        self.tanh = nn.Tanh()
        self.fc_2 = nn.Linear(256, 2) 
        self.log_sftmx = nn.LogSoftmax()

        for m in self.modules():
            m.apply(init_weights)

    def forward(self, x):
        out = self.conv_blocks(x)
        out = self.flat(out)        
        out = self.fc_1(out)
        out = self.tanh(out)
        out = self.fc_2(out)
        return self.log_sftmx(out)


class Generator(nn.Module):
    def __init__(self,code_size = 256):
        super(Generator,self).__init__()
        self.fc_1 = nn.Linear(code_size, 640)
        self.elu = nn.ELU()
        self.transp_conv_1 = nn.ConvTranspose2d(10, 64, 5) 
        self.transp_conv_2 = nn.ConvTranspose2d(64, 64, 5) 
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
        self.transp_conv_3 = nn.ConvTranspose2d(64, 32, 3) 
        self.transp_conv_4 = nn.ConvTranspose2d(32, 32, 3) 
        self.transp_conv_5 = nn.ConvTranspose2d(32, 32, 3)
        self.conv = nn.Conv2d(32, 3, 3)

        for m in self.modules():
            m.apply(init_weights)


    def forward(self, x):
        out = self.elu(self.fc_1(x))
        out = out.unflatten(1, (('A', 10),('B', 8), ('C',8)))
        out = out.rename(None)
        out = self.elu(self.transp_conv_1(out))
        out = self.elu(self.transp_conv_2(out))
        out = self.upsample(out)
        out = self.elu(self.transp_conv_3(out))
        out = self.elu(self.transp_conv_4(out))
        out = self.elu(self.transp_conv_5(out))
        return self.conv(out)


### Training

In [None]:
# loading data

data, attrs = load_lfw_dataset(dimx=36,dimy=36)
data = np.float32(data)/255.

In [None]:
cfg_batch_size = 100
cfg_ckpt_path = 'GAN/checkpoint'
cfg_ckpt_load = False 

class Train:
    def __init__(self, data):
        self.data = data
        self.disc = Discriminator().to(DEVICE)
        self.gen = Generator().to(DEVICE)
        self.ada_optim = optim.Adamax(self.gen.parameters())
        self.sgd_optim = optim.SGD(self.disc.parameters(),lr=1e-3)
        self.crit = nn.MSELoss()
    
    def save_model(self):
        if not os.path.exists(os.path.dirname(cfg_ckpt_path)):
            os.makedirs(os.path.dirname(cfg_ckpt_path))
        torch.save({
                    'disc':self.disc.state_dict(),
                    'gen':self.gen.state_dict(),
                    'ada_opt':self.ada_optim.state_dict(),
                    'sgd_opt':self.sgd_optim.state_dict()
                    },
                   cfg_ckpt_path)
        print('model saved')

    def load_model(self):
        ckpt = torch.load(cfg_ckpt_path)
        self.disc.load_state_dict(ckpt['disc'])
        self.gen.load_state_dict(ckpt['gen'])
        self.ada_optim.load_state_dict(ckpt['ada_opt'])
        self.sgd_optim.load_state_dict(ckpt['sgd_opt'])
        print('model loaded')

    def train_epoch(self):
        self.disc.train()
        self.gen.eval() 

        for i in range(5):
            self.sgd_optim.zero_grad()
            real_data = sample_data_batch(cfg_batch_size, self.data)
            noise = sample_noise_batch(cfg_batch_size)

            logp_real = self.disc(real_data)
            logp_gen = self.disc(self.gen(noise))

            d_loss = -torch.mean(logp_real[:,1] + logp_gen[:,0])
            last_layer_params = torch.cat([x.view(-1) for x in self.disc.fc_2.parameters()])
            d_loss += torch.mean(last_layer_params**2)
            
            d_loss.backward()
            self.sgd_optim.step()
        
        self.gen.train()
        self.disc.eval()
        self.ada_optim.zero_grad()

        noise = sample_noise_batch(cfg_batch_size)
        logp_gen = self.disc(self.gen(noise))

        g_loss = -torch.mean(logp_gen[:,1])
        g_loss.backward()
        self.ada_optim.step()

    def train(self):
        if cfg_ckpt_load:
            self.load_model()
        
        for epoch in range(50000):
            self.train_epoch()
            
            if epoch %100==0:
                self.disc.eval(), self.gen.eval()
                print('epoch',epoch)
                noise = sample_noise_batch(6)
                images = self.gen(noise).cpu().detach().numpy().transpose((0,2,3,1))
                sample_images(2,3,images,self.data,True)
                
                real_data = sample_data_batch(1000, self.data)
                noise = sample_noise_batch(1000)
                logp_real = self.disc(real_data).cpu().detach().numpy()
                logp_gen = self.disc(self.gen(noise)).cpu().detach().numpy()
                sample_probas(logp_real, logp_gen)
                self.save_model()

Tr_model = Train(data)
Tr_model.train()