In [1]:
import sys
if ".." not in sys.path:
    sys.path.insert(0, "..")

In [8]:
from configs.pbt import beta_vae_pbt,dbeta_vae_pbt,wae_pbt,hs_vae_pbt

model_configs = {
    "BetaVAE":beta_vae_pbt.get_config(),
    "DBetaVAE":dbeta_vae_pbt.get_config(),
    "WAE_MMD":wae_pbt.get_config(),
    "HyperSphericalVAE":hs_vae_pbt.get_config(),
}

dataset_dirs = {
    'OrganoidDataset':'/data/PycharmProjects/cytof_benchmark/data/organoids',
    'CafDataset':'/data/PycharmProjects/cytof_benchmark/data/caf',
    'ChallengeDataset':'/data/PycharmProjects/cytof_benchmark/data/breast_cancer_challenge',
}

In [9]:
import glob
bench_dir = "/home/egor/Desktop/ray_tune/pbt_bench/"
checkpoint_files = glob.glob(bench_dir + "*/*/*/model.pth")

In [10]:
checkpoint_files

['/home/egor/Desktop/ray_tune/pbt_bench/dim5/HyperSphericalVAE/ChallengeDataset/model.pth',
 '/home/egor/Desktop/ray_tune/pbt_bench/dim5/HyperSphericalVAE/CafDataset/model.pth',
 '/home/egor/Desktop/ray_tune/pbt_bench/dim5/HyperSphericalVAE/OrganoidDataset/model.pth',
 '/home/egor/Desktop/ray_tune/pbt_bench/dim5/DBetaVAE/ChallengeDataset/model.pth',
 '/home/egor/Desktop/ray_tune/pbt_bench/dim5/DBetaVAE/CafDataset/model.pth',
 '/home/egor/Desktop/ray_tune/pbt_bench/dim5/DBetaVAE/OrganoidDataset/model.pth',
 '/home/egor/Desktop/ray_tune/pbt_bench/dim5/BetaVAE/ChallengeDataset/model.pth',
 '/home/egor/Desktop/ray_tune/pbt_bench/dim5/BetaVAE/CafDataset/model.pth',
 '/home/egor/Desktop/ray_tune/pbt_bench/dim5/BetaVAE/OrganoidDataset/model.pth',
 '/home/egor/Desktop/ray_tune/pbt_bench/dim5/WAE_MMD/ChallengeDataset/model.pth',
 '/home/egor/Desktop/ray_tune/pbt_bench/dim5/WAE_MMD/CafDataset/model.pth',
 '/home/egor/Desktop/ray_tune/pbt_bench/dim5/WAE_MMD/OrganoidDataset/model.pth',
 '/home/ego

In [11]:
checkpoint_list = list()
for checkpoint_file in checkpoint_files:
    dataset = checkpoint_file.split('/')[-2]
    model = checkpoint_file.split('/')[-3]
    dim = checkpoint_file.split('/')[-4]
    checkpoint_list.append((dataset,model,dim,checkpoint_file))

In [12]:
checkpoint_list

[('ChallengeDataset',
  'HyperSphericalVAE',
  'dim5',
  '/home/egor/Desktop/ray_tune/pbt_bench/dim5/HyperSphericalVAE/ChallengeDataset/model.pth'),
 ('CafDataset',
  'HyperSphericalVAE',
  'dim5',
  '/home/egor/Desktop/ray_tune/pbt_bench/dim5/HyperSphericalVAE/CafDataset/model.pth'),
 ('OrganoidDataset',
  'HyperSphericalVAE',
  'dim5',
  '/home/egor/Desktop/ray_tune/pbt_bench/dim5/HyperSphericalVAE/OrganoidDataset/model.pth'),
 ('ChallengeDataset',
  'DBetaVAE',
  'dim5',
  '/home/egor/Desktop/ray_tune/pbt_bench/dim5/DBetaVAE/ChallengeDataset/model.pth'),
 ('CafDataset',
  'DBetaVAE',
  'dim5',
  '/home/egor/Desktop/ray_tune/pbt_bench/dim5/DBetaVAE/CafDataset/model.pth'),
 ('OrganoidDataset',
  'DBetaVAE',
  'dim5',
  '/home/egor/Desktop/ray_tune/pbt_bench/dim5/DBetaVAE/OrganoidDataset/model.pth'),
 ('ChallengeDataset',
  'BetaVAE',
  'dim5',
  '/home/egor/Desktop/ray_tune/pbt_bench/dim5/BetaVAE/ChallengeDataset/model.pth'),
 ('CafDataset',
  'BetaVAE',
  'dim5',
  '/home/egor/Deskto

In [19]:
from pathlib import Path
from tqdm import tqdm
import torch
import pandas as pd
import datasets
import models


save_dir = Path('/data/PycharmProjects/cytof_benchmark/results')

for dataset_name,model_name,dim_name,model_checkpoint_path in tqdm(checkpoint_list):
    dataset_class = getattr(datasets, dataset_name)
    dataset_dir = dataset_dirs[dataset_name]

    dataset = dataset_class(data_dir=dataset_dir)
    dataset_features = dataset.variables.shape[0]
    latent_dim = int(dim_name[-1])
    if model_name == 'HyperSphericalVAE':
        latent_dim = latent_dim+1

    model_class = getattr(models, model_name)
    config = model_configs[model_name]
    with config.unlocked():
        config.in_features = dataset_features
        config.latent_dim = latent_dim

    model = model_class(config).to('cuda')
    model_checkpoint = torch.load(model_checkpoint_path)
    model.load_state_dict(model_checkpoint['model'])

    splits = ['train','val', 'test']

    for split in splits:
        X,y = getattr(dataset, split)
        X_batches = torch.split(torch.Tensor(X).to('cuda'), split_size_or_sections=32*1024)

        latent_vals = []
        decoded_batches = []
        with torch.no_grad():
            for X_batch in X_batches:
                latent_val_batch = model.latent(X_batch).to('cpu')
                latent_vals.append(latent_val_batch)

                decoded_batch = model.forward(X_batch)[0].to('cpu')
                decoded_batches.append(decoded_batch)

        latent_val = torch.cat(latent_vals)
        decoded = torch.cat(decoded_batches)

        latent_df = pd.DataFrame(latent_val.numpy(), columns=["VAE{}".format(i) for i in range(1, latent_val.shape[1] + 1)])
        latent_df_file = save_dir / 'latent_data' / dim_name / model_name / dataset_name / (split+'.csv')
        latent_df_file.parent.mkdir(parents=True, exist_ok=True)
        latent_df.to_csv(latent_df_file)

        mse_df = pd.DataFrame(((decoded-torch.Tensor(X))**2).numpy(), columns=list(dataset.variables))
        mse_df_file = save_dir / 'mse_data' / dim_name / model_name / dataset_name / (split+'.csv')
        mse_df_file.parent.mkdir(parents=True, exist_ok=True)
        mse_df.to_csv(mse_df_file)

100%|██████████| 36/36 [2:39:30<00:00, 265.86s/it]
