In [1]:
import wandb
import sys
import matplotlib.pyplot as plt
import scprep
import pandas as pd
sys.path.append('../src/')
from evaluate import get_results
from omegaconf import OmegaConf
import numpy as np
import os
import glob
import demap
from tqdm import tqdm
from evaluation import compute_all_metrics, get_noiseless_name, get_ambient_name
import torch
from model import AEProb, Decoder

class Model():
    def __init__(self, encoder, decoder):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.encoder = encoder.to(self.device)
        self.decoder = decoder.to(self.device)
    def encode(self, x):
        return self.encoder.encode(x)
    def decode(self, x):
        return self.decoder(x)
    def eval(self):
        self.encoder.eval()
        self.decoder.eval()

In [2]:
string = 'sepa_gaussian_jsd_a1.0_knn5_noisy_3_groups_17580_3000_1_0.25_0.5_all.npz'
# noisy_1_groups_17580_3000_1_0.18_0.5_all

In [3]:
import re

# Regex pattern to extract the values
# pattern = r"sepa_(?P<prob_method>\w+)_a(?P<alpha>[\d.]+)_knn(?P<knn>\d+)_(?P<noisy_path>.+)"
pattern = r"sepa_(?P<prob_method>\w+)_a(?P<alpha>[\d.]+)_knn(?P<knn>\d+)_(?P<noisy_path>.+)"

# Perform regex search
match = re.search(pattern, string)

if match:
    # Extracting the values
    prob_method = match.group("prob_method")
    alpha = match.group("alpha")
    knn = match.group("knn")
    noisy_path = match.group("noisy_path")
    
    print(f"prob_method: {prob_method}")
    print(f"alpha: {alpha}")
    print(f"knn: {knn}")
    print(f"noisy_path: {noisy_path}")
else:
    print("No match found. Please check the string format.")


prob_method: gaussian_jsd
alpha: 1.0
knn: 5
noisy_path: noisy_3_groups_17580_3000_1_0.25_0.5_all.npz


In [4]:
root_path = '/gpfs/gibbs/pi/krishnaswamy_smita/dl2282/dmae/results'
data_path1 = string
enc_path = os.path.join(root_path, data_path1, 'model.ckpt')
dec_path = os.path.join(root_path, data_path1, 'decoder.ckpt')

In [5]:
encoder_dict = torch.load(enc_path)
decoder_dict = torch.load(dec_path)

In [6]:
data_name = noisy_path[:-4]

In [7]:
probmtd = prob_method

In [8]:
data_root = '../synthetic_data3/'
data_path = os.path.join(data_root, data_name + '.npz')
noiseless_path = os.path.join(data_root, get_noiseless_name(data_name) + '.npz')
ambient_path = os.path.join(data_root, get_ambient_name(data_name) + '.npy')

In [9]:
data_name

'noisy_3_groups_17580_3000_1_0.25_0.5_all'

In [10]:
encoder = AEProb(dim=100, emb_dim=2, layer_widths=[256, 128, 64], activation_fn=torch.nn.ReLU(), prob_method=probmtd, dist_reconstr_weights=[1.0,0.0,0.], )
encoder.load_state_dict(encoder_dict)
decoder = Decoder(dim=100, emb_dim=2, layer_widths=[256, 128, 64][::-1], activation_fn=torch.nn.ReLU())
decoder.load_state_dict(decoder_dict)

<All keys matched successfully>

In [11]:
model = Model(encoder, decoder)

In [12]:
res_dict = compute_all_metrics(model, data_path, noiseless_path, ambient_path)

In [13]:
res_dict

{'seed': 'groups',
 'method': '3',
 'bcv': '0.25',
 'dropout': '0.5',
 'demap': 0.8075374435383558,
 'accuracy': 0.05195010848583048,
 'recon score': 0.6449547852158134}