In [1]:
import torch
from torch import nn
import numpy as np
import pandas as pd
import os
import sys
import time
import cv2
from PIL import Image
from collections import defaultdict
from datetime import datetime
from IPython.display import clear_output

import matplotlib.pyplot as plt

from torch.distributions.multivariate_normal import MultivariateNormal
from torch.utils.data import DataLoader, Dataset

from sklearn.model_selection import train_test_split
from tqdm import tqdm

# Path initialization
sys.path.insert(0, "../vae_architectures")
from VAE import VAE1

In [3]:
df_split = pd.read_csv("../data/dataframes/train_val_test_split.csv")
imgs_train = df_split[df_split.phase == "train"].img_names.values[0]
imgs_train = [name for name in imgs_train.split("'") if ".jpg" in name]
imgs_val = df_split[df_split.phase == "val"].img_names.values[0]
imgs_val = [name for name in imgs_val.split("'") if ".jpg" in name]
imgs_test = df_split[df_split.phase == "test"].img_names.values[0]
imgs_test = [name for name in imgs_test.split("'") if ".jpg" in name]

In [4]:
### Data

ROOT = "../data/img_align_celeba"
H, W = 256, 256

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 

class FaceData(Dataset):
    def __init__(self, img_names):

        self.img_name = img_names         

    def __len__(self):
        return len(self.img_name)

    def __getitem__(self, index):
        img_path = os.path.join(ROOT, self.img_name[index])

        img = Image.open(img_path)
        img = img.resize((256, 256), Image.ANTIALIAS)
        img = img.crop((32, 32, 224, 224))
        img = img.resize((64, 64), Image.ANTIALIAS)
        img = np.array(img)

        return img.transpose(2, 0, 1) / 255.


img_names = {"train": imgs_train, "val": imgs_val, "test": imgs_test}
datasets = {phase: FaceData(img_names[phase]) for phase in ["train", "val", "test"]}
dataloaders = {phase: DataLoader(datasets[phase], batch_size=64, shuffle=True, num_workers=2) 
           for phase in ["train", "val", "test"]}

In [13]:
def plot_results(model, X_batch):
    r"""Plots and saves the results of model performance.
    """

    model.eval()
    X_rec = model.reconstruct_x(X_batch.float().to(device))
    X_gen = model.generate_x(N=1, device=device)

    fig, ax = plt.subplots(1, 3, figsize=(16, 5))

    ax[0].imshow(X_batch[0].cpu().numpy().transpose(1, 2, 0))
    ax[0].set_title("Real", fontsize=14)
    ax[1].imshow(torch.sigmoid(X_rec[0].cpu()).detach().numpy().transpose(1, 2, 0))
    ax[1].set_title("Rec", fontsize=14)
    ax[2].imshow(torch.sigmoid(X_gen[0].cpu()).detach().numpy().transpose(1, 2, 0))
    ax[2].set_title("Gen", fontsize=14)
    
    plt.show()

### Model

In [12]:
### Model 

model = VAE1(hid_dim=128, KOF=32, p=0.04)
model.load_state_dict(torch.load("/home/edlazareva/latent-subspaces/train_vae/weights/VAE1_Liza_20_10_2019_15_22/weight.epoch_40_itr_317_loss_val_0.03641778684237547.pth", 
                                 map_location=device))
model = model.to(device)
model.eval()

VAE1(
  (encoder): Sequential(
    (block01): Conv_block(
      (conv): Conv2d(3, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (activation): LeakyReLU(negative_slope=0.2)
      (dropout): Dropout2d(p=0.04)
      (batch_norm): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (block02): Conv_block(
      (conv): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (activation): LeakyReLU(negative_slope=0.2)
      (dropout): Dropout2d(p=0.04)
      (batch_norm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (block03): Conv_block(
      (conv): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (activation): LeakyReLU(negative_slope=0.2)
      (dropout): Dropout2d(p=0.04)
      (batch_norm): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (block04): Conv_block(
      (conv): Conv2d(128, 256, kernel_size=(4, 

### Performance

In [27]:
phase  = "val"

for i, X_batch in enumerate(dataloaders[phase]):
    plot_results(model, X_batch)
    break