In [1]:
import wandb
import sys
import matplotlib.pyplot as plt
import scprep
import pandas as pd
sys.path.append('../src/')
from evaluate import get_results
from omegaconf import OmegaConf
import numpy as np
import os
import glob
import demap
from tqdm import tqdm
from evaluation import compute_all_metrics, get_noiseless_name, get_ambient_name
import torch
from model import AEProb, Decoder

class Model():
    def __init__(self, encoder, decoder):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.encoder = encoder.to(self.device)
        self.decoder = decoder.to(self.device)
    def encode(self, x):
        return self.encoder.encode(x)
    def decode(self, x):
        return self.decoder(x)
    def eval(self):
        self.encoder.eval()
        self.decoder.eval()

In [3]:
root_path = '../affinity_matching_results_xingzhi/results/'
data_path1 = 'sepa_gaussian_noisy_42_groups_17580_2000_3_0.2_0.2_all1.00_bw1_knn5'

enc_path = os.path.join(root_path, data_path1, 'model.ckpt')
dec_path = os.path.join(root_path, data_path1, 'decoder.ckpt')

In [4]:
encoder_dict = torch.load(enc_path)
decoder_dict = torch.load(dec_path)

In [5]:
data_name = data_path1[14:-13]

In [6]:
probmtd = data_path1.split('_')[1]

In [7]:
data_root = '../synthetic_data2/'
data_path = os.path.join(data_root, data_name + '.npz')
noiseless_path = os.path.join(data_root, get_noiseless_name(data_name) + '.npz')
ambient_path = os.path.join(data_root, get_ambient_name(data_name) + '.npy')

In [21]:
encoder = AEProb(dim=100, emb_dim=2, layer_widths=[256, 128, 64], activation_fn=torch.nn.ReLU(), prob_method=probmtd, dist_reconstr_weights=[1.0,0.0,0.], )
encoder.load_state_dict(encoder_dict)
decoder = Decoder(dim=100, emb_dim=2, layer_widths=[256, 128, 64][::-1], activation_fn=torch.nn.ReLU())
decoder.load_state_dict(decoder_dict)

<All keys matched successfully>

In [22]:
model = Model(encoder, decoder)

In [23]:
res_dict = compute_all_metrics(model, data_path, noiseless_path, ambient_path)

In [24]:
res_dict

{'seedmethod': 'groups,42',
 'bcv': '0.2',
 'dropout': '0.2',
 'demap': 0.7505392434684071,
 'accuracy': 0.0416292262671728,
 'recon score': 0.7032789476825084}

In [25]:
root_path = '../affinity_matching_results_xingzhi/results/'
os.listdir(root_path)

['sepa_tstudent_noisy_43_groups_17580_2000_3_0.4_0.5_all1.00_bw1_knn5',
 'sepa_tstudent_noisy_44_groups_17580_2000_3_0.4_0.2_all1.00_bw1_knn5',
 'sepa_tstudent_noisy_46_groups_17580_2000_3_0.6_0.7_all1.00_bw1_knn5',
 'sepa_tstudent_noisy_45_groups_17580_2000_3_0.6_0.7_all1.00_bw1_knn5',
 'sepa_gaussian_noisy_46_groups_17580_2000_3_0.2_0.7_all1.00_bw1_knn5',
 'sepa_tstudent_noisy_42_paths_17580_2000_3_0.4_0.7_all1.00_bw1_knn5',
 'sepa_gaussian_noisy_46_paths_17580_2000_3_0.2_0.2_all1.00_bw1_knn5',
 'sepa_tstudent_noisy_43_paths_17580_2000_3_0.4_0.2_all1.00_bw1_knn5',
 'sepa_gaussian_noisy_46_paths_17580_2000_3_0.4_0.7_all1.00_bw1_knn5',
 'sepa_tstudent_noisy_46_groups_17580_2000_3_0.6_0.5_all1.00_bw1_knn5',
 'sepa_gaussian_noisy_42_paths_17580_2000_3_0.6_0.2_all1.00_bw1_knn5',
 'sepa_gaussian_noisy_44_groups_17580_2000_3_0.4_0.5_all1.00_bw1_knn5',
 'sepa_gaussian_noisy_42_groups_17580_2000_3_0.2_0.7_all1.00_bw1_knn5',
 'sepa_gaussian_noisy_45_groups_17580_2000_3_0.6_0.2_all1.00_bw1_knn5