## Setup

In [1]:
import yaml
import pathlib
import pickle as pk
from copy import deepcopy

import torch
import numpy as np
import seaborn as sns
import pandas as pd
import scipy.io as sio
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.colors import ListedColormap
import umap
from sklearn.decomposition import PCA
from sklearn.metrics import r2_score
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis as QDA
from sklearn.dummy import DummyClassifier

import utils
from data import MET_Data
from losses import ReconstructionLoss
from pca_cca import PCA_CCA, CCA_extended

In [2]:
old_format_maps = {"T": ["logcpm"], "E": ["pca-ipfx"], "M": ["arbors"]}

def get_latent(encoder, met_data, specimen_ids, formats):
    data = met_data.query(specimen_ids, formats = [tuple(formats)])
    torch_input = {form: torch.from_numpy(data[form]).float() for form in formats}
    latent = encoder(torch_input).detach().numpy()
    return (latent, data["specimen_ids"])

def get_experiment_reconstructions(exp_dict, met_data, specimen_ids, valid_modalities = []):
    valid_modalities = exp_dict["config"]["modalities"] if valid_modalities is None else valid_modalities
    modalities = set(valid_modalities) & set(exp_dict["config"]["modalities"])
    fold_recons = []
    folds = exp_dict["folds"]
    folds = dict(enumerate(folds, 1)) if type(folds) == list else folds
    for (i, fold) in folds.items():
        fold_recon = {}
        non_train_ids = met_data["specimen_id"][~np.isin(met_data["specimen_id"], fold["train_ids"])]
        model = fold["best"] if "best" in fold else fold["model"].models
        for in_modal in modalities:
            for out_modal in modalities:
                if "formats" not in exp_dict["config"]:
                    (in_formats, out_formats) = (old_format_maps[in_modal], old_format_maps[out_modal])
                    encoder = lambda x_forms: model[in_modal]["enc"](x_forms[in_formats[0]])
                    decoder = lambda z: {out_formats[0]: model[out_modal]["dec"](z)}
                else:
                    (in_formats, out_formats) = (config["formats"][in_modal], config["formats"][out_modal])
                    (encoder, decoder) = (model[in_modal]["enc"], model[out_modal]["dec"])
                valid_data = met_data.query(specimen_ids, formats = [tuple(in_formats + out_formats)])
                if len(valid_data[in_formats[0]]):
                    print(f"Generating {exp_name} - {i}: {in_modal} -> {out_modal}             ", end = "\r")
                    input_data = {form: torch.from_numpy(valid_data[form]).float() for form in in_formats}
                    raw_recons = decoder(encoder(input_data))
                    recons = {form: tensor.detach().numpy() for (form, tensor) in raw_recons.items()}
                    fold_recon[f"{in_modal}->{out_modal}"] = (recons, valid_data["specimen_id"])
        fold_recons.append(fold_recon)
    return fold_recons

def get_reconstruction_scores(exp_recon_dict, loss_func, met_data):
    scores = {}
    for (modal_string, (recon_forms, recon_specimen_ids)) in exp_recon_dict.items():
        data = met_data.get_specimens(recon_specimen_ids)
        tensor_orig = {form: torch.from_numpy(data[form]) for form in recon_forms}
        tensor_recon = {form: torch.from_numpy(arr) for (form, arr) in recon_forms.items()}
        # input({form: arr.shape for (form, arr) in tensor_orig.items()})
        # input({form: arr.shape for (form, arr) in tensor_recon.items()})
        loss = loss_func.loss(tensor_orig, tensor_recon)
        scores[modal_string] = 1 - loss
    return scores

def load_experiment(base_dir, exp_name, checkpoints = False):
    exp_path = pathlib.Path(base_dir) / exp_name
    if "cca" in exp_name:
        experiment = utils.load_pca_cca(exp_path)
    else:
        experiment = utils.load_jit_folds(exp_path, get_checkpoints = checkpoints)
    return experiment

## Models

In [3]:
met_data = MET_Data("../data/raw/MET_full_data.npz")

In [21]:
exp_info = {
    "cca": {
        "dir": "../../archive/2-24/patchseq-cca",
        "exps": ["t_e_cca", "t_m_cca", "e_m_cca"]},
    "patch": {
        "dir": "../../archive/2-24/patchseq-mse/",
        "exps": ["t_arm", "e_arm", "m_arm", "t_e_arms", "t_m_arms", "e_m_arms", "met"]},
    "full": {
        "dir": "../../archive/2-24/all-mse/",
        "exps": ["m_arm", "t_m_arms", "e_m_arms", "met"]},
    # "grad": {
    #     "dir": "../data/grad_stop/",
    #     "exps": ["met_all", "met_patchseq", "met_patchseq_control"]},
    # "complete": {
    #     "dir": "../data/full/",
    #     "exps": ["t_m_arms", "met", "met_high_trimodal"]},
    # "smartseq": {
    #     "dir": "../data/smartseq/",
    #     "exps": ["t_arm", "t_e_arms", "t_m_arms", "met"]},
    # "binary": {
    #     "dir": "../data/binary/",
    #     "exps": ["t_arm_both", "t_arm_patch", "t_arm_smart"]}
}

In [22]:
all_experiments = {group: {f"{exp}-{group}": load_experiment(info["dir"], exp) for exp in info["exps"]} 
               for (group, info) in exp_info.items()}

## Reconstructions

In [23]:
reconstruction_accs = {}

In [24]:
targets = {
    "patchseq": {
        "query": {"formats": [("logcpm", "pca-ipfx", "arbors")]},
        "modalities": ["T", "E", "M"],
        "loss": {
            "formats": {"T": ["logcpm"], "E": ["pca-ipfx"], "M": ["arbors"]},
            "losses": {"logcpm": "feature_r2", "pca-ipfx": "feature_r2", "arbors": "sample_r2"},
            "transform": None}
    },
    "EM": {
        "query": {"platforms": ["EM"]},
        "modalities": ["M"],
        "loss": {
            "formats": {"T": ["logcpm"], "E": ["pca-ipfx"], "M": ["arbors"]},
            "losses": {"logcpm": "feature_r2", "pca-ipfx": "feature_r2", "arbors": "sample_r2"},
            "transform": None}
    },
    "smartseq": {
        "query": {"platforms": ["smartseq"]},
        "modalities": ["T"],
        "loss": {
            "formats": {"T": ["logcpm"], "E": ["pca-ipfx"], "M": ["arbors"]},
            "losses": {"logcpm": "feature_r2", "pca-ipfx": "feature_r2", "arbors": "sample_r2"},
            "transform": None}
    }
}

# targets = {
#     "patchseq": {
#         "query": {"formats": [("logcpm", "pca-ipfx", "arbors")]},
#         "modalities": ["T", "E", "M"],
#         "loss": {
#             "formats": {"T": ["logcpm"], "E": ["pca-ipfx"], "M": ["arbors"]},
#             "losses": {"logcpm": "feature_r2", "pca-ipfx": "feature_r2", "arbors": "sample_r2"},
#             "transform": None}
#     },
#     "EM": {
#         "query": {"platforms": ["EM"]},
#         "modalities": ["M"],
#         "loss": {
#             "formats": {"M": ["arbors"]},
#             "losses": {"arbors": "sample_r2"},
#             "transform": None}
#     },
#     "smartseq": {
#         "query": {"platforms": ["smartseq"]},
#         "modalities": ["T"],
#         "loss": {
#             "formats": {"T": ["logcpm"]},
#             "losses": {"logcpm": "feature_r2"},
#             "transform": None}
#     }
# }

In [25]:
for (target, target_info) in targets.items():
    print(f"\nRunning target {target}:")
    specimen_ids = met_data.query(**target_info["query"])["specimen_id"]
    loss_config = {"encoder_cross_grad": False, "device": "cpu", **target_info["loss"]}
    loss_func = ReconstructionLoss(loss_config, met_data, specimen_ids)
    target_accs = reconstruction_accs.setdefault(target, {})
    for (group, experiments) in all_experiments.items():
        group_accs = target_accs.setdefault(group, {})
        for (exp_name, exp_dict) in experiments.items():
            if exp_name not in group_accs:
                folds = get_experiment_reconstructions(exp_dict, met_data, specimen_ids, target_info["modalities"])
                group_accs[exp_name] = [get_reconstruction_scores(fold_dict, loss_func, met_data) for fold_dict in folds]
print("Complete                                                       ")


Running target patchseq:
Generating met-full - 10: T -> T                   
Running target EM:
Generating met-full - 10: M -> M                   
Running target smartseq:
Completeng met-full - 10: T -> T                   


In [26]:
modal_pairs = ["TT", "ET", "MT", "EE", "TE", "ME", "MM", "TM", "EM"]
combined = {key:value for dct in reconstruction_accs["patchseq"].values() for (key, value) in dct.items()}
mean_r2 = {f"{modal_1}->{modal_2}": {} for (modal_1, modal_2) in modal_pairs}
stdv_r2 = {f"{modal_1}->{modal_2}": {} for (modal_1, modal_2) in modal_pairs}
for (exp_name, folds) in combined.items():
    for (modal_string, mean_dict) in mean_r2.items():
        if modal_string in folds[0]:
           score = np.mean([fold[modal_string] for fold in folds]) 
        else:
            score = None
        mean_dict[exp_name] = score
    for (modal_string, stdv_dict) in stdv_r2.items():
        if modal_string in folds[0] and len(folds) > 1:
            stdv = np.std([fold[modal_string] for fold in folds]) 
        else:
            stdv = None
        stdv_dict[exp_name] = stdv
patch_mean_frame = pd.DataFrame(mean_r2).round(2).astype("string").fillna("--")
patch_stdv_frame = pd.DataFrame(stdv_r2).round(2).astype("string").fillna("")
r2_frame = patch_mean_frame + ("±" + patch_stdv_frame).replace("±", "")

modal_pairs = ["TT", "ET", "MT", "EE", "TE", "ME", "MM", "TM", "EM"]
combined = {key:value for dct in reconstruction_accs["EM"].values() for (key, value) in dct.items()}
mean_r2 = {f"{modal_1}->{modal_2}": {} for (modal_1, modal_2) in modal_pairs}
stdv_r2 = {f"{modal_1}->{modal_2}": {} for (modal_1, modal_2) in modal_pairs}
for (exp_name, folds) in combined.items():
    for (modal_string, mean_dict) in mean_r2.items():
        if modal_string in folds[0]:
           score = np.mean([fold[modal_string] for fold in folds]) 
        else:
            score = None
        mean_dict[exp_name] = score
    for (modal_string, stdv_dict) in stdv_r2.items():
        if modal_string in folds[0] and len(folds) > 1:
            stdv = np.std([fold[modal_string] for fold in folds]) 
        else:
            stdv = None
        stdv_dict[exp_name] = stdv
EM_mean_frame = pd.DataFrame(mean_r2)[["M->M"]].round(2).astype("string").fillna("--")
EM_stdv_frame = pd.DataFrame(stdv_r2)[["M->M"]].round(2).astype("string").fillna("")
comb_frame = EM_mean_frame + ("±" + EM_stdv_frame).replace("±", "")

r2_frame["M->M (EM)"] = comb_frame["M->M"]

modal_pairs = ["TT", "ET", "MT", "EE", "TE", "ME", "MM", "TM", "EM"]
combined = {key:value for dct in reconstruction_accs["smartseq"].values() for (key, value) in dct.items()}
mean_r2 = {f"{modal_1}->{modal_2}": {} for (modal_1, modal_2) in modal_pairs}
stdv_r2 = {f"{modal_1}->{modal_2}": {} for (modal_1, modal_2) in modal_pairs}
for (exp_name, folds) in combined.items():
    for (modal_string, mean_dict) in mean_r2.items():
        if modal_string in folds[0]:
           score = np.mean([fold[modal_string] for fold in folds]) 
        else:
            score = None
        mean_dict[exp_name] = score
    for (modal_string, stdv_dict) in stdv_r2.items():
        if modal_string in folds[0] and len(folds) > 1:
            stdv = np.std([fold[modal_string] for fold in folds]) 
        else:
            stdv = None
        stdv_dict[exp_name] = stdv
smart_mean_frame = pd.DataFrame(mean_r2)[["T->T"]].round(2).astype("string").fillna("--")
smart_stdv_frame = pd.DataFrame(stdv_r2)[["T->T"]].round(2).astype("string").fillna("")
smart_frame = smart_mean_frame + ("±" + smart_stdv_frame).replace("±", "")

r2_frame.insert(3, "T->T (Tasic)", smart_frame["T->T"])

In [27]:
r2_frame

Unnamed: 0,T->T,E->T,M->T,T->T (Tasic),E->E,T->E,M->E,M->M,T->M,E->M,M->M (EM)
t_e_cca-cca,0.16±0.0,0.14±0.0,--,-0.29±0.0,0.26±0.0,0.24±0.0,--,--,--,--,--
t_m_cca-cca,0.15±0.0,--,0.14±0.0,-0.31±0.0,--,--,--,0.24±0.0,0.22±0.0,--,0.11±0.0
e_m_cca-cca,--,--,--,--,0.2±0.0,--,0.15±0.0,0.23±0.0,--,0.15±0.0,0.03±0.02
t_arm-patch,0.4±0.01,--,--,0.03±0.01,--,--,--,--,--,--,--
e_arm-patch,--,--,--,--,0.63±0.02,--,--,--,--,--,--
m_arm-patch,--,--,--,--,--,--,--,0.75±0.02,--,--,0.63±0.03
t_e_arms-patch,0.36±0.01,0.32±0.01,--,-0.07±0.01,0.58±0.02,0.51±0.02,--,--,--,--,--
t_m_arms-patch,0.23±0.03,--,0.19±0.03,-0.17±0.04,--,--,--,0.63±0.04,0.5±0.05,--,0.53±0.05
e_m_arms-patch,--,--,--,--,0.36±0.06,--,0.27±0.05,0.57±0.08,--,0.41±0.09,0.35±0.07
met-patch,0.24±0.03,0.23±0.03,0.2±0.03,-0.16±0.02,0.34±0.01,0.32±0.01,0.26±0.01,0.41±0.04,0.35±0.03,0.3±0.02,0.24±0.05


In [16]:
r2_frame

Unnamed: 0,T->T,E->T,M->T,T->T (Tasic),E->E,T->E,M->E,M->M,T->M,E->M,M->M (EM)
t_e_cca-cca,0.16±0.0,0.14±0.0,--,0.16±0.0,0.26±0.0,0.24±0.0,--,--,--,--,--
t_m_cca-cca,0.15±0.0,--,0.14±0.0,0.15±0.0,--,--,--,0.24±0.0,0.22±0.0,--,0.24±0.0
e_m_cca-cca,--,--,--,--,0.2±0.0,--,0.15±0.0,0.23±0.0,--,0.15±0.0,0.23±0.0
t_arm-patch,0.4±0.01,--,--,0.4±0.01,--,--,--,--,--,--,--
e_arm-patch,--,--,--,--,0.63±0.02,--,--,--,--,--,--
m_arm-patch,--,--,--,--,--,--,--,0.75±0.02,--,--,0.75±0.02
t_e_arms-patch,0.36±0.01,0.32±0.01,--,0.36±0.01,0.58±0.02,0.51±0.02,--,--,--,--,--
t_m_arms-patch,0.23±0.03,--,0.19±0.03,0.23±0.03,--,--,--,0.63±0.04,0.5±0.05,--,0.63±0.04
e_m_arms-patch,--,--,--,--,0.36±0.06,--,0.27±0.05,0.57±0.08,--,0.41±0.09,0.57±0.08
met-patch,0.24±0.03,0.23±0.03,0.2±0.03,0.24±0.03,0.34±0.01,0.32±0.01,0.26±0.01,0.41±0.04,0.35±0.03,0.3±0.02,0.41±0.04
