In [None]:
import torch
import scanpy as sc
import random
import numpy as np
import umap
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
def set_seed(seed):
    """Set all relevant random seeds to ensure experiment reproducibility."""
    random.seed(seed)                
    np.random.seed(seed)             
    torch.manual_seed(seed)          
    torch.cuda.manual_seed(seed)    
    torch.cuda.manual_seed_all(seed)  
    torch.backends.cudnn.deterministic = True 
    torch.backends.cudnn.benchmark = False     

set_seed(42)

In [None]:
adata = sc.read('./results/neurips-multiome-mappingandimputing.h5ad')

multiome_idx = adata.obs['Modality'] == 'multiome'
X_multiome = adata[multiome_idx].obsm['latent']

umap_model = umap.UMAP(min_dist=0.2)
X_umap_multiome = umap_model.fit_transform(X_multiome)

X_all = adata.obsm['latent']
X_all_umap = umap_model.transform(X_all)

adata.obsm['X_umap'] = X_all_umap

palette = {'multiome': 'lightgray', 'rna': 'red', 'atac': 'blue'}

sc.pl.umap(adata, color='cell_type', legend_loc='right margin', legend_fontsize=12, legend_fontoutline=2)        
sc.pl.umap(adata, color='batch', legend_loc='right margin', legend_fontsize=12, legend_fontoutline=2)
sc.pl.umap(
    adata,
    color='Modality',
    palette=palette,
    legend_loc='right margin',
    legend_fontsize=12,
    legend_fontoutline=2
)

In [None]:
adata = sc.read('./results/neurips-multiome-NK.h5ad')
sc.pp.neighbors(adata, use_rep='latent')
sc.tl.umap(adata, min_dist=0.2)

adata.obs['NK_highlight'] = ['NK' if ct == 'NK' else 'Other' for ct in adata.obs['cell_type']]

sc.pl.umap(adata, color='cell_type', legend_loc='right margin', legend_fontsize=12, legend_fontoutline=2)        
sc.pl.umap(adata, color='batch', legend_loc='right margin', legend_fontsize=12, legend_fontoutline=2)

sc.pl.umap(
    adata,
    color='NK_highlight',
    palette={'NK':'red', 'Other':'lightgray'}, 
    legend_loc='right margin',
    legend_fontsize=12,
    legend_fontoutline=2,
)

In [None]:
adata = sc.read('./results/neurips-multiome-Lymphprog.h5ad')
sc.pp.neighbors(adata, use_rep='latent')
sc.tl.umap(adata, min_dist=0.2)

adata.obs['Lymphprog_highlight'] = ['Lymph prog' if ct == 'Lymph prog' else 'Other' for ct in adata.obs['cell_type']]

sc.pl.umap(adata, color='cell_type', legend_loc='right margin', legend_fontsize=12, legend_fontoutline=2)        
sc.pl.umap(adata, color='batch', legend_loc='right margin', legend_fontsize=12, legend_fontoutline=2)

sc.pl.umap(
    adata,
    color='Lymphprog_highlight',
    palette={'Lymph prog':'red', 'Other':'lightgray'}, 
    legend_loc='right margin',
    legend_fontsize=12,
    legend_fontoutline=2,
)

In [None]:
adata = sc.read('./results/neurips-cite-mappingandimputing.h5ad')

cite_idx = adata.obs['Modality'] == 'cite'
X_cite = adata[cite_idx].obsm['latent']

umap_model = umap.UMAP(min_dist=0.2)
X_umap_cite = umap_model.fit_transform(X_cite)

X_all = adata.obsm['latent']
X_all_umap = umap_model.transform(X_all)

adata.obsm['X_umap'] = X_all_umap

palette = {
    'cite': 'lightgray',
    'rna': 'red',
    'adt': 'blue'
}

sc.pl.umap(adata, color='cell_type', legend_loc='right margin', legend_fontsize=12, legend_fontoutline=2)        
sc.pl.umap(adata, color='batch', legend_loc='right margin', legend_fontsize=12, legend_fontoutline=2)
sc.pl.umap(
    adata,
    color='Modality',
    palette=palette,
    legend_loc='right margin',
    legend_fontsize=12,
    legend_fontoutline=2
)

In [None]:
adata = sc.read('./results/neurips-cite-CD16+Mono.h5ad')
sc.pp.neighbors(adata, use_rep='latent')
sc.tl.umap(adata, min_dist=0.2)

adata.obs['CD16+ Mono_highlight'] = ['CD16+ Mono' if ct == 'CD16+ Mono' else 'Other' for ct in adata.obs['cell_type']]

sc.pl.umap(adata, color='cell_type', legend_loc='right margin', legend_fontsize=12, legend_fontoutline=2)        
sc.pl.umap(adata, color='batch', legend_loc='right margin', legend_fontsize=12, legend_fontoutline=2)

sc.pl.umap(
    adata,
    color='CD16+ Mono_highlight',
    palette={'CD16+ Mono':'red', 'Other':'lightgray'}, 
    legend_loc='right margin',
    legend_fontsize=12,
    legend_fontoutline=2,
)

In [None]:
adata = sc.read('./results/neurips-cite-HSC.h5ad')
sc.pp.neighbors(adata, use_rep='latent')
sc.tl.umap(adata, min_dist=0.2)

adata.obs['HSC_highlight'] = ['HSC' if ct == 'HSC' else 'Other' for ct in adata.obs['cell_type']]

sc.pl.umap(adata, color='cell_type', legend_loc='right margin', legend_fontsize=12, legend_fontoutline=2)        
sc.pl.umap(adata, color='batch', legend_loc='right margin', legend_fontsize=12, legend_fontoutline=2)

sc.pl.umap(
    adata,
    color='HSC_highlight',
    palette={'HSC':'red', 'Other':'lightgray'}, 
    legend_loc='right margin',
    legend_fontsize=12,
    legend_fontoutline=2,
)

In [None]:
adata = sc.read('./results/trimodal.h5ad')

sc.pp.neighbors(adata, use_rep='latent')
sc.tl.umap(adata, min_dist=0.2)

palette = {
    'rna': 'red',
    'atac': 'blue',
    'adt': 'green',
    'multiome': 'orange',
    'cite': 'purple'
}

sc.pl.umap(adata, color='cell_type', legend_loc='right margin', legend_fontsize=12, legend_fontoutline=2)        
sc.pl.umap(adata, color='batch', legend_loc='right margin', legend_fontsize=12, legend_fontoutline=2)
sc.pl.umap(
    adata,
    color='Modality',
    palette=palette,
    legend_loc='right margin',
    legend_fontsize=12,
    legend_fontoutline=2
)

In [None]:
a_dir = "./results"
v_dir = "./results"

E = torch.load(f"{a_dir}/trimodal_e.pt").detach().cpu().numpy() 

V = torch.load(f"{v_dir}/trimodal_v.pt").detach().cpu().numpy()  

reducer = umap.UMAP(
    n_neighbors=30,
    min_dist=0.3,
    random_state=42
)

E_umap = reducer.fit_transform(E)

V_umap = reducer.transform(V)

sns.set(style="white", context="notebook")

plt.figure(figsize=(8,6))

plt.scatter(
    E_umap[:,0], E_umap[:,1],
    s=10,
    c="lightgray",
    alpha=0.4,
    edgecolors="none",
    label="e"
)

plt.scatter(
    V_umap[:,0], V_umap[:,1],
    s=60,
    c="crimson",
    alpha=0.9,
    edgecolors="black",
    linewidths=0.5,
    label="v"
)

plt.title("Trimodal VE UMAP", fontsize=16, weight='bold')
plt.xlabel("UMAP1", fontsize=12)
plt.ylabel("UMAP2", fontsize=12)
plt.xticks(fontsize=10)
plt.yticks(fontsize=10)
plt.legend(frameon=True, loc="best", fontsize=10)
plt.grid(False)
plt.tight_layout()
plt.show()

In [None]:
adata = sc.read('./results/trimodal-mappingandimputing.h5ad')

idx = (adata.obs['Modality'] == 'multiome') | (adata.obs['Modality'] == 'cite')
X_train = adata[idx].obsm['latent']

umap_model = umap.UMAP(min_dist=0.2)
X_umap_train = umap_model.fit_transform(X_train)

X_all = adata.obsm['latent']
X_all_umap = umap_model.transform(X_all)

adata.obsm['X_umap'] = X_all_umap

palette = {
    'rna': 'red',
    'atac': 'blue',
    'adt': 'green',
    'multiome': 'orange',
    'cite': 'purple'
}

sc.pl.umap(adata, color='cell_type', legend_loc='right margin', legend_fontsize=12, legend_fontoutline=2)
sc.pl.umap(adata, color='batch', legend_loc='right margin', legend_fontsize=12, legend_fontoutline=2)
sc.pl.umap(adata, color='Modality', palette=palette, legend_loc='right margin', legend_fontsize=12, legend_fontoutline=2)

In [None]:
adata = sc.read('./results/trimodal-mappingandimputing.h5ad')

interest = ["CD3G", "NCAM1", "MS4A1", "CD20", "CD3", "CD56"]

sc.pp.neighbors(adata, use_rep='latent')
sc.tl.umap(adata, min_dist=0.2)


for g in interest:
    adata.obs[f"{g}_imputed"] = adata.obsm[g]

for g in interest:
    sc.pl.umap(adata, 
    color=f"{g}_imputed",            
    legend_loc='right margin',
    legend_fontsize=12,
    legend_fontoutline=2,
    cmap='plasma')