In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np

embeddings_names = [
    "blood_cell_scrna",
    "lymphoma",
    "cifar_100",
    "mnist",
]
sigs = [
    [(1, 2), (0, 2), (-1, 2), (-1, 2), (-1, 2)],
    [(1, 2), (1, 2)],
    [(1, 2), (1, 2), (1, 2), (1, 2)],
    [(1, 2), (0, 2), (-1, 2)],
]
n_trials = 10
sets = ["train", "test"]
datasets = ["X", "y"]

bad = []
for embedding in embeddings_names:
    for trial in range(n_trials):
        for set_name in sets:
            for dataset in datasets:
                my_data = np.load(f"../data/{embedding}/embeddings/{dataset}_{set_name}_{trial}.npy")
                if np.isnan(my_data).any():
                    bad.append((embedding, trial, set_name, dataset))
                    print(embedding, trial, set_name, dataset)
                # print(my_data.shape)
print(bad)

[]


In [None]:
import embedders
import pandas as pd
from tqdm.notebook import tqdm
import torch

# N_SAMPLES = 100 # Takes ~20 secs
N_SAMPLES = 1_000 # Takes ~5 mins
# N_SAMPLES = float("inf")  # Takes ~1 hour
MAX_DEPTH = 5
N_FEATURES = "d_choose_2"

results = []
my_tqdm = tqdm(total=len(embeddings_names) * n_trials)
for embedding, sig in zip(embeddings_names, sigs):
    pm = embedders.manifolds.ProductManifold(signature=sig, device="cuda")
    for trial in range(n_trials):
        X_train = np.load(f"../data/{embedding}/embeddings/X_train_{trial}.npy")
        y_train = np.load(f"../data/{embedding}/embeddings/y_train_{trial}.npy")
        X_test = np.load(f"../data/{embedding}/embeddings/X_test_{trial}.npy")
        y_test = np.load(f"../data/{embedding}/embeddings/y_test_{trial}.npy")

        # Randomly subsample
        if len(X_train) > N_SAMPLES:
            idx = np.random.choice(X_train.shape[0], N_SAMPLES, replace=False)
            X_train = X_train[idx]
            y_train = y_train[idx]

        if len(X_test) > N_SAMPLES:
            idx = np.random.choice(X_test.shape[0], N_SAMPLES, replace=False)
            X_test = X_test[idx]
            y_test = y_test[idx]
        
        # Make tensors
        X_train = torch.tensor(X_train, dtype=torch.float32, device="cuda")
        y_train = torch.tensor(y_train, dtype=torch.long, device="cuda")
        X_test = torch.tensor(X_test, dtype=torch.float32, device="cuda")
        y_test = torch.tensor(y_test, dtype=torch.long, device="cuda")
        
        # Get A_train and A_test
        D_train = pm.pdist2(X_train)
        max_train_dist = D_train[D_train.isfinite()].max()
        D_train /= max_train_dist
        A_train = embedders.predictors.kappa_gcn.get_A_hat(torch.exp(-D_train))
        A_test = embedders.predictors.kappa_gcn.get_A_hat(torch.exp(-pm.pdist2(X_test) / max_train_dist))

        res = embedders.benchmarks.benchmark(
            X=None,
            y=None,
            X_train=X_train,
            X_test=X_test,
            y_train=y_train,
            y_test=y_test,
            pm=pm,
            A_train=A_train,
            A_test=A_test,
            # models=["sklearn_dt", "product_dt"],
            # max_depth=MAX_DEPTH,
            # batch_size=1,
            # n_features=N_FEATURES,
            device="cuda",
            task="classification",
            score=["accuracy", "f1-micro"],
        )
        res["embedding"] = embedding
        res["trial"] = trial

        results.append(res)
        my_tqdm.update(1)

results = pd.DataFrame(results)

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

Exception in thread Thread-4:
Traceback (most recent call last):
  File "/home/phil/miniconda3/envs/embedders/lib/python3.9/threading.py", line 950, in _bootstrap_inner
    self.run()
  File "/home/phil/miniconda3/envs/embedders/lib/python3.9/site-packages/tqdm/_monitor.py", line 69, in run
    instances = self.get_instances()
  File "/home/phil/miniconda3/envs/embedders/lib/python3.9/site-packages/tqdm/_monitor.py", line 49, in get_instances
    return [i for i in self.tqdm_cls._instances.copy()
  File "/home/phil/miniconda3/envs/embedders/lib/python3.9/_weakrefset.py", line 93, in copy
    return self.__class__(self)
  File "/home/phil/miniconda3/envs/embedders/lib/python3.9/_weakrefset.py", line 51, in __init__
    self.update(data)
  File "/home/phil/miniconda3/envs/embedders/lib/python3.9/_weakrefset.py", line 120, in update
    for element in other:
  File "/home/phil/miniconda3/envs/embedders/lib/python3.9/_weakrefset.py", line 61, in __iter__
    for itemref in self.data:
Runti

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

In [5]:
results.to_csv("embedders/data/results_icml/vae.tsv", sep="\t", index=False)

In [None]:
results.groupby("embedding").mean()

Unnamed: 0_level_0,sklearn_dt_accuracy,sklearn_dt_f1-micro,sklearn_dt_time,sklearn_rf_accuracy,sklearn_rf_f1-micro,sklearn_rf_time,product_dt_accuracy,product_dt_f1-micro,product_dt_time,product_rf_accuracy,...,kappa_gcn_accuracy,kappa_gcn_f1-micro,kappa_gcn_time,product_mlr_accuracy,product_mlr_f1-micro,product_mlr_time,trial,ps_svm_accuracy,ps_svm_f1-micro,ps_svm_time
embedding,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
blood_cell_scrna,0.1754,0.1754,0.009795,0.1746,0.1746,0.143847,0.1442,0.1442,0.030121,0.1575,...,0.1246,0.1246,61.697719,0.1065,0.1065,11.371682,4.5,,,
cifar_100,0.0841,0.0841,0.006801,0.0987,0.0987,0.283782,0.0959,0.0959,0.023256,0.102,...,0.0542,0.0542,81.909793,0.0574,0.0574,11.120485,4.5,,,
lymphoma,0.8043,0.8043,0.004803,0.8086,0.8086,0.508018,0.83,0.83,0.036812,0.8245,...,0.5578,0.5578,45.832055,0.5977,0.5977,6.296381,4.5,,,
mnist,0.2796,0.2796,0.005322,0.3435,0.3435,0.123957,0.2971,0.2971,0.023986,0.3364,...,0.1138,0.1138,45.947609,0.1769,0.1769,7.121199,4.5,0.120333,0.120333,2480.684172
