In [1]:
import os.path
import sys
if ".." not in sys.path:
    sys.path.insert(0, "..")

In [26]:
from datasets import OrganoidDataset

dataset = OrganoidDataset(data_dir='/data/PycharmProjects/cytof_benchmark/data/organoids')

import torch

X_val, y_val = dataset.val
X_val_batches = torch.split(torch.Tensor(X_val).to('cuda'), split_size_or_sections=32*1024)

In [3]:
y_val

Unnamed: 0,index,cell_type,day
0,125964,Enterocyte,2
1,573521,Enterocyte,7
2,1112662,Tuft,5
3,1058543,Tuft,2
4,1031398,Stem,7
...,...,...,...
234490,857805,Stem,4
234491,167125,Enterocyte,2
234492,680457,Enteroendocrine,6
234493,139701,Enterocyte,2


In [49]:
from configs.pbt import beta_vae_pbt,dbeta_vae_pbt,wae_pbt,hs_vae_pbt

model_names = ["BetaVAE", "DBetaVAE", "WAE_MMD", "HyperSphericalVAE"]
configs = [beta_vae_pbt.get_config(),dbeta_vae_pbt.get_config(),wae_pbt.get_config(),hs_vae_pbt.get_config()]

dataset_names = ['OrganoidDataset', 'CafDataset', 'ChallengeDataset']
data_dirs = ['/data/PycharmProjects/cytof_benchmark/data/organoids',
             '/data/PycharmProjects/cytof_benchmark/data/caf',
             '/data/PycharmProjects/cytof_benchmark/data/breast_cancer_challenge']
features = [41, 44, 37]

In [5]:
import glob
bench_dir = "/home/egor/Desktop/ray_tune/pbt_bench/"
checkpoint_files = glob.glob(bench_dir + "*/*/model.pth")
checkpoint_files

['/home/egor/Desktop/ray_tune/pbt_bench/HyperSphericalVAE/ChallengeDataset/model.pth',
 '/home/egor/Desktop/ray_tune/pbt_bench/HyperSphericalVAE/CafDataset/model.pth',
 '/home/egor/Desktop/ray_tune/pbt_bench/HyperSphericalVAE/OrganoidDataset/model.pth',
 '/home/egor/Desktop/ray_tune/pbt_bench/DBetaVAE/ChallengeDataset/model.pth',
 '/home/egor/Desktop/ray_tune/pbt_bench/DBetaVAE/CafDataset/model.pth',
 '/home/egor/Desktop/ray_tune/pbt_bench/DBetaVAE/OrganoidDataset/model.pth',
 '/home/egor/Desktop/ray_tune/pbt_bench/BetaVAE/ChallengeDataset/model.pth',
 '/home/egor/Desktop/ray_tune/pbt_bench/BetaVAE/CafDataset/model.pth',
 '/home/egor/Desktop/ray_tune/pbt_bench/BetaVAE/OrganoidDataset/model.pth',
 '/home/egor/Desktop/ray_tune/pbt_bench/WAE_MMD/ChallengeDataset/model.pth',
 '/home/egor/Desktop/ray_tune/pbt_bench/WAE_MMD/CafDataset/model.pth',
 '/home/egor/Desktop/ray_tune/pbt_bench/WAE_MMD/OrganoidDataset/model.pth']

In [6]:
checkpoint_dict = dict()
for checkpoint_file in checkpoint_files:
    dataset_name = checkpoint_file.split('/')[-2]
    model = checkpoint_file.split('/')[-3]
    if dataset_name not in checkpoint_dict:
        checkpoint_dict[dataset_name]=dict()
    checkpoint_dict[dataset_name][model]= checkpoint_file
checkpoint_dict

{'ChallengeDataset': {'HyperSphericalVAE': '/home/egor/Desktop/ray_tune/pbt_bench/HyperSphericalVAE/ChallengeDataset/model.pth',
  'DBetaVAE': '/home/egor/Desktop/ray_tune/pbt_bench/DBetaVAE/ChallengeDataset/model.pth',
  'BetaVAE': '/home/egor/Desktop/ray_tune/pbt_bench/BetaVAE/ChallengeDataset/model.pth',
  'WAE_MMD': '/home/egor/Desktop/ray_tune/pbt_bench/WAE_MMD/ChallengeDataset/model.pth'},
 'CafDataset': {'HyperSphericalVAE': '/home/egor/Desktop/ray_tune/pbt_bench/HyperSphericalVAE/CafDataset/model.pth',
  'DBetaVAE': '/home/egor/Desktop/ray_tune/pbt_bench/DBetaVAE/CafDataset/model.pth',
  'BetaVAE': '/home/egor/Desktop/ray_tune/pbt_bench/BetaVAE/CafDataset/model.pth',
  'WAE_MMD': '/home/egor/Desktop/ray_tune/pbt_bench/WAE_MMD/CafDataset/model.pth'},
 'OrganoidDataset': {'HyperSphericalVAE': '/home/egor/Desktop/ray_tune/pbt_bench/HyperSphericalVAE/OrganoidDataset/model.pth',
  'DBetaVAE': '/home/egor/Desktop/ray_tune/pbt_bench/DBetaVAE/OrganoidDataset/model.pth',
  'BetaVAE': '/

In [51]:
import models
import pandas as pd
import datasets
save_dir ='/data/PycharmProjects/cytof_benchmark/results/mse_data'

for dataset_name,feature_count,data_dir in zip(dataset_names,features,data_dirs):

    dataset_class = getattr(datasets, dataset_name)
    dataset = dataset_class(data_dir=data_dir)
    X_val, y_val = dataset.val
    X_val_batches = torch.split(torch.Tensor(X_val).to('cuda'), split_size_or_sections=32*1024)

    for config, model_name in zip(configs,model_names):
        with config.unlocked():
            config.in_features = feature_count

        model_class = getattr(models, model_name)
        model = model_class(config).to('cuda')

        checkpoint = torch.load(checkpoint_dict[dataset_name][model_name])
        model.load_state_dict(checkpoint['model'])

        decoded_batches = []
        latent_batches = []
        with torch.no_grad():
            for X_batch in X_val_batches:
                decoded_batch = model.forward(X_batch)[0].to('cpu')
                decoded_batches.append(decoded_batch)

                latent_batch = model.latent(X_batch).to('cpu')
                latent_batches.append(latent_batch)

        decoded = torch.cat(decoded_batches)
        latent = torch.cat(latent_batches)

        latent_df = pd.DataFrame(latent.numpy(), columns=["VAE{}".format(i) for i in range(1, latent.shape[1] + 1)])
        mse_df = pd.DataFrame(((decoded-torch.Tensor(X_val))**2).numpy(), columns=list(dataset.variables))

        latent_df.to_csv(os.path.join(save_dir,f'{dataset_name}_{model_name}_latent.csv'))
        mse_df.to_csv(os.path.join(save_dir,f'{dataset_name}_{model_name}_mse.csv'))



In [53]:
import models
import pandas as pd
import datasets
save_dir ='/data/PycharmProjects/cytof_benchmark/results/latent_data'

for dataset_name,feature_count,data_dir in zip(dataset_names,features,data_dirs):

    dataset_class = getattr(datasets, dataset_name)
    dataset = dataset_class(data_dir=data_dir)

    X_train,y_train = dataset.train
    X_val, y_val = dataset.val

    X_train_batches = torch.split(torch.Tensor(X_train).to('cuda'), split_size_or_sections=32*1024)
    X_val_batches = torch.split(torch.Tensor(X_val).to('cuda'), split_size_or_sections=32*1024)

    for config, model_name in zip(configs,model_names):
        with config.unlocked():
            config.in_features = feature_count

        model_class = getattr(models, model_name)
        model = model_class(config).to('cuda')

        checkpoint = torch.load(checkpoint_dict[dataset_name][model_name])
        model.load_state_dict(checkpoint['model'])

        latent_batches = []
        with torch.no_grad():
            for X_batch in X_val_batches:
                latent_batch = model.latent(X_batch).to('cpu')
                latent_batches.append(latent_batch)

        latent = torch.cat(latent_batches)

        latent_df = pd.DataFrame(latent.numpy(), columns=["VAE{}".format(i) for i in range(1, latent.shape[1] + 1)])
        latent_df.to_csv(os.path.join(save_dir,f'{dataset_name}_{model_name}_val_latent.csv'))


        latent_batches = []
        with torch.no_grad():
            for X_batch in X_train_batches:
                latent_batch = model.latent(X_batch).to('cpu')
                latent_batches.append(latent_batch)

        latent = torch.cat(latent_batches)

        latent_df = pd.DataFrame(latent.numpy(), columns=["VAE{}".format(i) for i in range(1, latent.shape[1] + 1)])
        latent_df.to_csv(os.path.join(save_dir,f'{dataset_name}_{model_name}_train_latent.csv'))

