## Evaluation of sampling methods over epochs of MNIST

In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CelebA,MNIST
import torchvision.datasets as dset
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm
import warnings
warnings.filterwarnings("ignore",category=UserWarning)
import timeit

import rpy2.robjects.numpy2ri
import rpy2.robjects as robjects
from rpy2.robjects.packages import importr
robjects.numpy2ri.activate()
base = importr('base')
#utils = importr('utils')
#utils.install_packages('rvinecopulib')
#utils.install_packages('Rcpp')
rvinecop = importr('rvinecopulib')

from sklearn.mixture import GaussianMixture
from sklearn.neighbors import KernelDensity
from sklearn.model_selection import GridSearchCV

from torch import nn
from torch import distributions
from torch.nn.parameter import Parameter

from AE import *
from Sampling import *
from Metric import *
from RealNVP import *



trans0 = transforms.Compose([
    transforms.ToTensor(),
    transforms.Pad(2),
    transforms.ToPILImage(),
    transforms.ToTensor(),
])
path = %pwd
dataset_train = MNIST(path,train=True, transform=trans0, download=True)
dataset_test = MNIST(path,train=False, transform=trans0, download=True)




## Evaluation

In [None]:
model_AE = AE_MNIST(image_size=32, channel_num=1,kernel_num=128, z_size=10)
model_VAE = VAE_MNIST(image_size=32, channel_num=1, kernel_num=128, z_size=10)

score=np.zeros((10,8,6)) # Epochs, # Models, # Scores

# Encoded data for latent space
n=2000    
dataloader_train = DataLoader(dataset_train, batch_size=n, shuffle=True)
dataloader_test = DataLoader(dataset_test, batch_size=n, shuffle=True)
img_train,attr = next(iter(dataloader_train))
img_test,attr = next(iter(dataloader_test))


for epoch in range(20,201,20):
    i = int(epoch/20-1)
  
    # Load model according to epoch
    model_AE.load_state_dict(torch.load('./ae_MNIST_{}.pth'.format(epoch)
                             ,map_location=torch.device('cpu')))

    model_VAE.load_state_dict(torch.load('./vae_MNIST_{}.pth'.format(epoch)
                                         ,map_location=torch.device('cpu')))
  

    # Test set
    with torch.no_grad():
         y_test = model_AE.encode(img_test)
      
    
    # Get latent variable
    s0 = timeit.default_timer()

    lv = model_AE.encode(img_train)
    lv = lv.detach().numpy()

    s1 = timeit.default_timer()
    t1 = s1-s0
    
  
 
    # Create samples
    for m in [0,1,2,3,4,5,6,7]:
        
        s0 = timeit.default_timer()
            
        # Beta Copula
        if m == 0: 
            y_sample = sampleing1(lv, lv, lv.shape[0], seed=500)
           
            
        # VAE
        elif m == 1: 
            y_sample=torch.randn(lv.shape[0], 10)

                
        # Vine copula trun_level=15
        elif m == 2:
            fixed_noise = np.random.rand(lv.shape[0], lv.shape[1]) 
            copula_controls = base.list(family_set="tll", trunc_lvl=15, cores=1)
            vine_obj = rvinecop.vine(lv, copula_controls=copula_controls)
            sampled_r = rvinecop.inverse_rosenblatt(fixed_noise, vine_obj)
            y_sample = torch.Tensor(np.asarray(sampled_r)).view(lv.shape[0], -1).to("cpu")    
                
                
        # Gauss
        elif m == 3: 
            mean = np.mean(lv, axis=0)
            cov = np.cov(lv, rowvar=0)
            y_sample=torch.tensor(np.random.multivariate_normal(mean, cov, lv.shape[0])).float()
            
                
        # Independent
        elif m == 4: 
            #Indep
            lv_new_indep =indep_sampling(lv,lv, lv.shape[0], seed=500)
            y_sample= lv_new_indep.detach().numpy()
            y_sample=torch.tensor(shuffle_along_axis(lv_new_indep, axis=0)).float()

                
        # GMM
        elif m == 5: 
            gm = GaussianMixture(n_components=10, random_state=0).fit(lv)
            y_sample=torch.tensor(gm.sample(n_samples=lv.shape[0])[0]).float()
                
             
        # KDE
        elif m == 6: 
            grid = GridSearchCV(KernelDensity(),
                {'bandwidth': np.linspace(0.1, 2.0, 40)},
                cv=10) 
            grid.fit(lv)
            kde = grid.best_estimator_
  
            lvnew_kde=kde.sample(n_samples=lv.shape[0])
            lvnew_kde=np.array(lvnew_kde,dtype=np.double)
            y_sample=torch.tensor(lvnew_kde).float()
            
         # Real NVP
         elif m==7: 
            masks = torch.from_numpy(np.array([[0, 1, 0, 1, 0, 1, 0, 1, 0, 1], [1, 0, 1, 0, 1, 0, 1, 0, 1, 0]] * 3).astype(np.float32))
            nets = lambda: nn.Sequential(nn.Linear(10, 256), nn.LeakyReLU(), nn.Linear(256, 256), nn.LeakyReLU(), nn.Linear(256, 10),nn.Linear(10, 256), nn.LeakyReLU(), nn.Linear(256, 256), nn.LeakyReLU(), nn.Linear(256, 10), nn.Linear(10, 256), nn.LeakyReLU(), nn.Linear(256, 256), nn.LeakyReLU(), nn.Linear(256, 10),nn.Linear(10, 256), nn.LeakyReLU(), nn.Linear(256, 256), nn.LeakyReLU(), nn.Linear(256, 10),nn.Tanh())
            nett = lambda: nn.Sequential(nn.Linear(10, 256), nn.LeakyReLU(), nn.Linear(256, 256), nn.LeakyReLU(), nn.Linear(256, 10),nn.Linear(10, 256), nn.LeakyReLU(), nn.Linear(256, 256), nn.LeakyReLU(), nn.Linear(256, 10),nn.Linear(10, 256), nn.LeakyReLU(), nn.Linear(256, 256), nn.LeakyReLU(), nn.Linear(256, 10),nn.Linear(10, 256), nn.LeakyReLU(), nn.Linear(256, 256), nn.LeakyReLU(), nn.Linear(256, 10))
            prior = distributions.MultivariateNormal(torch.zeros(10), torch.eye(10))
            flow = RealNVP(nets, nett, masks, prior)
            optimizer = torch.optim.Adam([p for p in flow.parameters() if p.requires_grad==True], lr=1e-4)
            loss_hist = np.array([])
            num_samples=128
            for a in tqdm(range(2001)): 
                helper=np.random.randint(0,2000,size=num_samples)
                x_np= lv[helper]
                x_np = x_np.astype(np.float32)
                loss = -flow.log_prob(torch.from_numpy(x_np)).mean()
                optimizer.zero_grad()
                loss.backward(retain_graph=True)
                optimizer.step()
                   
            x = flow.sample(2000).detach().numpy()
            x=np.reshape(x,newshape=(x.shape[0],x.shape[2]))
            y_sample=torch.tensor(x).float()
                
            
            
        # Decode    
        if m == 2: 
            img_new = model_VAE.decode(torch.tensor(y_sample).float())
        else:
            img_new = model_AE.decode(torch.tensor(y_sample).float())
                
        # Compute score    
        with torch.no_grad():   
            sc1 = compute_score(img_test,img_new)
            sc2 = compute_score(y_test,y_sample)

            for j in range(3): score[i,m,j] = sc1[j]
            for j in range(3,6): score[i,m,j] = sc2[j-3]
                
    print("Finished:",epoch,timeit.default_timer()- s1)

In [None]:
torch.save(score,'./score_MNIST_epoch_all_VGL.pt')

#### Plot

In [None]:
from matplotlib.pyplot import cm
score1=score[:10,:,:]
xaxis = [i for i in range(20,201,20)]
title = ['EMD PIXEL','MMD PIXEL','1NN PIXEL','EMD CONV','MMD CONV','1NN CONV']
label =  ['EBCAE','VCAE_5','VAE','VCAE_15','Gauss','Independent','GMM','KDE','RealNVP']
plt.rcParams["font.family"] = "Times New Roman"
fig=plt.figure(figsize=(8.27, 3.8))
k=1
for i in [0,3,1,4,2,5]:
    ax = plt.subplot(2,3,k)
    color = iter(cm.Set2(np.linspace(0, 1, 8)))
    plt.xticks(xaxis[0::2])
    for j in [0,2,3,4,5,6,7]:
        c = next(color)
        plt.xticks(fontsize=10)
        plt.yticks(fontsize=10)
        ax.plot(xaxis,score1[:,j,i],'--bo',label=label[j], color=c,markersize=3.5)
        plt.xlabel('epochs',fontsize=10)
        
        
    ax.set_title(title[i])
    k+=1
    
lines_labels = [axi.get_legend_handles_labels() for axi in fig.axes]
lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]   


lines=[lines[0],lines[1],lines[2],lines[3],lines[4],lines[5],lines[6],lines[7]]
labels=[labels[0],labels[1],labels[2],labels[3],labels[4],labels[5],labels[6],labels[7]]

bbox_transform=fig.transFigure
fig.legend(lines, labels, loc='upper center', bbox_to_anchor=(0.5, 0.05),fancybox=False, markerscale=2
           ,shadow=False, framealpha=0, ncol=7, fontsize=10)    

plt.tight_layout()
plt.savefig("score_MNIST_epoch_VGL.png",dpi=300,bbox_inches='tight')