In [None]:
import torch
import scanpy as sc
import random
import numpy as np
import umap
import pandas as pd
from anndata import AnnData
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
adata = sc.read('./results/neurips-multiome-mappingandimputing.h5ad')

multiome_idx = adata.obs['Modality'] == 'multiome'
X_multiome = adata[multiome_idx].obsm['latent']

umap_model = umap.UMAP(min_dist=0.2)
X_umap_multiome = umap_model.fit_transform(X_multiome)

X_all = adata.obsm['latent']
X_all_umap = umap_model.transform(X_all)

adata.obsm['X_umap'] = X_all_umap

palette = {'multiome': 'lightgray', 'rna': 'red', 'atac': 'blue'}

sc.pl.umap(adata, color='cell_type', legend_loc='right margin', legend_fontsize=12, legend_fontoutline=2)        
sc.pl.umap(adata, color='batch', legend_loc='right margin', legend_fontsize=12, legend_fontoutline=2)
sc.pl.umap(
    adata,
    color='Modality',
    palette=palette,
    legend_loc='right margin',
    legend_fontsize=12,
    legend_fontoutline=2
)

In [None]:
adata = sc.read('./results/neurips-multiome-NK.h5ad')
sc.pp.neighbors(adata, use_rep='latent')
sc.tl.umap(adata, min_dist=0.2)

adata.obs['NK_highlight'] = ['NK' if ct == 'NK' else 'Other' for ct in adata.obs['cell_type']]

sc.pl.umap(adata, color='cell_type', legend_loc=None)        

sc.pl.umap(
    adata,
    color='NK_highlight',
    palette={'NK':'red', 'Other':'lightgray'}, 
    legend_loc='right margin',
    legend_fontsize=12,
    legend_fontoutline=2,
)

In [None]:
adata = sc.read('./results/neurips-multiome-Lymphprog.h5ad')
sc.pp.neighbors(adata, use_rep='latent')
sc.tl.umap(adata, min_dist=0.2)

adata.obs['Lymphprog_highlight'] = ['Lymph prog' if ct == 'Lymph prog' else 'Other' for ct in adata.obs['cell_type']]

sc.pl.umap(adata, color='cell_type', legend_loc=None)        

sc.pl.umap(
    adata,
    color='Lymphprog_highlight',
    palette={'Lymph prog':'red', 'Other':'lightgray'}, 
    legend_loc='right margin',
    legend_fontsize=12,
    legend_fontoutline=2,
)

In [None]:
adata = sc.read('./results/neurips-cite-mappingandimputing.h5ad')

cite_idx = adata.obs['Modality'] == 'cite'
X_cite = adata[cite_idx].obsm['latent']

umap_model = umap.UMAP(min_dist=0.2)
X_umap_cite = umap_model.fit_transform(X_cite)

X_all = adata.obsm['latent']
X_all_umap = umap_model.transform(X_all)

adata.obsm['X_umap'] = X_all_umap

palette = {
    'cite': 'lightgray',
    'rna': 'red',
    'adt': 'blue'
}

sc.pl.umap(adata, color='cell_type', legend_loc='right margin', legend_fontsize=12, legend_fontoutline=2)        
sc.pl.umap(adata, color='batch', legend_loc='right margin', legend_fontsize=12, legend_fontoutline=2)
sc.pl.umap(
    adata,
    color='Modality',
    palette=palette,
    legend_loc='right margin',
    legend_fontsize=12,
    legend_fontoutline=2
)

In [None]:
adata = sc.read('./results/neurips-cite-CD8+Tnaive.h5ad')
sc.pp.neighbors(adata, use_rep='latent')
sc.tl.umap(adata, min_dist=0.2)

adata.obs['CD8+ T naive_highlight'] = ['CD8+ T naive' if ct == 'CD8+ T naive' else 'Other' for ct in adata.obs['cell_type']]

sc.pl.umap(adata, color='cell_type', legend_loc=None)        

sc.pl.umap(
    adata,
    color='CD8+ T naive_highlight',
    palette={'CD8+ T naive':'red', 'Other':'lightgray'}, 
    legend_loc='right margin',
    legend_fontsize=12,
    legend_fontoutline=2,
)

In [None]:
adata = sc.read('./results/neurips-cite-HSC.h5ad')
sc.pp.neighbors(adata, use_rep='latent')
sc.tl.umap(adata, min_dist=0.2)

adata.obs['HSC_highlight'] = ['HSC' if ct == 'HSC' else 'Other' for ct in adata.obs['cell_type']]

sc.pl.umap(adata, color='cell_type', legend_loc=None)        

sc.pl.umap(
    adata,
    color='HSC_highlight',
    palette={'HSC':'red', 'Other':'lightgray'}, 
    legend_loc='right margin',
    legend_fontsize=12,
    legend_fontoutline=2,
)

In [None]:
adata = sc.read('./results/trimodal.h5ad')

sc.pp.neighbors(adata, use_rep='latent')
sc.tl.umap(adata, min_dist=0.2)

palette = {
    'rna': 'red',
    'atac': 'blue',
    'adt': 'green',
    'multiome': 'orange',
    'cite': 'purple'
}

sc.pl.umap(adata, color='cell_type', legend_loc='right margin', legend_fontsize=12, legend_fontoutline=2)        
sc.pl.umap(adata, color='batch', legend_loc=None)
sc.pl.umap(
    adata,
    color='Modality',
    palette=palette,
    legend_loc='right margin',
    legend_fontsize=12,
    legend_fontoutline=2
)

In [None]:
adata = sc.read('./results/trimodal-mappingandimputing.h5ad')

idx = (adata.obs['Modality'] == 'multiome') | (adata.obs['Modality'] == 'cite')
X_train = adata[idx].obsm['latent']

umap_model = umap.UMAP(min_dist=0.2)
X_umap_train = umap_model.fit_transform(X_train)

X_all = adata.obsm['latent']
X_all_umap = umap_model.transform(X_all)

adata.obsm['X_umap'] = X_all_umap

palette = {
    'rna': 'red',
    'atac': 'blue',
    'adt': 'green',
    'multiome': 'orange',
    'cite': 'purple'
}

sc.pl.umap(adata, color='cell_type', legend_loc=None)
sc.pl.umap(adata, color='Modality', palette=palette, legend_loc='right margin', legend_fontsize=12, legend_fontoutline=2)

In [None]:
adata = sc.read('./results/trimodal-mappingandimputing.h5ad')

interest = ["CD3G", "MS4A1", "FCGR3A", "chr11-118343914-118344801", "CD20", "CD3", "CD16"]

sc.pp.neighbors(adata, use_rep='latent')
sc.tl.umap(adata, min_dist=0.2)


for g in interest:
    adata.obs[f"{g}_imputed"] = adata.obsm[g]

for g in interest:
    sc.pl.umap(adata, 
    color=f"{g}_imputed",            
    legend_loc='right margin',
    legend_fontsize=12,
    legend_fontoutline=2,
    cmap='plasma')

cell_type = adata.obs['cell_type']

T_cells = ['CD4+ T CD314+ CD45RA+', 'CD4+ T activated', 'CD4+ T activated integrinB7+', 'CD4+ T naive',
           'CD8+ T', 'CD8+ T CD49f+', 'CD8+ T CD57+ CD45RA+', 'CD8+ T CD57+ CD45RO+', 'CD8+ T CD69+ CD45RA+',
           'CD8+ T CD69+ CD45RO+', 'CD8+ T TIGIT+ CD45RA+', 'CD8+ T TIGIT+ CD45RO+', 'CD8+ T naive',
           'CD8+ T naive CD127+ CD26- CD101-', 'T prog cycling', 'T reg', 'dnT', 'gdT CD158b+', 'gdT TCRVD2+']

NK_cells = ['NK', 'NK CD158e1+']

B_cells = ['B1 B', 'B1 B IGKC+', 'B1 B IGKC-', 'Naive CD20+ B', 'Naive CD20+ B IGKC+', 
           'Naive CD20+ B IGKC-', 'Transitional B']

Mono_cells = ['CD16+ Mono']  

adata.obs['cell_group'] = 'Other' 
adata.obs.loc[cell_type.isin(T_cells), 'cell_group'] = 'T'
adata.obs.loc[cell_type.isin(NK_cells), 'cell_group'] = 'NK'
adata.obs.loc[cell_type.isin(B_cells), 'cell_group'] = 'B'
adata.obs.loc[cell_type.isin(Mono_cells), 'cell_group'] = 'CD16+ Mono'  

colors = {'T': 'red', 'NK': 'blue', 'B': 'green', 'CD16+ Mono': 'yellow', 'Other': 'lightgray'}

sc.pl.umap(adata, color='cell_group', palette=colors, legend_loc='on data', legend_fontsize=12, legend_fontoutline=2)

In [None]:
adata = sc.read('./results/trimodal.h5ad')
t_dir = "./results"

E = torch.load(f"{t_dir}/trimodal_e.pt").detach().cpu().numpy()
V = torch.load(f"{t_dir}/trimodal_v.pt").detach().cpu().numpy()
Q = torch.load(f"{t_dir}/trimodal_q.pt").detach().cpu().numpy()

cell_type = adata.obs['cell_type'].astype(str)
batch = adata.obs['batch'].astype(str)

reducer_E = umap.UMAP(min_dist=0.2)
E_umap = reducer_E.fit_transform(E)
V_umap = reducer_E.transform(V)

reducer_Q = umap.UMAP(min_dist=0.2)
Q_umap = reducer_Q.fit_transform(Q)

adata.obsm['X_umap_E'] = E_umap
adata.obsm['X_umap_Q'] = Q_umap

sc.pl.embedding(
    adata,
    basis='umap_Q', 
    color='cell_type',
    legend_loc=None,
    title='Joint Query Vectors colored by cell type',
    show=False
)
plt.xlabel("UMAP1", fontsize=12)
plt.ylabel("UMAP2", fontsize=12)
plt.show()

sc.pl.embedding(
    adata,
    basis='umap_E', 
    color='cell_type',
    legend_loc=None,
    title='Joint Representations colored by cell type',
    show=False
)
plt.xlabel("UMAP1", fontsize=12)
plt.ylabel("UMAP2", fontsize=12)
plt.show()

plt.figure(figsize=(7, 6))

plt.scatter(
    E_umap[:, 0], E_umap[:, 1],
    c="lightgray",
    s=15,
    alpha=0.4,
    linewidth=0,
    label="Joint Representations"
)

plt.scatter(
    V_umap[:, 0], V_umap[:, 1],
    c="crimson",
    s=120,
    edgecolors="black",
    linewidths=0.8,
    label="Joint Value Vectors"
)

plt.title("Joint Value Vectors projected onto Joint Representations", fontsize=14, weight="bold")
plt.xlabel("UMAP1", fontsize=12)
plt.ylabel("UMAP2", fontsize=12)
plt.legend(frameon=False, fontsize=10, loc="best")
sns.despine(trim=True)
plt.tight_layout()
plt.show()