## Setup

In [23]:
import yaml
import pathlib

import torch
import numpy as np
import pandas as pd

import utils
from data import MET_Data, get_transformation_function
from losses import ReconstructionLoss
from pca_cca import PCA_CCA, CCA_extended # Not used explicity, but needed to unpickle PCA-CCA models

In [24]:
old_format_maps = {"T": ["logcpm"], "E": ["pca-ipfx"], "M": ["arbors"]}

def get_latent(encoder, met_data, specimen_ids, formats):
    data = met_data.query(specimen_ids, formats = [tuple(formats)])
    torch_input = {form: torch.from_numpy(data[form]).float() for form in formats}
    latent = encoder(torch_input).detach().numpy()
    return (latent, data["specimen_ids"])

def get_experiment_reconstructions(exp_dict, met_data, specimen_ids, trans_funcs, valid_modalities = []):
    valid_modalities = exp_dict["config"]["modalities"] if valid_modalities is None else valid_modalities
    modalities = set(valid_modalities) & set(exp_dict["config"]["modalities"])
    fold_recons = []
    folds = exp_dict["folds"]
    folds = dict(enumerate(folds, 1)) if type(folds) == list else folds
    for (i, fold) in folds.items():
        fold_recon = {}
        non_train_ids = specimen_ids[~np.isin(np.char.strip(specimen_ids), np.char.strip(fold["train_ids"]))]
        model = fold["best"] if "best" in fold else fold["model"].models
        for in_modal in modalities:
            for out_modal in modalities:
                if "formats" not in exp_dict["config"]:
                    (in_formats, out_formats) = (old_format_maps[in_modal], old_format_maps[out_modal])
                    encoder = lambda x_forms: model[in_modal]["enc"](x_forms[in_formats[0]])
                    decoder = lambda z: {out_formats[0]: model[out_modal]["dec"](z)}
                else:
                    (in_formats, out_formats) = (exp_dict["config"]["formats"][in_modal], exp_dict["config"]["formats"][out_modal])
                    (encoder, decoder) = (model[in_modal]["enc"], model[out_modal]["dec"])
                valid_data = met_data.query(non_train_ids, formats = [tuple(in_formats + out_formats)])
                if len(valid_data[in_formats[0]]):
                    print(f"Generating {exp_name} - {i}: {in_modal} -> {out_modal}             ", end = "\r")
                    transformed = {form: trans_funcs.get(form, lambda x: x)(valid_data[form]) for form in in_formats}
                    input_data = {form: torch.from_numpy(array).float() for (form, array) in transformed.items()}
                    raw_recons = decoder(encoder(input_data))
                    recons = {form: tensor.detach().numpy() for (form, tensor) in raw_recons.items()}
                    fold_recon[f"{in_modal}->{out_modal}"] = (recons, valid_data["specimen_id"])
        fold_recons.append(fold_recon)
    return fold_recons

def get_reconstruction_scores(exp_recon_dict, loss_func, met_data, trans_funcs, display_funcs = None):
    scores = {}
    for (modal_string, (recon_forms, recon_specimen_ids)) in exp_recon_dict.items():
        data = met_data.get_specimens(recon_specimen_ids)
        transformed = {form: trans_funcs.get(form, lambda x: x)(data[form]) for form in recon_forms}
        tensor_orig = {form: torch.from_numpy(array) for (form, array) in transformed.items()}
        tensor_recon = {form: torch.from_numpy(arr) for (form, arr) in recon_forms.items()}
        loss = loss_func.loss(tensor_orig, tensor_recon)
        display_func = (display_funcs.get(modal_string[-1], lambda x: x) if display_funcs else lambda x: x)
        scores[modal_string] = display_func(loss)
    return scores

def load_experiment(base_dir, exp_name, checkpoints = False):
    exp_path = pathlib.Path(base_dir) / exp_name
    if "cca" in exp_name:
        experiment = utils.load_pca_cca(exp_path)
    else:
        experiment = utils.load_jit_folds(exp_path, get_checkpoints = checkpoints)
    return experiment

def build_table(column_info, groups, num_decimals):
    (column_means, column_stds) = ({}, {})
    for (modal_string, target, col_label) in column_info:
        (column_means[col_label], column_stds[col_label]) = ({}, {})
        target_metrics = {exp_name:metrics for (group, group_dict) in scores[target].items() 
                         for (exp_name, metrics) in group_dict.items() if group in groups}
        column_means[col_label] = {exp_name: exp_dict["mean"].get(modal_string) 
                                   for (exp_name, exp_dict) in target_metrics.items()}
        column_stds[col_label] = {exp_name: exp_dict["std"].get(modal_string) 
                                  for (exp_name, exp_dict) in target_metrics.items()}
    mean_frame = pd.DataFrame(column_means).round(num_decimals).astype("string").fillna("--")
    std_frame = pd.DataFrame(column_stds).round(num_decimals).astype("string").fillna("")
    combined_frame = mean_frame + ("±" + std_frame).replace("±", "")
    combined_frame = combined_frame[~(combined_frame == "--").all(1)]
    sorted_frame = combined_frame.sort_index(key = lambda series: series.map(lambda exp_name: groups.index(exp_name.split("-")[-1])))
    return sorted_frame

## Models

In [25]:
met_data = MET_Data("../data/raw/MET_full_data.npz")

In [26]:
exp_info = {
    "cca": {
        "dir": "../../archive/2-24/patchseq-cca",
        "exps": ["t_e_cca", "t_m_cca", "e_m_cca"]},
    "patch": {
        "dir": "../../archive/2-24/patchseq-mse/",
        "exps": ["t_arm", "e_arm", "m_arm", "t_e_arms", "t_m_arms", "e_m_arms", "met"]},
    "full": {
        "dir": "../../archive/2-24/all-mse/",
        "exps": ["m_arm", "t_m_arms", "e_m_arms", "met"]},
    "grad": {
        "dir": "../data/grad_stop/",
        "exps": ["met_patchseq_control", "met_patchseq", "met_full"]},
    "complete": {
        "dir": "../data/full/",
        "exps": ["t_m_arms", "met"]},
    "smartseq": {
        "dir": "../data/smartseq/",
        "exps": ["t_arm", "t_e_arms", "t_m_arms", "met"]},
    "binary": {
        "dir": "../data/binary",
        "exps": ["t_arm_patch", "t_arm_smart", "t_arm_both", "met_patch", "t_e_arms_patch", "t_m_arms_patch"]},
    "ivscc": {
        "dir": "../data/ivscc",
        "exps": ["m_arm_ivscc", "met_ivscc", "m_arm_full", "met_full"]},
    "dual": {
        "dir": "../data/ivscc",
        "exps": ["m_arm_dual", "met_dual"]}
}

In [27]:
all_experiments = {group: {f"{exp}-{group}": load_experiment(info["dir"], exp) for exp in info["exps"]} 
               for (group, info) in exp_info.items()}

## Reconstructions

### Setup

In [28]:
reconstruction_accs = {}

In [29]:
targets = {
    "patchseq": {
        "query": {"formats": [("logcpm", "pca-ipfx", "arbors")]},
        "modalities": ["T", "E", "M"],
        "loss": {
            "formats": {"T": ["logcpm"], "E": ["pca-ipfx"], "M": ["arbors"]},
            "losses": {"logcpm": "feature_r2", "pca-ipfx": "feature_r2", "arbors": "sample_r2"},
            "transform": {}},
        "groups": {"cca", "patch", "full", "grad", "complete", "smartseq"},
        "display": {"T": lambda x: 1 - x, "E": lambda x: 1 - x, "M": lambda x: 1 - x}
    },
    "EM": {
        "query": {"platforms": ["EM"]},
        "modalities": ["M"],
        "loss": {
            "formats": {"M": ["arbors"]},
            "losses": {"arbors": "sample_r2"},
            "transform": {}},
        "groups": {"cca", "patch", "full", "grad", "complete", "smartseq"},
        "display": {"M": lambda x: 1 - x}
    },
    "smartseq": {
        "query": {"platforms": ["smartseq"]},
        "modalities": ["T"],
        "loss": {
            "formats": {"T": ["logcpm"]},
            "losses": {"logcpm": "feature_r2"},
            "transform": {}},
        "groups": {"cca", "patch", "full", "grad", "complete", "smartseq"},
        "display": {"T": lambda x: 1 - x}
    },
    "patch_binary": {
        "query": {"platforms": ["patchseq"], "formats": [("logcpm",)]},
        "modalities": ["T"],
        "loss": {
            "formats": {"T": ["logcpm"]},
            "losses": {"logcpm": "bce"},
            "transform": {"logcpm": {"binarize": 0.1}}},
        "groups": {"binary"},
        "display": None
    },
    "smart_binary": {
        "query": {"platforms": ["smartseq"], "formats": [("logcpm",)]},
        "modalities": ["T"],
        "loss": {
            "formats": {"T": ["logcpm"]},
            "losses": {"logcpm": "bce"},
            "transform": {"logcpm": {"binarize": 0.1}}},
        "groups": {"binary"},
        "display": None
    },
    "multimodal_binary": {
        "query": {"formats": [("logcpm", "pca-ipfx", "arbors")]},
        "modalities": ["T", "E", "M"],
        "loss": {
            "formats": {"T": ["logcpm"], "E": ["pca-ipfx"], "M": ["arbors"]},
            "losses": {"logcpm": "bce", "pca-ipfx": "feature_r2", "arbors": "sample_r2"},
            "transform": {"logcpm": {"binarize": 0.1}}},
        "groups": {"binary"},
        "display": {"E": lambda x: 1 - x, "M": lambda x: 1 - x}
    },
    "ivscc-patchseq": {
        "query": {"formats": [("logcpm", "pca-ipfx", "arbors", "ivscc")]},
        "modalities": ["T", "E", "M"],
        "loss": {
            "formats": {"T": ["logcpm"], "E": ["pca-ipfx"], "M": ["ivscc"]},
            "losses": {"logcpm": "feature_r2", "pca-ipfx": "feature_r2", "ivscc": "sample_r2"},
            "transform": {}},
        "groups": {"ivscc"},
        "display": {"T": lambda x: 1 - x, "E": lambda x: 1 - x, "M": lambda x: 1 - x}
    },
    "ivscc-EM": {
        "query": {"platforms": ["EM"], "formats": [("arbors", "ivscc",)]},
        "modalities": ["M"],
        "loss": {
            "formats": {"M": ["ivscc"]},
            "losses": {"ivscc": "sample_r2"},
            "transform": {}},
        "groups": {"ivscc"},
        "display": {"M": lambda x: 1 - x}
    },
    "ivscc-dual-patchseq": {
        "query": {"formats": [("logcpm", "pca-ipfx", "arbors", "ivscc")]},
        "modalities": ["T", "E", "M"],
        "loss": {
            "formats": {"T": ["logcpm"], "E": ["pca-ipfx"], "M": ["arbors", "ivscc"]},
            "losses": {"logcpm": "feature_r2", "pca-ipfx": "feature_r2", "arbors": "sample_r2", "ivscc": "sample_r2"},
            "transform": {}},
        "groups": {"dual"},
        "display": {"T": lambda x: 1 - x, "E": lambda x: 1 - x, "M": lambda x: 1 - x}
    },
    "ivscc-dual-EM": {
        "query": {"platforms": ["EM"], "formats": [("arbors", "ivscc")]},
        "modalities": ["M"],
        "loss": {
            "formats": {"M": ["arbors", "ivscc"]},
            "losses": {"arbors": "sample_r2", "ivscc": "sample_r2"},
            "transform": {}},
        "groups": {"dual"},
        "display": {"M": lambda x: 1 - x}
    },
    "patchseq-with-ivscc": {
        "query": {"formats": [("logcpm", "pca-ipfx", "arbors", "ivscc")]},
        "modalities": ["T", "E", "M"],
        "loss": {
            "formats": {"T": ["logcpm"], "E": ["pca-ipfx"], "M": ["arbors"]},
            "losses": {"logcpm": "feature_r2", "pca-ipfx": "feature_r2", "arbors": "sample_r2"},
            "transform": {}},
        "groups": {"patch"},
        "display": {"T": lambda x: 1 - x, "E": lambda x: 1 - x, "M": lambda x: 1 - x}
    }
}

In [30]:
specified = [] # Format: [(target, group, experiment), substituting None will run all experiments in that category

In [31]:
plan = {}
if not specified:
    spec_targets = list(targets.keys())
else:
    spec_targets = [tupl[0] for tupl in specified]
for target in spec_targets:
    plan[target] = {}
    spec_groups = [tupl[1] for tupl in specified if tupl[0] == target and tupl[1] is not None]
    if not spec_groups:
        spec_groups = targets[target]["groups"]
    for group in spec_groups:
        spec_experiments = [tupl[2] for tupl in specified 
                            if tupl[0] == target and tupl[1] == group and tupl[2] is not None]
        if not spec_experiments:
            spec_experiments = list(all_experiments[group].keys())
        plan[target][group] = spec_experiments
plan

{'patchseq': {'full': ['m_arm-full',
   't_m_arms-full',
   'e_m_arms-full',
   'met-full'],
  'patch': ['t_arm-patch',
   'e_arm-patch',
   'm_arm-patch',
   't_e_arms-patch',
   't_m_arms-patch',
   'e_m_arms-patch',
   'met-patch'],
  'cca': ['t_e_cca-cca', 't_m_cca-cca', 'e_m_cca-cca'],
  'complete': ['t_m_arms-complete', 'met-complete'],
  'grad': ['met_patchseq_control-grad', 'met_patchseq-grad', 'met_full-grad'],
  'smartseq': ['t_arm-smartseq',
   't_e_arms-smartseq',
   't_m_arms-smartseq',
   'met-smartseq']},
 'EM': {'full': ['m_arm-full', 't_m_arms-full', 'e_m_arms-full', 'met-full'],
  'patch': ['t_arm-patch',
   'e_arm-patch',
   'm_arm-patch',
   't_e_arms-patch',
   't_m_arms-patch',
   'e_m_arms-patch',
   'met-patch'],
  'cca': ['t_e_cca-cca', 't_m_cca-cca', 'e_m_cca-cca'],
  'complete': ['t_m_arms-complete', 'met-complete'],
  'grad': ['met_patchseq_control-grad', 'met_patchseq-grad', 'met_full-grad'],
  'smartseq': ['t_arm-smartseq',
   't_e_arms-smartseq',
   't_m_

In [32]:
for (target, groups_dict) in plan.items():
    print(f"\nRunning target {target}:")
    target_info = targets[target]
    specimen_ids = met_data.query(**target_info["query"])["specimen_id"]
    loss_config = {"encoder_cross_grad": False, "device": "cpu", **target_info["loss"]}
    loss_func = ReconstructionLoss(loss_config, met_data, specimen_ids)
    trans_funcs = {form: get_transformation_function(params) for (form, params) in target_info["loss"]["transform"].items()}
    target_accs = reconstruction_accs.setdefault(target, {})
    for (group, experiments) in groups_dict.items():
        group_accs = target_accs.setdefault(group, {})
        for exp_name in experiments:
            exp_dict = all_experiments[group][exp_name]
            folds = get_experiment_reconstructions(exp_dict, met_data, specimen_ids, trans_funcs, target_info["modalities"])
            group_accs[exp_name] = [get_reconstruction_scores(fold_dict, loss_func, met_data, trans_funcs, target_info["display"]) 
                                    for fold_dict in folds]
print("\nComplete                                                       ")


Running target patchseq:
Generating met-smartseq - 10: E -> E                          
Running target EM:
Generating met-smartseq - 10: M -> M                          
Running target smartseq:
Generating met-smartseq - 10: T -> T                          
Running target patch_binary:
Generating t_m_arms_patch-binary - 7: T -> T             
Running target smart_binary:
Generating t_m_arms_patch-binary - 7: T -> T             
Running target multimodal_binary:
Generating t_m_arms_patch-binary - 7: T -> T             
Running target ivscc-patchseq:
Generating met_full-ivscc - 10: E -> E                
Running target ivscc-EM:


  variances[form] = torch.from_numpy(np.nanvar(data, 0)).to(device, dtype)


Generating met_full-ivscc - 10: M -> M                
Running target ivscc-dual-patchseq:
Generating met_dual-dual - 10: E -> E               
Running target ivscc-dual-EM:
Generating met_dual-dual - 10: M -> M               
Running target patchseq-with-ivscc:
Generating met-patch - 10: E -> E                  
Complete                                                       


In [33]:
scores = {}
for (target, target_dict) in reconstruction_accs.items():
    scores[target] = {}
    for (group, group_dict) in target_dict.items():
        scores[target][group] = {}
        for (exp_name, folds) in group_dict.items():
            modal_strings = [key for fold_dict in folds for key in fold_dict.keys()]
            means = {string: np.mean([fold[string] for fold in folds]) for string in modal_strings}
            stds = {string: np.std([fold[string] for fold in folds]) for string in modal_strings}
            scores[target][group][exp_name] = {"mean": means, "std": stds}

### Models trained on Patch-seq and EM data

In [34]:
columns = [("T->T", "patchseq", "T->T"), ("E->T", "patchseq", "E->T"), ("M->T", "patchseq", "M->T"), ("T->T", "smartseq", "T->T (Tasic)"), 
           ("E->E", "patchseq", "E->E"), ("T->E", "patchseq", "T->E"), ("M->E", "patchseq", "M->E"), ("M->M", "patchseq", "M->M"), 
           ("T->M", "patchseq", "T->M"), ("E->M", "patchseq", "E->M"), ("M->M", "EM", "M->M (EM)")]
build_table(columns, ["cca", "patch", "full", "complete"], 2)

Unnamed: 0,T->T,E->T,M->T,T->T (Tasic),E->E,T->E,M->E,M->M,T->M,E->M,M->M (EM)
t_e_cca-cca,0.16±0.02,0.14±0.02,--,-0.29±0.0,0.26±0.07,0.23±0.07,--,--,--,--,--
t_m_cca-cca,0.15±0.01,--,0.14±0.02,-0.31±0.0,--,--,--,0.24±0.03,0.22±0.03,--,0.11±0.0
e_m_cca-cca,--,--,--,--,0.2±0.07,--,0.14±0.08,0.22±0.03,--,0.13±0.04,0.03±0.02
t_arm-patch,0.37±0.01,--,--,0.03±0.01,--,--,--,--,--,--,--
e_arm-patch,--,--,--,--,0.55±0.04,--,--,--,--,--,--
m_arm-patch,--,--,--,--,--,--,--,0.71±0.02,--,--,0.63±0.03
t_e_arms-patch,0.34±0.01,0.27±0.01,--,-0.07±0.01,0.52±0.04,0.33±0.07,--,--,--,--,--
t_m_arms-patch,0.24±0.05,--,0.16±0.04,-0.17±0.04,--,--,--,0.62±0.04,0.28±0.06,--,0.53±0.05
e_m_arms-patch,--,--,--,--,0.34±0.06,--,0.1±0.11,0.51±0.04,--,0.09±0.1,0.35±0.07
met-patch,0.24±0.02,0.2±0.02,0.13±0.02,-0.16±0.02,0.34±0.06,0.31±0.07,0.16±0.08,0.37±0.05,0.3±0.04,0.22±0.04,0.24±0.05


### Impact of Gradient Freezing

In [35]:
columns = [("T->T", "patchseq", "T->T"), ("E->T", "patchseq", "E->T"), ("M->T", "patchseq", "M->T"), ("T->T", "smartseq", "T->T (Tasic)"), 
           ("E->E", "patchseq", "E->E"), ("T->E", "patchseq", "T->E"), ("M->E", "patchseq", "M->E"), ("M->M", "patchseq", "M->M"), 
           ("T->M", "patchseq", "T->M"), ("E->M", "patchseq", "E->M"), ("M->M", "EM", "M->M (EM)")]
build_table(columns, ["grad"], 2)

Unnamed: 0,T->T,E->T,M->T,T->T (Tasic),E->E,T->E,M->E,M->M,T->M,E->M,M->M (EM)
met_patchseq_control-grad,0.23±0.05,0.19±0.04,0.13±0.04,-0.16±0.03,0.32±0.07,0.3±0.07,0.16±0.08,0.36±0.08,0.28±0.07,0.19±0.05,0.25±0.09
met_patchseq-grad,0.28±0.03,0.23±0.02,0.15±0.02,-0.13±0.03,0.35±0.07,0.31±0.06,0.14±0.08,0.43±0.05,0.29±0.06,0.19±0.06,0.34±0.07
met_full-grad,0.3±0.01,0.24±0.01,0.15±0.02,-0.13±0.03,0.35±0.06,0.31±0.07,0.12±0.08,0.41±0.05,0.29±0.04,0.22±0.06,0.64±0.03


### Binarizing the Transcriptomic Data

In [36]:
columns = [("T->T", "patch_binary", "T->T"), ("T->T", "smart_binary", "T->T (Tasic)")]
build_table(columns, ["binary"], 3)

Unnamed: 0,T->T,T->T (Tasic)
t_arm_patch-binary,0.396±0.002,0.497±0.003
t_arm_smart-binary,0.692±0.035,0.329±0.001
t_arm_both-binary,0.397±0.001,0.329±0.001
met_patch-binary,0.436±0.002,0.515±0.003
t_e_arms_patch-binary,0.422±0.003,0.505±0.002
t_m_arms_patch-binary,0.443±0.005,0.528±0.005


In [37]:
columns = [("T->T", "multimodal_binary", "T->T"), ("E->T", "multimodal_binary", "E->T"), ("M->T", "multimodal_binary", "M->T"), 
           ("E->E", "multimodal_binary", "E->E"), ("T->E", "multimodal_binary", "T->E"), ("M->E", "multimodal_binary", "M->E"), 
           ("M->M", "multimodal_binary", "M->M"), ("T->M", "multimodal_binary", "T->M"), ("E->M", "multimodal_binary", "E->M")]
build_table(columns, ["binary"], 2)

Unnamed: 0,T->T,E->T,M->T,E->E,T->E,M->E,M->M,T->M,E->M
t_arm_patch-binary,0.38±0.0,--,--,--,--,--,--,--,--
t_arm_smart-binary,0.65±0.03,--,--,--,--,--,--,--,--
t_arm_both-binary,0.38±0.0,--,--,--,--,--,--,--,--
met_patch-binary,0.42±0.0,0.43±0.0,0.45±0.01,0.28±0.04,0.26±0.04,0.12±0.03,0.37±0.04,0.28±0.04,0.21±0.04
t_e_arms_patch-binary,0.41±0.0,0.43±0.0,--,0.49±0.04,0.32±0.07,--,--,--,--
t_m_arms_patch-binary,0.43±0.01,--,0.45±0.01,--,--,--,0.52±0.17,0.24±0.1,--


### Handcrafted Morphology Features

In [38]:
columns = [("M->T", "ivscc-patchseq", "morph->T"), ("M->T", "patchseq-with-ivscc", "arbors->T"), ("M->T", "ivscc-dual-patchseq", "dual->T"),
           ("M->E", "ivscc-patchseq", "morph->E"), ("M->E", "patchseq-with-ivscc", "arbors->E"), ("M->E", "ivscc-dual-patchseq", "dual->E"),
           ("M->M", "ivscc-patchseq", "morph->morph"), ("M->M", "ivscc-EM", "morph->morph (EM)")]
build_table(columns, ["patch", "ivscc", "dual"], 2)

Unnamed: 0,morph->T,arbors->T,dual->T,morph->E,arbors->E,dual->E,morph->morph,morph->morph (EM)
t_m_arms-patch,--,0.15±0.03,--,--,--,--,--,--
e_m_arms-patch,--,--,--,--,0.08±0.13,--,--,--
met-patch,--,0.13±0.03,--,--,0.15±0.08,--,--,--
m_arm_ivscc-ivscc,--,--,--,--,--,--,0.58±0.03,-0.49±0.24
met_ivscc-ivscc,0.21±0.02,--,--,0.21±0.07,--,--,0.5±0.04,-1.35±0.51
m_arm_full-ivscc,--,--,--,--,--,--,0.48±0.03,0.68±0.01
met_full-ivscc,0.22±0.02,--,--,0.23±0.07,--,--,0.42±0.04,0.53±0.03
met_dual-dual,--,--,0.21±0.02,--,--,0.21±0.08,--,--
