## Additional functions and plotting

In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CelebA,MNIST
import torchvision.datasets as dset
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm
from AE import *
from Sampling import *
from Metric import *
import warnings
import timeit
warnings.filterwarnings("ignore",category=UserWarning)
from sklearn.mixture import GaussianMixture
from sklearn.neighbors import KernelDensity

import rpy2.robjects.numpy2ri
import rpy2.robjects as robjects
from rpy2.robjects.packages import importr
robjects.numpy2ri.activate()
base = importr('base')
rvinecop = importr('rvinecopulib')

trans0 = transforms.Compose([
    transforms.CenterCrop(140),
    transforms.Scale((64,64)),
    transforms.ToTensor(),
])


# Load data
path = %pwd
dataset_train = CelebA(path,split="train", transform=trans0, download=False)
dataset_test = CelebA(path,split="test", transform=trans0, download=False)
dataloader_test = DataLoader(dataset_test, batch_size=2000, shuffle=True)
img_test,attr = next(iter(dataloader_test))

In [None]:
# AE 
model_AE = AE_Celeba()
model_AE.load_state_dict(torch.load('./ae_celebA_200.pth',map_location=torch.device('cpu')))

In [None]:
# VAE 
model_VAE = VAE_Celeba(image_size=64, channel_num=3, kernel_num=128, z_size=100)
model_VAE.load_state_dict(torch.load('./vae_CelebA_200.pth'
                                          ,map_location=torch.device('cpu')))


In [None]:
n=2000
dataloader_train = DataLoader(dataset_train, batch_size=n, shuffle=True)
img,attr = next(iter(dataloader_train))

In [None]:
# Get latent variable
with torch.no_grad():
    lv = model_AE.encode(img)
lv = lv.detach().numpy()
img_new_AE = model_AE.decode(torch.tensor(lv).float())

### Linear interpolation in latent space

In [None]:
def interpolate(img):
    n1 = 100; n2 = 10 #n1=100
    output = torch.ones((n1,n2))
    for i in range(n1):
        inter = torch.linspace(img[0,i],img[1,i],n2)
        output[i] = inter
    return output.T

In [None]:
# Examples to interpolate on
lv_interpo = lv[[66,1,2,7,8,155,775,17,67,16],:]

In [None]:
plt.figure(figsize=(18, 9))
plt.subplots_adjust(wspace =0, hspace =0)
for i in range(5):
    output = interpolate(lv_interpo[2*i:2*i+2,:])
    img_new = model_AE.decode(torch.tensor(output).float())
    for j in range(10):
        ax = plt.subplot(5, 10, i*10+j+1)
        with torch.no_grad():
            plt.imshow(img_new[j].reshape((3,64,64)).permute(1,2,0)) 
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

### Combination of marginals and copula of different classes

###### Classes
21: male, 16: glasses, 9:black hair, 10:blond hair, 12:brown hair, 18:grey hair, 32:smiling, 23:mustasch
36: wearingHat, 

In [None]:
# Choose input
l1 = []; l2 = []; attr_nr1 =32; attr_nr2 =32
for i,label in enumerate(attr):
    if label[11-1] == 0: #not blurry
        if label[attr_nr1-1] == 0: #
            l1.append(i)
        if label[attr_nr1-1] == 1: #
            l2.append(i)
x = lv[l1,:] # copula
y = lv[l2,:] # margins

In [None]:
# Generate new images with model
n_sample = 10
samples_manip = sampling1(x, y, n_sample, seed=123)
img_new_manip = model_AE.decode(torch.tensor(samples_manip).float())

### Finding the closest neighbour


In [None]:
def find_neighbour(lv_samples,n_sample):
    # return lv of nearest neighbor in latent space of original data
    lv_xy = torch.tensor(lv) 
    dist = distance(lv_samples,lv_xy,"cpu")
    val, idx = dist.topk(1, 1, False)
    lv_neighbor = lv_xy[idx[:,0],:]
    index =idx[]
    return  lv_neighbor, index

In [None]:
# New samples
n_sample=4
lv_samples_EBCAE = sampleing1(lv, lv, n_sample, seed=123)
lv_neighbor_EBCAE, index = find_neighbour(lv_samples_EBCAE,10)

In [None]:
# Decode
img_new_EBCAE = model_AE.decode(lv_samples_EBCAE)
img_neighbor_AE_EBCAE = model_AE.decode(lv_neighbor_EBCAE)
img_neighbor_EBCAE = img[index]

In [None]:
# Plot
plt.figure(figsize=(8, 6))
plt.subplots_adjust(wspace =0, hspace =0)
for i in range(12):
    ax = plt.subplot(3, 4, i + 1)
    with torch.no_grad():
        if i < 4: 
            plt.imshow(img_new_EBCAE[i].reshape((3,64,64)).permute(1,2,0)) 
        elif i <8:
              plt.imshow(img_neighbor_AE_EBCAE[i-4].reshape((3,64,64)).permute(1,2,0)) 
        else: 
            plt.imshow(img_neighbor_EBCAE[i-8].reshape((3,64,64)).permute(1,2,0))            
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)