In [1]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import pandas as pd
import torchvision

from sklearn import manifold
import plotly.express as px
import pandas as pd
from sklearn.decomposition import PCA

from src import _PATH_DATA, _PATH_MODELS, _PROJECT_ROOT
from src.data.make_dataset import CelebADataModule
from src.models.models import  CViTVAE, ConvCVAE

In [2]:
celeb = CelebADataModule(batch_size=500, num_workers=0)
celeb.setup()

In [3]:
model = CViTVAE.load_from_checkpoint("..\models\CViTVAE2022-04-29-1735\CViTVAE-epoch=174.ckpt")
model_2 = ConvCVAE.load_from_checkpoint("..\models\ConvCVAE2022-04-30-1854\ConvCVAE-epoch=349.ckpt")
model.eval()
model_2.eval()
test_loader = celeb.test_dataloader()
iterator = iter(test_loader)

In [4]:
z = []
z2 = []
labels_all = []
for img, labels in iterator:
    z.append(model.encoding(img,labels.to(dtype=torch.float)).detach())
    z2.append(model.encoding(img,labels.to(dtype=torch.float)).detach())
    labels_all.append(labels.argmax(dim=1))

z = torch.cat(z).detach()
z2 = torch.cat(z2).detach()
labels = torch.cat(labels_all).detach()

In [5]:
pca = PCA()
z_ = pca.fit_transform(z)
fig = px.line(np.cumsum(pca.explained_variance_ratio_),width=800,labels={"value":"Cumulative Explained Variance","index":"Number of Principle Components"},title="ViT-VAE")
fig.update_layout(showlegend=False)
fig.show()
fig.write_image("../outputs/PCA_ViTVAE_encode.png")
fig.write_image("../outputs/PCA_ViTVAE_encode.svg")

In [15]:
model = CViTVAE.load_from_checkpoint("..\models\CViTVAE2022-04-29-1735\CViTVAE-epoch=174.ckpt")
model.eval()

z_mblond = model.sample_latent_space(2000,torch.tensor([1,0,0,0,0,0],dtype=torch.float))
z_mbrown = model.sample_latent_space(2000,torch.tensor([0,1,0,0,0,0],dtype=torch.float))
z_mblack = model.sample_latent_space(2000,torch.tensor([0,0,1,0,0,0],dtype=torch.float))
z_wblond = model.sample_latent_space(2000,torch.tensor([0,0,0,1,0,0],dtype=torch.float))
z_wbrown = model.sample_latent_space(2000,torch.tensor([0,0,0,0,1,0],dtype=torch.float))
z_wblack = model.sample_latent_space(2000,torch.tensor([0,0,0,0,0,1],dtype=torch.float))

z = torch.cat([z_mblond,z_mbrown,z_mblack,z_wblond,z_wbrown,z_wblack]).detach()
pca = PCA()
z_ = pca.fit_transform(z)

fig = px.line(np.cumsum(pca.explained_variance_ratio_),width=800,labels={"value":"Cumulative Explained Variance","index":"Number of Principle Components"},title="ViT-VAE")
fig.update_layout(showlegend=False)
fig.show()
# fig.write_image("PCA_ViTVAE_sample.png")
# fig.write_image("PCA_ViTVAE_sample.svg")

In [14]:
tsne = manifold.TSNE(n_components=2, init='random', random_state=0,learning_rate="auto")
x_tsne = tsne.fit_transform(z_[:,:200])

df = pd.DataFrame(x_tsne,columns=["x","y"])
df["label"] = np.repeat([0,1,2,3,4,5],2000)

fig = px.scatter(df,x="x",y="y",color="label",width=800,height=800,title="ViT-VAE")
fig.show()
fig.write_image("T-SNE_ViTVAE_sample.png")
fig.write_image("T-SNE_ViTVAE_sample.svg")

In [7]:
model = ConvCVAE.load_from_checkpoint("..\models\ConvCVAE2022-04-30-1854\ConvCVAE-epoch=349.ckpt")
model.eval()

z_mblond = model.sample_latent_space(2000,torch.tensor([1,0,0,0,0,0],dtype=torch.float))
z_mbrown = model.sample_latent_space(2000,torch.tensor([0,1,0,0,0,0],dtype=torch.float))
z_mblack = model.sample_latent_space(2000,torch.tensor([0,0,1,0,0,0],dtype=torch.float))
z_wblond = model.sample_latent_space(2000,torch.tensor([0,0,0,1,0,0],dtype=torch.float))
z_wbrown = model.sample_latent_space(2000,torch.tensor([0,0,0,0,1,0],dtype=torch.float))
z_wblack = model.sample_latent_space(2000,torch.tensor([0,0,0,0,0,1],dtype=torch.float))

z = torch.cat([z_mblond,z_mbrown,z_mblack,z_wblond,z_wbrown,z_wblack]).detach()
pca = PCA()
z_ = pca.fit_transform(z)

fig = px.line(np.cumsum(pca.explained_variance_ratio_),width=800,labels={"value":"Cumulative Explained Variance","index":"Number of Principle Components"},title="VAE")
fig.update_layout(showlegend=False)
fig.show()
# fig.write_image("PCA_VAE_sample.png")
# fig.write_image("PCA_VAE_sample.svg")

In [12]:
tsne = manifold.TSNE(n_components=2, init='random', random_state=0,learning_rate="auto")
x_tsne = tsne.fit_transform(z_[:,:100])

df = pd.DataFrame(x_tsne,columns=["x","y"])
df["label"] = np.repeat([0,1,2,3,4,5],2000)

fig = px.scatter(df,x="x",y="y",color="label",width=800,height=800,title="VAE")
fig.show()
fig.write_image("../outputs/T-SNE_VAE_sample.png")
fig.write_image("../outputs/T-SNE_VAE_sample.svg")