In [4]:
import sys
if ".." not in sys.path:
    sys.path.insert(0, "..")

import glob

latents_dir = "/data/PycharmProjects/cytof_benchmark/results/latent_data/"
latent_files = glob.glob(latents_dir + "*/*/*/val.csv")

dataset_dirs = {
    'OrganoidDataset':'/data/PycharmProjects/cytof_benchmark/data/organoids',
    'CafDataset':'/data/PycharmProjects/cytof_benchmark/data/caf',
    'ChallengeDataset':'/data/PycharmProjects/cytof_benchmark/data/breast_cancer_challenge',
}

latents_list = list()
for latent_file in latent_files:
    dataset = latent_file.split('/')[-2]
    model = latent_file.split('/')[-3]
    dim = latent_file.split('/')[-4]
    latents_list.append((dataset,model,dim,latent_file))
latents_list

[('CafDataset',
  'HyperSphericalVAE',
  'dim5',
  '/data/PycharmProjects/cytof_benchmark/results/latent_data/dim5/HyperSphericalVAE/CafDataset/val.csv'),
 ('ChallengeDataset',
  'HyperSphericalVAE',
  'dim5',
  '/data/PycharmProjects/cytof_benchmark/results/latent_data/dim5/HyperSphericalVAE/ChallengeDataset/val.csv'),
 ('OrganoidDataset',
  'HyperSphericalVAE',
  'dim5',
  '/data/PycharmProjects/cytof_benchmark/results/latent_data/dim5/HyperSphericalVAE/OrganoidDataset/val.csv'),
 ('CafDataset',
  'DBetaVAE',
  'dim5',
  '/data/PycharmProjects/cytof_benchmark/results/latent_data/dim5/DBetaVAE/CafDataset/val.csv'),
 ('ChallengeDataset',
  'DBetaVAE',
  'dim5',
  '/data/PycharmProjects/cytof_benchmark/results/latent_data/dim5/DBetaVAE/ChallengeDataset/val.csv'),
 ('OrganoidDataset',
  'DBetaVAE',
  'dim5',
  '/data/PycharmProjects/cytof_benchmark/results/latent_data/dim5/DBetaVAE/OrganoidDataset/val.csv'),
 ('CafDataset',
  'WAE_MMD',
  'dim5',
  '/data/PycharmProjects/cytof_benchmark/

In [9]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm
import datasets

metadata_dict = dict()
save_dir = Path('/data/PycharmProjects/cytof_benchmark/results/latent_plots')

for dataset_name,model_name,dim_name,latent_path in tqdm(latents_list):
    if dataset_name not in metadata_dict:
        dataset_class = getattr(datasets, dataset_name)
        dataset_dir = dataset_dirs[dataset_name]

        dataset = dataset_class(data_dir=dataset_dir)
        metadata_dict[dataset_name] = dataset.val[1]

    y_val = metadata_dict[dataset_name]
    latent_df = pd.read_csv(latent_path,index_col=0)
    latent_dim = int(dim_name[-1])

    for variable_name in y_val.columns:

        sns.set(rc={'figure.figsize':(12,12)})
        dot_plot = sns.scatterplot(x="VAE1", y="VAE2",
                                   data=pd.concat([latent_df, y_val], axis=1).head(100000),
                                   hue=variable_name,
                                   legend=True,
                                   alpha=0.5)
        plot_file = save_dir / dim_name / model_name / dataset_name / (variable_name +'_VAE12.png')
        plot_file.parent.mkdir(parents=True, exist_ok=True)
        plt.savefig(plot_file)
        plt.clf()

        # UMAP plot
    if latent_dim > 2 or model_name == 'HyperSphericalVAE':
        import umap
        import umap.plot
        from bokeh.io import save
        umap_points = 10000
        u = pd.DataFrame(umap.UMAP(low_memory=False, n_jobs=8).fit_transform(latent_df[:umap_points]),columns=['UMAP1','UMAP2'])
        for variable_name in y_val.columns:
            sns.set(rc={'figure.figsize':(12,12)})
            dot_plot = sns.scatterplot(x='UMAP1', y='UMAP2',
                                   data=pd.concat([u, y_val], axis=1).head(umap_points),
                                   hue=variable_name,
                                   legend=True,
                                   alpha=0.5)
            plt.savefig(save_dir / dim_name / model_name / dataset_name / (variable_name +'_umap.png'))
            plt.clf()

    # Spherical plots
    if model_name == 'HyperSphericalVAE' and latent_dim == 2:
        import cartopy
        import geopandas
        from shapely.geometry import Point
        import plotly.express as px

        sphere = geopandas.GeoSeries(latent_df.apply(Point, axis=1), crs=cartopy.crs.Geocentric())
        projected = sphere.to_crs(cartopy.crs.Mollweide())
        for variable_name in y_val.columns:
            data = pd.DataFrame({'x': projected.x,
                                'y': projected.y,
                                variable_name: y_val[variable_name]
                                 })

            # 2d projection plot
            sns.scatterplot(x="x", y="y",
                    data=data.head(10000),
                    hue=variable_name,  # color by cluster
                    legend=True,
                    alpha=0.5)
            plt.axis('equal')
            plt.savefig(save_dir / dim_name / model_name / dataset_name / (variable_name +'_projected.png'))
            plt.clf()
            # 3d point cloud plot
            plot_data = pd.concat([latent_df, y_val], axis=1).head(10000)
            fig = px.scatter_3d(plot_data, x='VAE1', y='VAE2', z='VAE3', color=variable_name)
            fig.update_traces(marker_size=2)
            fig.write_html(save_dir / dim_name / model_name / dataset_name / (variable_name +'_point-cloud.html'))

  arr = construct_1d_object_array_from_listlike(values)

The array interface is deprecated and will no longer work in Shapely 2.0. Convert the '.coords' to a numpy array instead.


The array interface is deprecated and will no longer work in Shapely 2.0. Convert the '.coords' to a numpy array instead.

100%|██████████| 36/36 [46:20<00:00, 77.24s/it]   


<Figure size 1200x1200 with 0 Axes>