In [1]:
import sys
if ".." not in sys.path:
    sys.path.insert(0, "..")

In [32]:
from datasets import OrganoidDataset

dataset = OrganoidDataset(data_dir='/data/PycharmProjects/cytof_benchmark/data/organoids')

import torch

X_val, y_val = dataset.val
X_val_batches = torch.split(torch.Tensor(X_val).to('cuda'), split_size_or_sections=32*1024)

In [14]:
y_val

Unnamed: 0,index,cell_type,day
0,125964,Enterocyte,2
1,573521,Enterocyte,7
2,1112662,Tuft,5
3,1058543,Tuft,2
4,1031398,Stem,7
...,...,...,...
234490,857805,Stem,4
234491,167125,Enterocyte,2
234492,680457,Enteroendocrine,6
234493,139701,Enterocyte,2


In [4]:
from configs.pbt import beta_vae_pbt,dbeta_vae_pbt,wae_pbt,hs_vae_pbt

In [5]:
model_names = ["BetaVAE", "DBetaVAE", "WAE_MMD", "HyperSphericalVAE"]
configs = [beta_vae_pbt.get_config(),dbeta_vae_pbt.get_config(),wae_pbt.get_config(),hs_vae_pbt.get_config()]

dataset_names = ['OrganoidDataset', 'CafDataset', 'ChallengeDataset']
data_dirs = ['data/organoids', 'data/caf', 'data/breast_cancer_challenge']
features = [41, 44, 37]

In [6]:
import glob
bench_dir = "/home/egor/Desktop/ray_tune/pbt_bench/"
checkpoint_files = glob.glob(bench_dir + "*/*/model.pth")

In [7]:
checkpoint_files

['/home/egor/Desktop/ray_tune/pbt_bench/HyperSphericalVAE/ChallengeDataset/model.pth',
 '/home/egor/Desktop/ray_tune/pbt_bench/HyperSphericalVAE/CafDataset/model.pth',
 '/home/egor/Desktop/ray_tune/pbt_bench/HyperSphericalVAE/OrganoidDataset/model.pth',
 '/home/egor/Desktop/ray_tune/pbt_bench/DBetaVAE/ChallengeDataset/model.pth',
 '/home/egor/Desktop/ray_tune/pbt_bench/DBetaVAE/CafDataset/model.pth',
 '/home/egor/Desktop/ray_tune/pbt_bench/DBetaVAE/OrganoidDataset/model.pth',
 '/home/egor/Desktop/ray_tune/pbt_bench/BetaVAE/ChallengeDataset/model.pth',
 '/home/egor/Desktop/ray_tune/pbt_bench/BetaVAE/CafDataset/model.pth',
 '/home/egor/Desktop/ray_tune/pbt_bench/BetaVAE/OrganoidDataset/model.pth',
 '/home/egor/Desktop/ray_tune/pbt_bench/WAE_MMD/ChallengeDataset/model.pth',
 '/home/egor/Desktop/ray_tune/pbt_bench/WAE_MMD/CafDataset/model.pth',
 '/home/egor/Desktop/ray_tune/pbt_bench/WAE_MMD/OrganoidDataset/model.pth']

In [8]:
checkpoint_dict = dict()
for checkpoint_file in checkpoint_files:
    dataset = checkpoint_file.split('/')[-2]
    model = checkpoint_file.split('/')[-3]
    if dataset not in checkpoint_dict:
        checkpoint_dict[dataset]=dict()
    checkpoint_dict[dataset][model]= checkpoint_file

In [9]:
checkpoint_dict

{'ChallengeDataset': {'HyperSphericalVAE': '/home/egor/Desktop/ray_tune/pbt_bench/HyperSphericalVAE/ChallengeDataset/model.pth',
  'DBetaVAE': '/home/egor/Desktop/ray_tune/pbt_bench/DBetaVAE/ChallengeDataset/model.pth',
  'BetaVAE': '/home/egor/Desktop/ray_tune/pbt_bench/BetaVAE/ChallengeDataset/model.pth',
  'WAE_MMD': '/home/egor/Desktop/ray_tune/pbt_bench/WAE_MMD/ChallengeDataset/model.pth'},
 'CafDataset': {'HyperSphericalVAE': '/home/egor/Desktop/ray_tune/pbt_bench/HyperSphericalVAE/CafDataset/model.pth',
  'DBetaVAE': '/home/egor/Desktop/ray_tune/pbt_bench/DBetaVAE/CafDataset/model.pth',
  'BetaVAE': '/home/egor/Desktop/ray_tune/pbt_bench/BetaVAE/CafDataset/model.pth',
  'WAE_MMD': '/home/egor/Desktop/ray_tune/pbt_bench/WAE_MMD/CafDataset/model.pth'},
 'OrganoidDataset': {'HyperSphericalVAE': '/home/egor/Desktop/ray_tune/pbt_bench/HyperSphericalVAE/OrganoidDataset/model.pth',
  'DBetaVAE': '/home/egor/Desktop/ray_tune/pbt_bench/DBetaVAE/OrganoidDataset/model.pth',
  'BetaVAE': '/

In [29]:
organoid_models = checkpoint_dict['OrganoidDataset']

In [34]:
organoid_models

{'HyperSphericalVAE': '/home/egor/Desktop/ray_tune/pbt_bench/HyperSphericalVAE/OrganoidDataset/model.pth',
 'DBetaVAE': '/home/egor/Desktop/ray_tune/pbt_bench/DBetaVAE/OrganoidDataset/model.pth',
 'BetaVAE': '/home/egor/Desktop/ray_tune/pbt_bench/BetaVAE/OrganoidDataset/model.pth',
 'WAE_MMD': '/home/egor/Desktop/ray_tune/pbt_bench/WAE_MMD/OrganoidDataset/model.pth'}

In [38]:
import models
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import os

save_dir ='/data/PycharmProjects/cytof_benchmark/results/latents'

for config, model_name in zip(configs,model_names):
    with config.unlocked():
        config.in_features = 41

    model_class = getattr(models, model_name)
    model = model_class(config).to('cuda')

    plot_name = 'OrganoidDataset/{}_dim{}'.format(model_name,config.latent_dim)

    checkpoint = torch.load(organoid_models[model_name])
    model.load_state_dict(checkpoint['model'])
    latent_vals = []
    with torch.no_grad():
        for X_batch in X_val_batches:
            latent_val_batch = model.latent(X_batch).to('cpu')
            latent_vals.append(latent_val_batch)

    latent_val = torch.cat(latent_vals)
    latent_df = pd.DataFrame(latent_val.numpy(), columns=["VAE{}".format(i) for i in range(1, latent_val.shape[1] + 1)])

    sns.set(rc={'figure.figsize':(12,12)})
    dot_plot = sns.scatterplot(x="VAE1", y="VAE2",
                            data=pd.concat([latent_df, y_val], axis=1).head(100000),
                          hue='cell_type',
                          legend=True,
                          alpha=0.5)
    plt.savefig(os.path.join(save_dir, plot_name+'_cell_type_latent.png'))
    plt.clf()

    sns.set(rc={'figure.figsize':(12,12)})
    dot_plot = sns.scatterplot(x="VAE1", y="VAE2",
                            data=pd.concat([latent_df, y_val], axis=1).head(100000),
                          hue='day',
                          legend=True,
                          alpha=0.5)
    plt.savefig(os.path.join(save_dir, plot_name+'_day_latent.png'))
    plt.clf()

        # UMAP plot
    if config.latent_dim > 2:
        import umap
        import umap.plot
        from bokeh.io import save

        mapper = umap.UMAP(low_memory=False,n_jobs=64).fit(latent_df)
        umap_points = 10000
        mapper.embedding_ = mapper.embedding_[:umap_points]

        p = umap.plot.interactive(mapper,
                                  labels=y_val['cell_type'].head(umap_points),
                                  hover_data=y_val.head(umap_points),
                                  point_size=2,
                                  interactive_text_search=True,
                                  interactive_text_search_columns=['cell_type', "day"])
        save(p, filename=os.path.join(save_dir, plot_name+'latent_umap.html'))

    # Spherical plots
    if config.model == 'HyperSphericalVAE' and config.latent_dim == 3:
        import cartopy
        import geopandas
        from shapely.geometry import Point
        import plotly.express as px

        sphere = geopandas.GeoSeries(latent_df.apply(Point, axis=1), crs=cartopy.crs.Geocentric())
        projected = sphere.to_crs(cartopy.crs.Mollweide())
        data = pd.DataFrame({'x': projected.x,
                             'y': projected.y,
                             'cell_type': y_val['cell_type'],
                             'day': y_val['day']})
        # 2d projection plot
        sns.scatterplot(x="x", y="y",
                   data=data.head(10000),
                   hue='cell_type',  # color by cluster
                   legend=True,
                   alpha=0.5)
        plt.axis('equal')
        plt.savefig(os.path.join(save_dir, plot_name+'latent_projection.png'))
        plt.clf()
        # 3d point cloud plot
        plot_data = pd.concat([latent_df, y_val], axis=1).head(10000)
        fig = px.scatter_3d(plot_data, x='VAE1', y='VAE2', z='VAE3', color='cell_type')
        fig.update_traces(marker_size=2)
        fig.write_html(os.path.join(save_dir, plot_name+'latent_point_cloud.html'))


<class 'models.hyperspherical_vae_extra.distributions.hyperspherical_uniform.HypersphericalUniform'> does not define `arg_constraints`. Please set `arg_constraints = {}` or initialize the distribution with `validate_args=False` to turn off validation.


save() called but no resources were supplied and output_file(...) was never called, defaulting to resources.CDN


save() called but no title was supplied and output_file(...) was never called, using default title 'Bokeh Plot'


The array interface is deprecated and will no longer work in Shapely 2.0. Convert the '.coords' to a numpy array instead.



<Figure size 1200x1200 with 0 Axes>

In [39]:
from datasets import CafDataset

dataset = CafDataset(data_dir='/data/PycharmProjects/cytof_benchmark/data/caf')

X_val, y_val = dataset.val
X_val_batches = torch.split(torch.Tensor(X_val).to('cuda'), split_size_or_sections=32*1024)
y_val

Unnamed: 0,index,Date,Patient,Culture,Treatment,Concentration,Replicate,Cell_type,pPKCa,Plate,Batch
0,2840618,20210518,5,PDO,S,2,C,PDOs,48.466751,SLV,1
1,6904274,20210330,27,PDOF,VS,3,A,PDOs,,SLV,1
2,5478019,20210420,75,PDO,S,1,A,PDOs,,SLV,1
3,7230165,20210525,216,PDOF,VS,2,C,Fibs,8.220644,SLV,1
4,7643022,20210524,109,PDOF,VS,2,A,Fibs,27.408182,SLV,1
...,...,...,...,...,...,...,...,...,...,...,...
4951201,5232280,20210608,11,F,S,1,C,Fibs,56.341122,SLV,1
4951202,17554307,20210524,109,PDOF,F,3,B,PDOs,70.352081,CSF,1
4951203,8748653,20210518,5,PDOF,L,1,A,Fibs,4.410349,SLV,1
4951204,17634951,20210524,109,F,F,4,B,Fibs,24.214802,CSF,1


In [40]:
caf_models = checkpoint_dict['CafDataset']
caf_models

{'HyperSphericalVAE': '/home/egor/Desktop/ray_tune/pbt_bench/HyperSphericalVAE/CafDataset/model.pth',
 'DBetaVAE': '/home/egor/Desktop/ray_tune/pbt_bench/DBetaVAE/CafDataset/model.pth',
 'BetaVAE': '/home/egor/Desktop/ray_tune/pbt_bench/BetaVAE/CafDataset/model.pth',
 'WAE_MMD': '/home/egor/Desktop/ray_tune/pbt_bench/WAE_MMD/CafDataset/model.pth'}

In [43]:
for config, model_name in zip(configs,model_names):
    with config.unlocked():
        config.in_features = 44
    model_class = getattr(models, model_name)
    model = model_class(config).to('cuda')

    plot_name = 'CAFDataset/{}_dim{}'.format(model_name,config.latent_dim)

    checkpoint = torch.load(caf_models[model_name])
    model.load_state_dict(checkpoint['model'])
    latent_vals = []
    with torch.no_grad():
        for X_batch in X_val_batches:
            latent_val_batch = model.latent(X_batch).to('cpu')
            latent_vals.append(latent_val_batch)

    latent_val = torch.cat(latent_vals)
    latent_df = pd.DataFrame(latent_val.numpy(), columns=["VAE{}".format(i) for i in range(1, latent_val.shape[1] + 1)])

    sns.set(rc={'figure.figsize':(12,12)})
    dot_plot = sns.scatterplot(x="VAE1", y="VAE2",
                            data=pd.concat([latent_df, y_val], axis=1).head(100000),
                          hue='Culture',
                          legend=True,
                          alpha=0.5)
    plt.savefig(os.path.join(save_dir, plot_name+'_culture_latent.png'))
    plt.clf()

    sns.set(rc={'figure.figsize':(12,12)})
    dot_plot = sns.scatterplot(x="VAE1", y="VAE2",
                            data=pd.concat([latent_df, y_val], axis=1).head(100000),
                          hue='Treatment',
                          legend=True,
                          alpha=0.5)
    plt.savefig(os.path.join(save_dir, plot_name+'_treatment_latent.png'))
    plt.clf()

    # UMAP plot
    if config.latent_dim > 2:
        import umap
        import umap.plot
        from bokeh.io import save

        mapper = umap.UMAP(low_memory=False,n_jobs=64).fit(latent_df)
        umap_points = 10000
        mapper.embedding_ = mapper.embedding_[:umap_points]

        p = umap.plot.interactive(mapper,
                                  labels=y_val['Treatment'].head(umap_points),
                                  hover_data=y_val.head(umap_points),
                                  point_size=2,
                                  interactive_text_search=True,
                                  interactive_text_search_columns=['Treatment', "Culture"])
        save(p, filename=os.path.join(save_dir, plot_name+'latent_umap.html'))

    # Spherical plots
    if config.model == 'HyperSphericalVAE' and config.latent_dim == 3:
        import cartopy
        import geopandas
        from shapely.geometry import Point
        import plotly.express as px

        sphere = geopandas.GeoSeries(latent_df.apply(Point, axis=1), crs=cartopy.crs.Geocentric())
        projected = sphere.to_crs(cartopy.crs.Mollweide())
        data = pd.DataFrame({'x': projected.x,
                             'y': projected.y,
                             'Treatment': y_val['Treatment'],
                             'Culture': y_val['Culture']})
        # 2d projection plot
        sns.scatterplot(x="x", y="y",
                   data=data.head(10000),
                   hue='Treatment',  # color by cluster
                   legend=True,
                   alpha=0.5)
        plt.axis('equal')
        plt.savefig(os.path.join(save_dir, plot_name+'latent_projection.png'))
        plt.clf()
        # 3d point cloud plot
        plot_data = pd.concat([latent_df, y_val], axis=1).head(10000)
        fig = px.scatter_3d(plot_data, x='VAE1', y='VAE2', z='VAE3', color='Treatment')
        fig.update_traces(marker_size=2)
        fig.write_html(os.path.join(save_dir, plot_name+'latent_point_cloud.html'))


<class 'models.hyperspherical_vae_extra.distributions.hyperspherical_uniform.HypersphericalUniform'> does not define `arg_constraints`. Please set `arg_constraints = {}` or initialize the distribution with `validate_args=False` to turn off validation.


Exited at iteration 20 with accuracies 
[0.01159898 0.01145096 0.01146064]
not reaching the requested tolerance 1e-08.


save() called but no resources were supplied and output_file(...) was never called, defaulting to resources.CDN


save() called but no title was supplied and output_file(...) was never called, using default title 'Bokeh Plot'


The array interface is deprecated and will no longer work in Shapely 2.0. Convert the '.coords' to a numpy array instead.



<Figure size 1200x1200 with 0 Axes>

In [44]:
from datasets import ChallengeDataset

dataset = ChallengeDataset(data_dir='/data/PycharmProjects/cytof_benchmark/data/breast_cancer_challenge')

X_val, y_val = dataset.val
X_val_batches = torch.split(torch.Tensor(X_val).to('cuda'), split_size_or_sections=32*1024)
y_val

Unnamed: 0,index,treatment,cell_line,time,cellID,fileID
0,855587,full,MCF10F,0.0,28353,2000
1,364135,iPKC,HBL100,60.0,3115,1531
2,356676,iPI3K,OCUBM,0.0,4744,2464
3,248118,iMEK,HBL100,40.0,4960,1513
4,142322,iMEK,EVSAT,0.0,898,1892
...,...,...,...,...,...,...
2204463,224723,iMEK,184A1,13.0,2845,2676
2204464,121261,EGF,CAL851,60.0,36,1451
2204465,350172,iPKC,CAL51,0.0,6718,1217
2204466,350589,iPKC,MCF10F,17.0,2249,395


In [45]:
challenge_models = checkpoint_dict['ChallengeDataset']
challenge_models

{'HyperSphericalVAE': '/home/egor/Desktop/ray_tune/pbt_bench/HyperSphericalVAE/ChallengeDataset/model.pth',
 'DBetaVAE': '/home/egor/Desktop/ray_tune/pbt_bench/DBetaVAE/ChallengeDataset/model.pth',
 'BetaVAE': '/home/egor/Desktop/ray_tune/pbt_bench/BetaVAE/ChallengeDataset/model.pth',
 'WAE_MMD': '/home/egor/Desktop/ray_tune/pbt_bench/WAE_MMD/ChallengeDataset/model.pth'}

In [47]:
for config, model_name in zip(configs,model_names):
    with config.unlocked():
        config.in_features = 37
    model_class = getattr(models, model_name)
    model = model_class(config).to('cuda')

    plot_name = 'ChallengeDataset/{}_dim{}'.format(model_name,config.latent_dim)

    checkpoint = torch.load(challenge_models[model_name])
    model.load_state_dict(checkpoint['model'])
    latent_vals = []
    with torch.no_grad():
        for X_batch in X_val_batches:
            latent_val_batch = model.latent(X_batch).to('cpu')
            latent_vals.append(latent_val_batch)

    latent_val = torch.cat(latent_vals)
    latent_df = pd.DataFrame(latent_val.numpy(), columns=["VAE{}".format(i) for i in range(1, latent_val.shape[1] + 1)])

    sns.set(rc={'figure.figsize':(12,12)})
    dot_plot = sns.scatterplot(x="VAE1", y="VAE2",
                            data=pd.concat([latent_df, y_val], axis=1).head(100000),
                          hue='cell_line',
                          legend=True,
                          alpha=0.5)
    plt.savefig(os.path.join(save_dir, plot_name+'_cell_line_latent.png'))
    plt.clf()

    sns.set(rc={'figure.figsize':(12,12)})
    dot_plot = sns.scatterplot(x="VAE1", y="VAE2",
                            data=pd.concat([latent_df, y_val], axis=1).head(100000),
                          hue='treatment',
                          legend=True,
                          alpha=0.5)
    plt.savefig(os.path.join(save_dir, plot_name+'_treatment_latent.png'))
    plt.clf()

    # UMAP plot
    if config.latent_dim > 2:
        import umap
        import umap.plot
        from bokeh.io import save

        mapper = umap.UMAP(low_memory=False,n_jobs=64).fit(latent_df)
        umap_points = 10000
        mapper.embedding_ = mapper.embedding_[:umap_points]

        p = umap.plot.interactive(mapper,
                                  labels=y_val['cell_line'].head(umap_points),
                                  hover_data=y_val.head(umap_points),
                                  point_size=2,
                                  interactive_text_search=True,
                                  interactive_text_search_columns=['cell_line', "treatment"])
        save(p, filename=os.path.join(save_dir, plot_name+'latent_umap.html'))

    # Spherical plots
    if config.model == 'HyperSphericalVAE' and config.latent_dim == 3:
        import cartopy
        import geopandas
        from shapely.geometry import Point
        import plotly.express as px

        sphere = geopandas.GeoSeries(latent_df.apply(Point, axis=1), crs=cartopy.crs.Geocentric())
        projected = sphere.to_crs(cartopy.crs.Mollweide())
        data = pd.DataFrame({'x': projected.x,
                             'y': projected.y,
                             'treatment': y_val['treatment'],
                             'cell_line': y_val['cell_line']})
        # 2d projection plot
        sns.scatterplot(x="x", y="y",
                   data=data.head(10000),
                   hue='cell_line',  # color by cluster
                   legend=True,
                   alpha=0.5)
        plt.axis('equal')
        plt.savefig(os.path.join(save_dir, plot_name+'latent_projection.png'))
        plt.clf()
        # 3d point cloud plot
        plot_data = pd.concat([latent_df, y_val], axis=1).head(10000)
        fig = px.scatter_3d(plot_data, x='VAE1', y='VAE2', z='VAE3', color='cell_line')
        fig.update_traces(marker_size=2)
        fig.write_html(os.path.join(save_dir, plot_name+'latent_point_cloud.html'))


<class 'models.hyperspherical_vae_extra.distributions.hyperspherical_uniform.HypersphericalUniform'> does not define `arg_constraints`. Please set `arg_constraints = {}` or initialize the distribution with `validate_args=False` to turn off validation.


Exited at iteration 20 with accuracies 
[0.01132264 0.0113911  0.01136983]
not reaching the requested tolerance 1e-08.


save() called but no resources were supplied and output_file(...) was never called, defaulting to resources.CDN


save() called but no title was supplied and output_file(...) was never called, using default title 'Bokeh Plot'


The array interface is deprecated and will no longer work in Shapely 2.0. Convert the '.coords' to a numpy array instead.



<Figure size 1200x1200 with 0 Axes>