In [None]:
import sys, os
import numpy as np
import matplotlib.pyplot as plt
sys.path.append('../')

import torch, torchvision
import torch.nn.functional as F
from torchvision.models import vit_b_16, ViT_B_16_Weights
from sklearn.svm import LinearSVC
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.mixture import GaussianMixture
from sklearn.metrics import homogeneity_score, silhouette_score

from loader.MNIST_dataset import MNIST
from models import load_pretrained
from loader import get_dataloader
from geometry import get_pullbacked_Riemannian_metric 
from utils.utils import label_to_color, figure_to_array, PD_metric_to_ellipse

device = f'cuda:0'

## 1. Select Trained Model

In [None]:
# # MNIST digits 0, 1
# vae, cfg = load_pretrained(
#     identifier='MNIST/DIM_2/vae_mnist_01',
#     config_file='mnist_vae_z2.yml',
#     ckpt_file='model_best.pkl',
#     root='../results'
# )
# irvae, cfg = load_pretrained(
#     identifier='MNIST/DIM_2/irvae_mnist_01',
#     config_file='mnist_irvae_z2.yml',
#     ckpt_file='model_best.pkl',
#     root='../results'
# )

# # MNIST digits 0, 1, 5
# vae, cfg = load_pretrained(
#     identifier='MNIST/DIM_2/vae_mnist_015',
#     config_file='mnist_vae_z2.yml',
#     ckpt_file='model_best.pkl',
#     root='../results'
# )
# irvae, cfg = load_pretrained(
#     identifier='MNIST/DIM_2/irvae_mnist_015',
#     config_file='mnist_irvae_z2.yml',
#     ckpt_file='model_best.pkl',
#     root='../results'
# )

# # MNIST digits 0, 1, 3, 6, 7
# vae, cfg = load_pretrained(
#     identifier='MNIST/DIM_2/vae_mnist_01367',
#     config_file='mnist_vae_z2.yml',
#     ckpt_file='model_best.pkl',
#     root='../results'
# )
# irvae, cfg = load_pretrained(
#     identifier='MNIST/DIM_2/irvae_mnist_01367',
#     config_file='mnist_irvae_z2.yml',
#     ckpt_file='model_best.pkl',
#     root='../results'
# )

# MNIST digits 0, 1, 3, 6, 7
irvae, cfg = load_pretrained(
    identifier='MNIST/DIM_2/irvae_mnist_01_isoreg_100',
    config_file='mnist_irvae_z2.yml',
    ckpt_file='model_best.pkl',
    root='../results'
)

irvae_pretrain, cfg = load_pretrained(
    identifier='MNIST/DIM_2/irvae_mnist_pretrain_01_isoreg_100',
    config_file='mnist_irvae_z2_pretrain.yml',
    ckpt_file='model_best.pkl',
    root='../results'
)

irvae.to(device);
irvae_pretrain.to(device);

type(irvae_pretrain).__name__

## 2. Get Data Loader

In [None]:
val_data_cfg = cfg['data']['validation']

val_data_cfg['root'] = '../dataset'
dl = get_dataloader(val_data_cfg)

## 3. Data Encoding wiht VAE and IRVAE

In [None]:
# num_points_for_each_class = 200
num_points_for_each_class = 2
num_G_plots_for_each_class = 2
label_unique = torch.unique(dl.dataset.targets)
print(label_unique)

# get model
from models.modules import Net

net = Net()
net.load_state_dict(torch.load('../models/saved_model/simple_linear.pt'))
net.to(device)

class pretrained_enhanced_Net(torch.nn.Module):
    def __init__(self):
        super(pretrained_enhanced_Net, self).__init__()
        self.f1 = irvae_pretrain.decode
        self.f2 = net
    
    def forward(self, x):
        x = self.f1(x)
        return self.f2(x)
pre_trained_enhanced_net = pretrained_enhanced_Net()

## 4. Visualize Encodings

### 4.1 VAE

In [None]:
def plot_irave_result(encoder_network, decoder_network, number_of_oval=1):
    if type(decoder_network).__name__ == 'IsotropicGaussian':
        title = 'Isometric Representation (IRVAE)'
    elif type(decoder_network).__name__ == 'pretrained_enhanced_Net':
        title = 'Isometric Representation (IRVAE Pretrained)'
    else:
        title = 'Isometric Representation (IRVAE VIT Prtrained)'

    z_ = []
    z_sampled_ = []
    label_ = []
    label_sampled_ = []
    G_ = []

    for label in label_unique:
        temp_data = dl.dataset.data[dl.dataset.targets == label][:num_points_for_each_class]
        temp_z = encoder_network(temp_data.to(device))
        z_sampled = temp_z[torch.randperm(len(temp_z))[:num_G_plots_for_each_class]]
        G = get_pullbacked_Riemannian_metric(decoder_network, z_sampled)

        z_.append(temp_z)
        label_.append(label.repeat(temp_z.size(0)))
        z_sampled_.append(z_sampled)
        label_sampled_.append(label.repeat(z_sampled.size(0)))
        G_.append(G)

    irvae_z_ = torch.cat(z_, dim=0).detach().cpu().numpy()
    irvae_z_max, irvae_z_min = np.max(irvae_z_, axis=0), np.min(irvae_z_, axis=0)
    irvae_label_ = torch.cat(label_, dim=0).detach().cpu().numpy()
    irvae_color_ = label_to_color(irvae_label_)
    irvae_G_ = torch.cat(G_, dim=0).detach().cpu()
    irvae_z_sampled_ = torch.cat(z_sampled_, dim=0).detach().cpu().numpy()
    irvae_label_sampled_ = torch.cat(label_sampled_, dim=0).detach().cpu().numpy()
    irvae_color_sampled_ = label_to_color(irvae_label_sampled_)

    # clustering_model = GaussianMixture(n_components=len(label_unique), random_state=32)
    # clustering_model.fit(irvae_z_)

    clustering_model = LinearSVC(random_state=32)
    clustering_model = LinearDiscriminantAnalysis()
    clustering_model .fit(irvae_z_, irvae_label_)

    plt.rc('font', size=12)
    index = cfg['data']['training']['digits'].split('_')[1]
    size = 2

    f = plt.figure(1)
    z_scale = np.minimum(np.max(irvae_z_, axis=0), np.min(irvae_z_, axis=0))
    eig_mean = torch.svd(irvae_G_).S.mean().item()
    scale = 0.1 * z_scale * np.sqrt(eig_mean)
    alpha = 0.3

    x_grid = np.arange(irvae_z_min[0], irvae_z_max[0], 0.01)
    y_grid = np.arange(irvae_z_min[1], irvae_z_max[1], 0.01)
    
    # xx, yy = np.meshgrid(x_grid, y_grid)
    # xx = np.expand_dims(xx.reshape((-1, )), axis=1)
    # yy = np.expand_dims(yy.reshape((-1, )), axis=1)
    # X = np.concatenate((xx, yy), axis=1)
    # pred_X = svc_model.predict(X)
    # plt.scatter(X[:, 0], X[:, 1], c = pred_X)

    predicted_label = clustering_model.predict(irvae_z_)
    # print(f"irvae_label: {irvae_label_}")
    # print(f"predicted_label: {predicted_label}")
    homo_score = homogeneity_score(irvae_label_, predicted_label)
    sil_score = silhouette_score(irvae_z_, irvae_label_, metric='euclidean', sample_size=irvae_z_.shape[0])

    boundary_coord = []

    for j in y_grid:
        original_X = clustering_model.predict(np.array([[x_grid[0], j]]))[0]
        for i in x_grid:
            predict_X = clustering_model.predict(np.array([[i, j]]))[0]
            if i != 0 and (original_X != predict_X):
                boundary_coord.append([i, j])
                original_X = predict_X

    boundary_coord = np.array(boundary_coord)

    plt.scatter(boundary_coord[:, 0], boundary_coord[:, 1], c='k', s=size)
    # get G for boundary
    boundary_G = get_pullbacked_Riemannian_metric(decoder_network, torch.tensor(boundary_coord, dtype=torch.float32).to(device)).detach().cpu().numpy()

    for i in range(0, boundary_G.shape[0], boundary_G.shape[0]//number_of_oval):
        e = PD_metric_to_ellipse(np.linalg.inv(boundary_G[i,:,:]), boundary_coord[i,:], scale, fc='k', alpha=alpha)
        plt.gca().add_artist(e)

    # for idx in range(len(irvae_z_sampled_)):
    #     e = PD_metric_to_ellipse(np.linalg.inv(irvae_G_[idx,:,:]), irvae_z_sampled_[idx,:], scale, fc=irvae_color_sampled_[idx,:]/255.0, alpha=alpha)
    #     plt.gca().add_artist(e)
    for label in label_unique:
        label = label.item()
        plt.scatter(irvae_z_[irvae_label_==label,0], irvae_z_[irvae_label_==label,1], c=irvae_color_[irvae_label_==label]/255, label=label, s=size)
    plt.legend()
    plt.axis('equal')
    plt.title(title)
    # plt.savefig(f'../figure/irvae{index}.png')
    plt.show()

    print(sil_score)
    return homo_score, sil_score


In [None]:
def print_score(score_dict):
    for key in score_dict.keys():
        print(f"{key}", end="|")
    print()
    for key in score_dict.keys():
        print("-"*len(key), end="+")
    print()
    for key in score_dict.keys():
        print(f"{score_dict[key]:{len(key)}.3f}", end="|")

### 4.3 IRVAE

In [None]:
irvae_homo_score, irvae_sil_score = plot_irave_result(irvae.encode, irvae.decoder)
print_score({'homogenity': irvae_homo_score, 'silhouette': irvae_sil_score})

### 4.2 IRVAE (Pretrained)

In [None]:
pretrained_irvae_homo_score, pretrained_irvae_sil_score = plot_irave_result(irvae_pretrain.encode, pre_trained_enhanced_net, number_of_oval=1)
print_score({'homogenity':pretrained_irvae_homo_score, 'silhouette':pretrained_irvae_sil_score})

### 4.3 IRVAE (VIT pretrained)

In [None]:
class VIT_pretrained_decoder(torch.nn.Module):
    def __init__(self):
        super(VIT_pretrained_decoder, self).__init__()
        self.f1 = irvae_pretrain.decode
        self.vit = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
        self.vit.to(device=device)

    def forward(self, x):
        x = self.f1(x)
        x = F.interpolate(x, size=(224, 224), mode='bilinear')
        x = x.repeat(1, 3, 1, 1)
        x = self.vit(x)

        return x

vit_pretrained_decoder = VIT_pretrained_decoder()


In [None]:
irvae_homo_score_VIT, irvae_sil_score_VIT = plot_irave_result(irvae_pretrain.encode, vit_pretrained_decoder, number_of_oval=1)

print_score({'homogenity':irvae_homo_score_VIT, \
     'silhouette':irvae_sil_score_VIT})

### 4.3 Traditional Manifold Leanring (IsoMap)

In [None]:
from sklearn.manifold import Isomap
X = dl.dataset.data
X = X.view(len(X), -1)

embedding = Isomap(n_components=2)
X_transformed = embedding.fit_transform(X)
X_transformed = torch.tensor(X_transformed, dtype=torch.float)

In [None]:
num_points_for_each_class = 200
z_ = []
label_ = []
for label in label_unique:
    temp_z = X_transformed[dl.dataset.targets == label][:num_points_for_each_class]
    z_.append(temp_z)
    label_.append(label.repeat(temp_z.size(0)))

isomap_z_ = torch.cat(z_, dim=0).detach().cpu().numpy()
isomap_label_ = torch.cat(label_, dim=0).detach().cpu().numpy()
isomap_color_ = label_to_color(isomap_label_)

In [None]:
f = plt.figure(2)
for label in label_unique:
    label = label.item()
    plt.scatter(isomap_z_[isomap_label_==label, 0], isomap_z_[isomap_label_==label, 1], c=isomap_color_[isomap_label_==label]/255, label=label, s=size)
plt.legend()
plt.axis('equal')
plt.title('Manifold Learning (Isomap)')
# plt.savefig(f'../figure/Isomap{index}.png')
plt.show()

## 5. Interpolations

In [None]:
label1 = 0
label2 = 1

In [None]:
X = dl.dataset.data
y = dl.dataset.targets
data1 = X[y == label1][0:1].to(device)
data2 = X[y == label2][0:1].to(device)

In [None]:
z1_irvae = irvae.encode(data1)
z2_irvae = irvae.encode(data2)
z1_vae = vae.encode(data1)
z2_vae = vae.encode(data2)

In [None]:
linterp_irvae = torch.cat([z1_irvae + (z2_irvae-z1_irvae) * t/19 for t in range(20)], dim=0)
linterp_vae = torch.cat([z1_vae + (z2_vae-z1_vae) * t/19 for t in range(20)], dim=0)
x_interp_irvae = irvae.decode(linterp_irvae).detach().cpu()
x_interp_vae = vae.decode(linterp_vae).detach().cpu()

In [None]:
f = plt.figure(3, figsize=(10,10))
plt.rc('font', size=12)
index = cfg['data']['training']['digits'].split('_')[1]
size = 5

ax1 = f.add_subplot(2, 2, 1)
ax2 = f.add_subplot(2, 2, 2) 
ax3 = f.add_subplot(2, 2, 3)
ax4 = f.add_subplot(2, 2, 4)

for label in label_unique:
    label = label.item()
    ax1.scatter(vae_z_[vae_label_==label,0], vae_z_[vae_label_==label,1], c=vae_color_[vae_label_==label]/255, label=label, s=size)
ax1.plot(linterp_vae[:,0].detach().cpu(), linterp_vae[:,1].detach().cpu(), linewidth=3, color='k')
ax1.set_aspect('equal')
ax1.set_title('VAE')

for label in label_unique:
    label = label.item()
    ax2.scatter(irvae_z_[irvae_label_==label,0], irvae_z_[irvae_label_==label,1], c=irvae_color_[irvae_label_==label]/255, label=label, s=size)
ax2.plot(linterp_irvae[:,0].detach().cpu(), linterp_irvae[:,1].detach().cpu(), '--', linewidth=3, color='k')
ax2.set_aspect('equal')
ax2.set_title('IRVAE')

ax3.imshow(
    torchvision.utils.make_grid(
        torch.cat([x_interp_vae], dim=0), nrow=10, value_range=(0, 1), pad_value=1
        ).permute(1,2,0))
ax3.axis('off')

ax4.imshow(
    torchvision.utils.make_grid(
        torch.cat([x_interp_irvae], dim=0), nrow=10, value_range=(0, 1), pad_value=1
        ).permute(1,2,0))
ax4.axis('off')

f.supxlabel('Generated images from linear interpolants (from upper-left to lower-right)', y=0.385)
f.tight_layout(pad=2, h_pad=0)
plt.suptitle('Latent Space Linear Interpolations', fontsize=20, y=1)
plt.savefig(f'../figure/LSLI{index}.png', bbox_inches='tight')
plt.show()