## Setup

In [1]:
import yaml
import pathlib
import pickle as pk
from copy import deepcopy

import torch
import numpy as np
import seaborn as sns
import pandas as pd
import scipy.io as sio
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.colors import ListedColormap
import umap
from sklearn.decomposition import PCA
from sklearn.metrics import r2_score
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis as QDA
from sklearn.dummy import DummyClassifier

import utils
from data import MET_Data
from cplAE_MET.models.model_classes import MultiModal
from pca_cca import PCA_CCA, CCA_extended

In [100]:
def get_reconstruction(model, met_data, specimen_ids, in_modal, out_modal, out_exists = False):
    model.eval()
    all_data = met_data.get_specimens(specimen_ids)
    (in_data, out_data) = (all_data[f"{in_modal}_dat"], all_data[f"{out_modal}_dat"])
    valid_in = ~np.isnan(in_data).reshape([in_data.shape[0], -1]).any(1)
    valid_out = (~np.isnan(out_data).reshape([out_data.shape[0], -1]).any(1) if out_exists else valid_in)
    valid_specimens = specimen_ids[valid_in & valid_out]
    modal_data = met_data.get_specimens(valid_specimens)
    recon = model(torch.from_numpy(modal_data[f"{in_modal}_dat"]).float(), in_modal, [out_modal])[1][0].detach().numpy()
    return (recon, modal_data)

def get_reconstruction_cca(model, met_data, specimen_ids, in_modal, out_modal, out_exists = False):
    all_data = met_data.get_specimens(specimen_ids)
    (in_data, out_data) = (all_data[f"{in_modal}_dat"], all_data[f"{out_modal}_dat"])
    valid_in = ~np.isnan(in_data).reshape([in_data.shape[0], -1]).any(1)
    valid_out = (~np.isnan(out_data).reshape([out_data.shape[0], -1]).any(1) if out_exists else valid_in)
    valid_specimens = specimen_ids[valid_in & valid_out]
    modal_data = met_data.get_specimens(valid_specimens)
    recon = model(modal_data[f"{in_modal}_dat"], in_modal, [out_modal])[1][0]
    return (recon, modal_data)

def get_latent(model, met_data, specimen_ids, in_modal):
    model.eval()
    in_data = met_data.get_specimens(specimen_ids)[f"{in_modal}_dat"]
    valid_specimens = specimen_ids[~np.isnan(in_data).reshape([in_data.shape[0], -1]).any(1)]
    modal_data = met_data.get_specimens(valid_specimens)
    torch_input = torch.from_numpy(modal_data[f"{in_modal}_dat"]).float()
    latent = model.modal_arms[in_modal].encoder(torch_input).detach().numpy()
    return (latent, modal_data)

def get_r2_scores(exp_recon_dict, variance_dict, var_thresh = 1e-2):
    scores = {}
    for (modal_string, (recon, modal_data)) in exp_recon_dict.items():
        out_modal = modal_string.split("->")[1]
        num_samples = recon.shape[0]
        orig_flat = modal_data[f"{out_modal}_dat"].reshape(num_samples, -1)
        recon_flat = recon.reshape(num_samples, -1)
        squares = np.square(orig_flat - recon_flat).mean(0)
        variances = variance_dict[out_modal].flatten()
        r2_error = 1 - (squares[variances > var_thresh] / variances[variances > var_thresh]).mean()
        
        scores[modal_string] = r2_error #r2_score(orig_flat, recon_flat)
    return scores

## Models

In [13]:
met_data = MET_Data("../data/raw/MET_M120x4_50k_4Apr23.mat")

In [20]:
model_types = ["t_arm", "e_arm", "m_arm", "t_e_arms", "t_m_arms", "e_m_arms", "met"]

In [21]:
patchseq_experiments = {model: utils.load_cross_validation(f"patchseq/{model}") for model in model_types}
full_experiments = {model: utils.load_cross_validation(f"all/{model}") for model in model_types}

In [14]:
cca_paths = ["pca-cca/t_e_cca", "pca-cca/t_m_cca", "pca-cca/e_m_cca"]
cca_experiments = {path.split("/")[-1]: utils.load_pca_cca(path) for path in cca_paths}

## Within- and Cross-modality Reconstruction

### Setup

In [15]:
cca_recons = {}
for (exp_name, exp_dict) in cca_experiments.items():
    cca_recons[exp_name] = []
    for fold in exp_dict["folds"]:
        fold_recon = {}
        met_ids = met_data.query(fold["test_ids"], ["M", "E", "T"])["specimen_id"]
        modalities = fold["model"].pca.keys()
        for in_modal in modalities:
            for out_modal in modalities:
                print(f"Generating {exp_name}: {in_modal} -> {out_modal}             ", end = "\r") 
                (recon, modal_data) = get_reconstruction_cca(
                    fold["model"], met_data, met_ids, in_modal, out_modal,
                    out_exists = True)
                fold_recon[f"{in_modal}->{out_modal}"] = (recon, modal_data)
        cca_recons[exp_name].append(fold_recon)
print("Complete                      ")

Complete                               


In [23]:
patchseq_recons = {}
for (exp_name, exp_dict) in patchseq_experiments.items():
    patchseq_recons[exp_name] = []
    for fold in exp_dict["folds"]:
        fold_recon = {}
        met_ids = met_data.query(fold["test_ids"], ["M", "E", "T"])["specimen_id"]
        model = utils.load_model(exp_dict["config"], fold["best"])
        modalities = model.modal_arms.keys()
        for in_modal in modalities:
            for out_modal in modalities:
                print(f"Generating {exp_name}: {in_modal} -> {out_modal}             ", end = "\r") 
                (recon, modal_data) = get_reconstruction(
                    model, met_data, met_ids, in_modal, out_modal,
                    out_exists = True)
                fold_recon[f"{in_modal}->{out_modal}"] = (recon, modal_data)
        patchseq_recons[exp_name].append(fold_recon)
print("Complete                      ")

Complete                                


In [22]:
full_recons = {}
for (exp_name, exp_dict) in full_experiments.items():
    full_recons[exp_name] = []
    for fold in exp_dict["folds"]:
        fold_recon = {}
        met_ids = met_data.query(fold["test_ids"], ["M", "E", "T"])["specimen_id"]
        model = utils.load_model(exp_dict["config"], fold["best"])
        modalities = model.modal_arms.keys()
        for in_modal in modalities:
            for out_modal in modalities:
                print(f"Generating {exp_name}: {in_modal} -> {out_modal}             ", end = "\r") 
                (recon, modal_data) = get_reconstruction(
                    model, met_data, met_ids, in_modal, out_modal,
                    out_exists = True)
                fold_recon[f"{in_modal}->{out_modal}"] = (recon, modal_data)
        full_recons[exp_name].append(fold_recon)
print("Complete                      ")

Complete                                


In [81]:
patchseq_met = met_data.query(platforms = ["patchseq"])

In [82]:
variances = {modal: np.nanvar(patchseq_met[f"{modal}_dat"], 0) for modal in ["T", "E", "M"]}

In [101]:
patchseq_r2_scores = {exp_name: [get_r2_scores(fold_dict, variances) for fold_dict in folds] 
                      for (exp_name, folds) in patchseq_recons.items()}

In [102]:
full_r2_scores = {exp_name: [get_r2_scores(fold_dict, variances) for fold_dict in folds] 
                  for (exp_name, folds) in full_recons.items()}

In [103]:
cca_r2_scores = {exp_name: [get_r2_scores(fold_dict, variances) for fold_dict in folds] 
                 for (exp_name, folds) in cca_recons.items()}

### R<sup>2</sup> Coefficient Table for Patch-seq

In [104]:
modal_pairs = ["TT", "ET", "MT", "EE", "TE", "ME", "MM", "TM", "EM"]
within_r2 = {f"{modal_1}->{modal_2}": {} for (modal_1, modal_2) in modal_pairs}
for (exp_name, folds) in cca_r2_scores.items():
    for (modal_string, within_dict) in within_r2.items():
        if modal_string in folds[0]:
           score = np.mean([fold[modal_string] for fold in folds]) 
        else:
            score = None
        within_dict[exp_name] = score
for (exp_name, folds) in patchseq_r2_scores.items():
    for (modal_string, within_dict) in within_r2.items():
        if modal_string in folds[0]:
           score = np.mean([fold[modal_string] for fold in folds]) 
        else:
            score = None
        within_dict[exp_name] = score
pd.DataFrame(within_r2)

Unnamed: 0,T->T,E->T,M->T,E->E,T->E,M->E,M->M,T->M,E->M
t_e_cca,0.145437,0.124554,,0.373183,0.344749,,,,
t_m_cca,0.153231,,0.14406,,,,0.073518,0.067221,
e_m_cca,,,,0.306293,,0.232611,0.063007,,0.02913
t_arm,0.358143,,,,,,,,
e_arm,,,,0.624935,,,,,
m_arm,,,,,,,0.17815,,
t_e_arms,0.331904,0.259951,,0.597182,0.417483,,,,
t_m_arms,0.265497,,0.147143,,,,0.055401,0.040687,
e_m_arms,,,,0.415159,,0.15592,0.000398,,-0.007571
met,0.216672,0.179421,0.120372,0.414419,0.38301,0.223044,-0.02963,-0.030471,-0.030518


In [54]:
modal_pairs = ["TT", "ET", "MT", "EE", "TE", "ME", "MM", "TM", "EM"]
within_r2 = {f"{modal_1}->{modal_2}": {} for (modal_1, modal_2) in modal_pairs}
for (exp_name, folds) in cca_r2_scores.items():
    for (modal_string, within_dict) in within_r2.items():
        if modal_string in folds[0]:
           score = np.mean([fold[modal_string] for fold in folds]) 
        else:
            score = None
        within_dict[exp_name] = score
for (exp_name, folds) in patchseq_r2_scores.items():
    for (modal_string, within_dict) in within_r2.items():
        if modal_string in folds[0]:
           score = np.mean([fold[modal_string] for fold in folds]) 
        else:
            score = None
        within_dict[exp_name] = score
pd.DataFrame(within_r2)

Unnamed: 0,T->T,E->T,M->T,E->E,T->E,M->E,M->M,T->M,E->M
t_e_cca,0.146859,0.125522,,0.257949,0.234637,,,,
t_m_cca,0.147487,,0.138798,,,,0.213016,0.212966,
e_m_cca,,,,0.199256,,0.134695,0.216991,,0.201873
t_arm,0.356217,,,,,,,,
e_arm,,,,0.540027,,,,,
m_arm,,,,,,,0.161475,,
t_e_arms,0.331695,0.260052,,0.510501,0.329609,,,,
t_m_arms,0.260944,,0.138994,,,,0.343702,0.323212,
e_m_arms,,,,0.328,,0.079304,0.289481,,0.222242
met,0.215083,0.177697,0.114069,0.326717,0.298446,0.148884,0.301757,0.310927,0.304517


In [31]:
modal_pairs = ["TT", "ET", "MT", "EE", "TE", "ME", "MM", "TM", "EM"]
within_r2 = {f"{modal_1}->{modal_2}": {} for (modal_1, modal_2) in modal_pairs}
for (exp_name, folds) in cca_r2_scores.items():
    for (modal_string, within_dict) in within_r2.items():
        if modal_string in folds[0]:
           score = np.mean([fold[modal_string] for fold in folds]) 
        else:
            score = None
        within_dict[exp_name] = score
for (exp_name, folds) in full_r2_scores.items():
    for (modal_string, within_dict) in within_r2.items():
        if modal_string in folds[0]:
           score = np.mean([fold[modal_string] for fold in folds]) 
        else:
            score = None
        within_dict[exp_name] = score
pd.DataFrame(within_r2)

Unnamed: 0,T->T,E->T,M->T,E->E,T->E,M->E,M->M,T->M,E->M
t_e_cca,0.146859,0.125522,,0.257949,0.234637,,,,
t_m_cca,0.147487,,0.138798,,,,0.213016,0.212966,
e_m_cca,,,,0.199256,,0.134695,0.216991,,0.201873
t_arm,0.36061,,,,,,,,
e_arm,,,,0.535555,,,,,
m_arm,,,,,,,0.416186,,
t_e_arms,0.330705,0.251205,,0.506575,0.328856,,,,
t_m_arms,0.28721,,0.158505,,,,0.321931,0.30185,
e_m_arms,,,,0.391264,,0.07572,-1.038954,,-0.751023
met,0.295118,0.245474,0.113445,0.396075,0.330677,0.082092,-2.824914,-0.651965,-1.074103


## Latent Spaces

### Setup

In [None]:
baselines = {"X": [], "Y": [], "Modality": [], "Types": []}
for (path, modal) in zip(["t_arm", "e_arm", "m_arm"], ["T", "E", "M"]):
    exp_dict = load_all(f"baselines/{path}")
    (latent, modal_data) = get_latent(exp_dict["model"], exp_dict["data"], exp_dict["experiment"]["test_ids"], modal)
    labeled = ~np.isnan(modal_data["cluster_id"])
    proj = PCA(n_components = 2).fit_transform(latent)
    baselines["X"].append(proj[:, 0][labeled])
    baselines["Y"].append(proj[:, 1][labeled])
    baselines["Modality"].append(np.full([proj[labeled].shape[0]], modal))
    baselines["Types"].append(modal_data["merged_cluster_label_at80"][labeled])
baselines = {key: np.concatenate(arr_list) for (key, arr_list) in baselines.items()}