In [31]:
import wandb
import sys
import matplotlib.pyplot as plt
import scprep
import pandas as pd
sys.path.append('../src/')
from evaluate import get_results
from omegaconf import OmegaConf
from main import load_data, make_model
import numpy as np
import os
import glob
import demap

In [2]:
# Initialize wandb (replace 'your_entity' and 'your_project' with your specific details)
wandb.login()
api = wandb.Api()

# Specify your entity, project, and sweep ID
entity = "xingzhis"
project = "dmae"
sweep_id = 'gutpmsw1'

# Fetch the sweep
sweep = api.sweep(f"{entity}/{project}/{sweep_id}")

run_ids = [run.id for run in sweep.runs]

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mxingzhis[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [57]:
from main import load_data, make_model
from data import dataloader_from_pc
from procrustes import Procrustes

from transformations import LogTransform, NonTransform, StandardScaler, MinMaxScaler, PowerTransformer, KernelTransform

from omegaconf import OmegaConf
import numpy as np
import os
import glob
from scipy.spatial.distance import pdist, squareform

def get_results(run):
    cfg = OmegaConf.create(run.config)
    folder_path = "../src/wandb/"
    try:
        folder_list = glob.glob(f"{folder_path}*{run.id}*")
        ckpt_files = glob.glob(f"{folder_list[0]}/files/*.ckpt")
        ckpt_path = ckpt_files[0]
    except:
        print(f"No checkpoint found for run {run.id}")
        return None, None, None
    allloader, _, X, phate_coords, colors, dist, pp = load_data(cfg, load_all=True)
    emb_dim = phate_coords.shape[1]
    data_path = os.path.join(cfg.data.root, cfg.data.name + cfg.data.filetype)
    data = np.load(data_path, allow_pickle=True)
    dist_std = np.std(data['dist'].flatten())
    model = make_model(cfg, X.shape[1], emb_dim, pp, dist_std, from_checkpoint=True, checkpoint_path=ckpt_path)
    model.eval()
    x_all = next(iter(allloader))['x']
    x_pred, z_pred = model(x_all)
    x_pred = x_pred.detach().cpu().numpy()
    z_pred = z_pred.detach().cpu().numpy()
    data_all = data
    data_path_train = os.path.join(cfg.data.root, cfg.data.name + cfg.data.filetype)
    train_mask = data_all['is_train']
    test_mask = ~data_all['is_train']
    procrustes = Procrustes()
    phate_proc_train, z_hat, disparity = procrustes.fit_transform(data_all['phate'][train_mask], z_pred[train_mask])
    zhat_all = procrustes.transform(z_pred)
    dist_pred = squareform(pdist(zhat_all))
    dist_true = squareform(pdist(data_all['phate']))
    test_test_mask = test_mask[:,None] * test_mask[None,:]
    test_train_mask = test_mask[:,None] * train_mask[None,:]
    train_train_mask = train_mask[:,None] * train_mask[None,:]
    test_all_mask = test_mask[:,None] * np.ones_like(test_mask)
    eps = 1e-10
    dist_mape_test_test = (np.abs(dist_true - dist_pred + eps) / (dist_true + eps) * test_test_mask).sum() / test_test_mask.sum()
    dist_mape_test_train = (np.abs(dist_true - dist_pred + eps) / (dist_true + eps) * test_train_mask).sum() / test_train_mask.sum()
    dist_mape_train_train = (np.abs(dist_true - dist_pred + eps) / (dist_true + eps) * train_train_mask).sum() / train_train_mask.sum()
    dist_mape_test_overall = (np.abs(dist_true - dist_pred + eps) / (dist_true + eps) * test_all_mask).sum() / test_all_mask.sum()
    dist_rmse_test_test = np.sqrt(((dist_true - dist_pred)**2 * test_test_mask).sum()/ test_test_mask.sum())
    dist_rmse_test_train = np.sqrt(((dist_true - dist_pred)**2 * test_train_mask).sum() / test_train_mask.sum())
    dist_rmse_train_train = np.sqrt(((dist_true - dist_pred)**2 * train_train_mask).sum() / train_train_mask.sum())
    test_rmse = np.sqrt((data_all['phate'][test_mask] - zhat_all[test_mask])**2).mean()
    res = dict(
        data=cfg.data.name,
        preprocess=cfg.data.preprocess,
        kernel=cfg.data.kernel.type if cfg.data.preprocess == 'kernel' else None,
        sigma=cfg.data.kernel.sigma if cfg.data.preprocess == 'kernel' else 0,
        dist_recon_weight = cfg.model.dist_reconstr_weights,
        model_type = cfg.model.type,
        dist_mape_test_test=dist_mape_test_test,
        dist_mape_test_train=dist_mape_test_train,
        dist_mape_test_overall=dist_mape_test_overall,
        dist_mape_train_train=dist_mape_train_train,
        dist_rmse_test_test=dist_rmse_test_test,
        dist_rmse_test_train=dist_rmse_test_train,
        dist_rmse_train_train=dist_rmse_train_train,
        test_rmse=test_rmse,
        train_mask=train_mask
    )
    plot_data = dict(
        phate_true = data_all['phate'][test_mask],
        phate_pred = zhat_all[test_mask],
        colors = data_all['colors'][test_mask],
        colors_train = data_all['colors'][train_mask],
        dist_true_test_test = dist_true[test_mask][:,test_mask],
        dist_pred_test_test = dist_pred[test_mask][:,test_mask],
        dist_true_test_train = dist_true[test_mask][:,train_mask],
        dist_pred_test_train = dist_pred[test_mask][:,train_mask],
        phate_true_train = data_all['phate'][train_mask],
        phate_pred_train = zhat_all[train_mask],
        dist_true_train_train = dist_true[train_mask][:,train_mask],
        dist_pred_train_train = dist_pred[train_mask][:,train_mask],
    )
    return res, plot_data, cfg

def rename_string(s):
    # Split the string into parts
    parts = s.split('_')
    
    # Replace "noisy" with "true"
    parts[0] = "true"
    
    # Remove the last two numbers before "all"
    new_parts = parts[:-3] + parts[-1:]
    
    # Reassemble the string
    new_s = '_'.join(new_parts)
    
    return new_s

def get_data_config(s):
    # Split the string into parts
    parts = s.split('_')
 
    
    seedmethod = parts[2]+','+parts[1]
    bcv=parts[-3]
    dropout=parts[-2]
    return seedmethod, bcv, dropout

In [37]:
res_list = []
for run in sweep.runs:
    res, plots, cfg = get_results(run)
    res_list.append(
        dict(
            run_id=run.id,
            res=res,
            plots=plots,
            cfg=cfg
        )
    )

In [13]:
res_list[0]['plots'].keys()

dict_keys(['phate_true', 'phate_pred', 'colors', 'colors_train', 'dist_true_test_test', 'dist_pred_test_test', 'dist_true_test_train', 'dist_pred_test_train', 'phate_true_train', 'phate_pred_train', 'dist_true_train_train', 'dist_pred_train_train'])

In [42]:
res_list[0]['res']['train_mask']

0.8

In [70]:
metric_res = []
for i in range(len(res_list)):
    datatrue = np.load("../synthetic_data/" + rename_string(res_list[i]['res']['data']) + '.npz')
    datatrue_train = datatrue['data'][datatrue['is_train']]
    datatrue_test = datatrue['data'][~datatrue['is_train']]
    phate_train = res_list[i]['plots']['phate_true_train']
    phate_test = res_list[i]['plots']['phate_true']
    our_train = res_list[i]['plots']['phate_pred_train']
    our_test = res_list[i]['plots']['phate_pred']
    demap_phate_train = demap.DEMaP(datatrue_train, phate_train)
    demap_our_train = demap.DEMaP(datatrue_train, our_train)
    demap_phate_test = demap.DEMaP(datatrue_test, phate_test)
    demap_our_test = demap.DEMaP(datatrue_test, our_test)
    acc_our_train = 1 - res_list[i]['res']['dist_mape_train_train']
    acc_our_test = 1 - res_list[i]['res']['dist_mape_test_test']
    name = res_list[i]['res']['data']
    seedmethod, bcv, dropout = get_data_config(res_list[i]['res']['data'])
    metric_res.append(dict(
        dataset=seedmethod,
        bcv=bcv,
        dropout=dropout,
        acc_our_train=acc_our_train,
        acc_our_test=acc_our_test,
        demap_phate_train=demap_phate_train,
        demap_our_train=demap_our_train,
        demap_our_test=demap_our_test,
    ))

In [71]:
res_df = pd.DataFrame(metric_res)

In [72]:
res_df.sort_values(['dataset', 'bcv', 'dropout'])

Unnamed: 0,dataset,bcv,dropout,acc_our_train,acc_our_test,demap_phate_train,demap_our_train,demap_our_test
3,"groups,42",0.2,0.5,0.733464,0.684586,0.822904,0.843121,0.754205
6,"groups,42",0.4,0.7,0.724118,0.640231,0.743948,0.759209,0.689044
5,"groups,43",0.2,0.5,0.657552,0.607719,0.602204,0.581684,0.579229
2,"groups,43",0.4,0.7,0.615755,0.57272,0.571864,0.625931,0.681179
4,"paths,42",0.2,0.5,0.899945,0.788487,0.760293,0.764848,0.685775
0,"paths,42",0.4,0.7,0.858103,0.668171,0.544,0.557985,0.517827
7,"paths,43",0.2,0.5,0.880611,0.700362,0.778929,0.782762,0.565475
1,"paths,43",0.4,0.7,0.792614,0.486109,0.567703,0.57848,0.40926


In [73]:
res_df.to_csv("synth_results.csv", index=False)

In [74]:
res_df = res_df.sort_values(['dataset', 'bcv', 'dropout'])
# Round all numeric columns to 3 decimals, excluding strings
rounded_res_df = res_df.select_dtypes(include=['float64']).round(3)
# Re-attach the non-numeric columns to the rounded DataFrame
for col in res_df.select_dtypes(exclude=['float64']).columns:
    rounded_res_df[col] = res_df[col]

# Reorder columns to match original DataFrame
rounded_res_df = rounded_res_df[res_df.columns]
rounded_res_df

Unnamed: 0,dataset,bcv,dropout,acc_our_train,acc_our_test,demap_phate_train,demap_our_train,demap_our_test
3,"groups,42",0.2,0.5,0.733,0.685,0.823,0.843,0.754
6,"groups,42",0.4,0.7,0.724,0.64,0.744,0.759,0.689
5,"groups,43",0.2,0.5,0.658,0.608,0.602,0.582,0.579
2,"groups,43",0.4,0.7,0.616,0.573,0.572,0.626,0.681
4,"paths,42",0.2,0.5,0.9,0.788,0.76,0.765,0.686
0,"paths,42",0.4,0.7,0.858,0.668,0.544,0.558,0.518
7,"paths,43",0.2,0.5,0.881,0.7,0.779,0.783,0.565
1,"paths,43",0.4,0.7,0.793,0.486,0.568,0.578,0.409
