In [None]:
import os
import anndata as ad
import torch
import numpy
import argparse
import pandas as pd
import scanpy as sc
from tqdm.auto import tqdm

os.getcwd()

In [None]:
# TASK = 'GEX2ADT'
TASK = 'GEX2ATAC'
DATASET_PATH = "datasets"
PRETRAIN_PATH = "pretrain"
PREDICTION_PATH = "pretrain/defaultGEX2ATAC.h5ad"
OUT_NAME = ""

if TASK == 'GEX2ADT':
    is_multiome = False
    test_path = os.path.join(DATASET_PATH, "openproblems_bmmc_cite_phase2_rna/openproblems_bmmc_cite_phase2_rna"
                                                ".censor_dataset.output_")
    completedata_path = os.path.join(DATASET_PATH, "post_competition/openproblems_bmmc_cite_complete.h5ad")
    pretrain_path = os.path.join(PRETRAIN_PATH, "GEX2ADT")
elif TASK == 'GEX2ATAC':
    is_multiome = True
    test_path = os.path.join(DATASET_PATH, "openproblems_bmmc_multiome_phase2_rna"
                                                "/openproblems_bmmc_multiome_phase2_rna.censor_dataset.output_")
    completedata_path = os.path.join(DATASET_PATH, "post_competition/openproblems_bmmc_multiome_complete.h5ad")
    pretrain_path = os.path.join(PRETRAIN_PATH, "GEX2ATAC")
else:
    raise ValueError('Unknown task: ' + TASK)

par = {
        "input_train_mod1": f"{test_path}train_mod1.h5ad",
        "input_train_mod2": f"{test_path}train_mod2.h5ad",
        "input_test_mod1": f"{test_path}test_mod1.h5ad",
        "input_test_mod2": f"{test_path}test_mod2.h5ad",
        "input_complete": completedata_path,
        "input_test_sol": f"{test_path}test_sol.h5ad",
        "input_test_prediction": PREDICTION_PATH,
        "input_pretrain": pretrain_path,
        "output": os.path.join(PRETRAIN_PATH, OUT_NAME + TASK)
}

In [None]:
sc.pp.neighbors(input_train_mod1)
sc.tl.umap(input_train_mod1)

In [None]:
# sc.pl.umap(input_train_mod1, color='cell_type')
sc.pl.umap(input_train_mod1, color='batch')

In [None]:
input_complete = ad.read_h5ad(par["input_complete"])

In [None]:
input_complete

In [None]:
input_atac_umap = input_complete.copy()
input_atac_umap.obsm["umap"] = input_complete.obsm["ATAC_umap"]

In [None]:
sc.pl.umap(input_atac_umap, color='cell_type')
sc.pl.umap(input_atac_umap, color='batch')

In [None]:
input_gex_umap = input_complete.copy()
input_gex_umap.obsm["X_umap"] = input_complete.obsm["GEX_X_umap"]

In [None]:
sc.pl.umap(input_gex_umap, color='cell_type')
sc.pl.umap(input_gex_umap, color='batch')

In [None]:
input_atac_umap.obsm["umap"]

In [None]:
input_gex_umap.obsm["X_umap"]

# UMAP on train data

In [None]:
import argparse
import os
import pickle
import sys

import anndata as ad
import numpy as np
import pandas as pd
import scipy.sparse
import torch
from distutils.util import strtobool

sys.path.append(".")
from resources.data import ModalityMatchingDataset
from resources.models import Modality_CLIP, Encoder
from resources.postprocessing import OT_matching, MWB_matching
from resources.hyperparameters import *
from resources.preprocessing import harmony
from evaluate import evaluate

In [None]:
# Load data
# input_train_mod1 = ad.read_h5ad(par["input_train_mod1"])
# input_train_mod2 = ad.read_h5ad(par["input_train_mod2"])
input_test_mod1 = ad.read_h5ad(par["input_test_mod1"])
input_test_mod2 = ad.read_h5ad(par["input_test_mod2"])
input_complete = ad.read_h5ad(par["input_complete"])

In [None]:
input_train_mod2.obs_names

In [None]:
# test_withcelltype = input_complete[input_test_mod1.obs_names]
test_withcelltype_mod1 = ad.read_h5ad("datasets/PBMC/glue_processed/test_mod1.h5ad")
test_withcelltype_mod2 = ad.read_h5ad("datasets/PBMC/glue_processed/test_mod2.h5ad")

In [None]:
# load hard_x.npy
with open("run/hard_X.npy", "rb") as f:
    hard_X = np.load(f)

In [None]:
mod2_permutation = hard_X.argmax(axis=1)

In [None]:
mod2_permutation

In [None]:
test_withcelltype_mod2.X = test_withcelltype_mod2.X.toarray()[mod2_permutation]

In [None]:
test_withcelltype_mod2.obs_names

In [None]:
test_withcelltype_mod2.X[1].sum() # 15072.235 with no permutation

In [None]:
test_withcelltype_mod2.obs_keys

In [None]:
# concatenate the two modalities along axis 1
test_withcelltype = ad.concat([test_withcelltype_mod1, test_withcelltype_mod2], axis=1, keys=['GEX', 'ATAC'], merge='first', uns_merge='first')
test_withcelltype

In [None]:
test_withcelltype.obs["cell_type"]

In [None]:
test_withcelltype.var_names

In [None]:
test_withcelltype.obs["cell_type"].to_frame().value_counts()

In [None]:
test_withcelltype.var["feature_types"]

In [None]:
def umap(adata, save_celltype='umap_celltype.pdf', save_batch='umap_batch.pdf'):
    sc.pp.neighbors(adata)
    sc.tl.umap(adata, random_state=0)
    sc.pl.umap(adata, color='cell_type', edges=False, save=save_celltype)
    sc.pl.umap(adata, color='batch', edges=False, save=save_batch)

In [None]:
umap(test_withcelltype, save_celltype="pbmc_predicted_match.pdf")

In [None]:
# UMAP of embedding concatenations
import torch
fold = 0
# emb_mod12 = torch.load("pretrain/defaultPredictedMatch" + TASK + "emb_mod12_" + str(fold) + ".pt")
emb_mod12 = torch.load("pretrain/pbmc1NoEGEX2ATACemb_mod12_fold0_predmatch.pt")
test_withcelltype.obsm["X_pca"] = emb_mod12


In [None]:
umap(test_withcelltype, save_batch="randmatch_test_atacgex_emb_batch_umap.pdf", save_celltype="pbmc_emb_predmatch_noentropy.pdf")

In [None]:
# umap of train mod1
sc.pp.neighbors(input_train_mod1)
sc.tl.umap(input_train_mod1)
# sc.pl.umap(input_train_mod1, color='cell_type')
sc.pl.umap(input_train_mod1, color='batch')

In [None]:
# umap of train mod2
sc.pp.neighbors(input_train_mod2)
sc.tl.umap(input_train_mod2)
# sc.pl.umap(input_train_mod1, color='cell_type')
sc.pl.umap(input_train_mod2, color='batch')

In [None]:
# Load and apply LSI transformation
with open(par["input_pretrain"] + "/lsi_GEX_transformer.pickle", "rb") as f:
    lsi_transformer_gex = pickle.load(f)
if is_multiome:
    with open(par["input_pretrain"] + "/lsi_ATAC_transformer.pickle", "rb") as f:
        lsi_transformer_atac = pickle.load(f)
    gex_train = lsi_transformer_gex.transform(input_train_mod1)
    gex_test = lsi_transformer_gex.transform(input_test_mod1)
    mod2_train = lsi_transformer_atac.transform(input_train_mod2)
    mod2_test = lsi_transformer_atac.transform(input_test_mod2)
else:
    gex_train = lsi_transformer_gex.transform(input_train_mod1)
    gex_test = lsi_transformer_gex.transform(input_test_mod1)
    mod2_train = input_train_mod2.to_df()
    mod2_test = input_test_mod2.to_df()

In [None]:
TASK

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

# Define argument parsers
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(dest='TASK')

# Common args
for key, value in defaults_common.items():
    parser.add_argument("--" + key, default=value,
                        type=(lambda x: bool(strtobool(x))) if type(value) == bool else type(value))

# GEX2ADT args
parser_GEX2ADT = subparsers.add_parser('GEX2ADT', help='train GEX2ADT model')
for key, value in defaults_GEX2ADT.items():
    parser_GEX2ADT.add_argument("--" + key, default=value, type=type(value))

# GEX2ATAC args
parser_GEX2ATAC = subparsers.add_parser('GEX2ATAC', help='train GEX2ATAC model')
for key, value in defaults_GEX2ATAC.items():
    parser_GEX2ATAC.add_argument("--" + key, default=value, type=type(value))

# Parse args
args, unknown_args = parser.parse_known_args([TASK])
print("args:", args, "unknown_args:", unknown_args)

In [None]:
fold = 0
weight_file = par["input_pretrain"] + "/" + str(fold) + "/model.best.pth"
if os.path.exists(weight_file):
    print("Loading weights from " + weight_file)
    weight = torch.load(weight_file, map_location="cpu")

    # Define modality encoders
    if is_multiome:
        model = Modality_CLIP(
            Encoder=Encoder,
            layers_dims=(
                [args.LAYERS_DIM_ATAC],
                [args.LAYERS_DIM_GEX0, args.LAYERS_DIM_GEX1],
            ),
            dropout_rates=(
                [args.DROPOUT_RATES_ATAC],
                [args.DROPOUT_RATES_GEX0, args.DROPOUT_RATES_GEX1],
            ),
            dim_mod1=args.N_LSI_COMPONENTS_ATAC,
            dim_mod2=args.N_LSI_COMPONENTS_GEX,
            output_dim=args.EMBEDDING_DIM,
            T=args.LOG_T,
            noise_amount=args.SFA_NOISE,
        ).to(device)
    else:
        model = Modality_CLIP(
            Encoder=Encoder,
            layers_dims=(
                [args.LAYERS_DIM_ADT0, args.LAYERS_DIM_ADT1],
                [args.LAYERS_DIM_GEX0, args.LAYERS_DIM_GEX1],
            ),
            dropout_rates=(
                [args.DROPOUT_RATES_ADT0, args.DROPOUT_RATES_ADT1],
                [args.DROPOUT_RATES_GEX0, args.DROPOUT_RATES_GEX1],
            ),
            dim_mod1=args.N_LSI_COMPONENTS_ADT,
            dim_mod2=args.N_LSI_COMPONENTS_GEX,
            output_dim=args.EMBEDDING_DIM,
            T=args.LOG_T,
            noise_amount=args.SFA_NOISE,
        ).to(device)

    # Load pretrained weights
    model.load_state_dict(weight)

    # Load torch datasets
    dataset_train = ModalityMatchingDataset(pd.DataFrame(gex_train), pd.DataFrame(mod2_train))
    dataset_test = ModalityMatchingDataset(pd.DataFrame(gex_test), pd.DataFrame(mod2_test))
    data_train = torch.utils.data.DataLoader(dataset_train, 32, shuffle=False)
    data_test = torch.utils.data.DataLoader(dataset_test, 32, shuffle=False)

    # Predict on train set
    all_emb_mod1_train = []
    all_emb_mod2_train = []
    indexes = []
    model.eval()
    for batch in tqdm(data_train):
        x1 = batch["features_first"].float()
        x2 = batch["features_second"].float()
        # The model applies the GEX encoder to the second argument, here x1
        logits, features_mod2, features_mod1 = model(
            x2.to(device), x1.to(device)
        )

        all_emb_mod1_train.append(features_mod1.detach().cpu())
        all_emb_mod2_train.append(features_mod2.detach().cpu())

    all_emb_mod1_train = torch.cat(all_emb_mod1_train)
    all_emb_mod2_train = torch.cat(all_emb_mod2_train)

    # Predict on test set
    all_emb_mod1 = []
    all_emb_mod2 = []
    indexes = []
    model.eval()
    for batch in data_test:
        x1 = batch["features_first"].float()
        x2 = batch["features_second"].float()
        # The model applies the GEX encoder to the second argument, here x1
        logits, features_mod2, features_mod1 = model(
            x2.to(device), x1.to(device)
        )

        all_emb_mod1.append(features_mod1.detach().cpu())
        all_emb_mod2.append(features_mod2.detach().cpu())

    all_emb_mod1 = torch.cat(all_emb_mod1)
    all_emb_mod2 = torch.cat(all_emb_mod2)


In [None]:
# all_emb_mod1_train = torch.cat(all_emb_mod1_train)
# all_emb_mod2_train = torch.cat(all_emb_mod2_train)
all_emb_mod2_train.shape

In [None]:
all_emb_mod1_train.shape

In [None]:
input_train_mod1

In [None]:
# reload from file to delete umap coordinates
input_train_mod1_emb = ad.read_h5ad(par["input_train_mod1"])

In [None]:
input_train_mod1_emb.obsm["X_pca"] = all_emb_mod1_train

In [None]:
# umap of train mod1 embeddings
sc.pp.neighbors(input_train_mod1_emb)
sc.tl.umap(input_train_mod1_emb)
# sc.pl.umap(input_train_mod1, color='cell_type')
sc.pl.umap(input_train_mod1_emb, color='batch')

In [None]:
# reload from file to delete umap coordinates
input_train_mod2_emb = ad.read_h5ad(par["input_train_mod2"])
input_train_mod2_emb.obsm["X_pca"] = all_emb_mod2_train
# umap of train mod1 embeddings
sc.pp.neighbors(input_train_mod2_emb)
sc.tl.umap(input_train_mod2_emb)
# sc.pl.umap(input_train_mod1, color='cell_type')
sc.pl.umap(input_train_mod2_emb, color='batch')

In [None]:
# reload from file to delete umap coordinates
input_train_mod1_emb = ad.read_h5ad(par["input_train_mod1"])
input_train_mod2_emb = ad.read_h5ad(par["input_train_mod2"])
input_train_mod1_emb.obsm["X_pca"] = all_emb_mod1_train
input_train_mod2_emb.obsm["X_pca"] = all_emb_mod2_train
# concatenate to have the same umap space
input_train_mod1mod2_emb = ad.concat((input_train_mod1_emb, input_train_mod2_emb), join='outer')


In [None]:
input_train_mod1mod2_emb.obsm["X_pca"].shape

In [None]:
# umap of train mod1mod2 embeddings
sc.pp.neighbors(input_train_mod1mod2_emb)
sc.tl.umap(input_train_mod1mod2_emb)
# sc.pl.umap(input_train_mod1, color='cell_type')
sc.pl.umap(input_train_mod1mod2_emb[:42492], color='batch')
sc.pl.umap(input_train_mod1mod2_emb[42492:], color='batch')

In [None]:
sc.pl.umap(input_train_mod1mod2_emb, color='batch')

In [None]:
sc.pl.pca(input_train_mod1mod2_emb, color='batch')

In [None]:
sc.pl.pca(input_train_mod1mod2_emb[:42492], color='batch')
sc.pl.pca(input_train_mod1mod2_emb[42492:], color='batch')