In [None]:
# !curl https://datasets.cellxgene.cziscience.com/0401c761-2112-4f10-ae7d-6d5e04b5e1a4.h5ad -O liver_nanostring/sc_reference.h5ad

In [None]:
import sys
sys.path.insert(0,'/home/cane/Documents/yoseflab/can/resolVI')
from scvi.external import RESOLVI

In [None]:
import scanpy as sc
import pandas as pd
import numpy as np
import scvi
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
sys.path.insert(0, '.')
import _utils

In [None]:
scvi.settings.seed = 0
sc.set_figure_params(dpi=100, dpi_save=300, format='png', frameon=False, vector_friendly=True, fontsize=14, color_map='viridis', figsize=None)
sc.settings.figdir = 'figure3_vizgen/'

# Benchmarking

In [None]:
nanostring = {}

In [None]:
path = '/external_data/other/resolvi_final_other_files/liver_cancer_vizgen/'
segmentations = ['baysor', 'original', 'proseg', 'cellpose_triplez', 'cellpose_singlez', 'cellpose_nuclei']

In [None]:
for i in segmentations:
    nanostring[i] = sc.read_h5ad(f'{path}{i}/complete_adata.h5ad')
    ad = sc.read_h5ad(f'{path}{i}_semisupervised/complete_adata.h5ad')
    nanostring[i].obsm['X_resolvi_semisupervised'] = ad.obsm['X_resolvi']
    nanostring[i].layers['generated_expression_semisupervised'] = ad.layers['generated_expression']

In [None]:
nanostring[i].obs['total_counts'] = nanostring[i].layers['raw_counts'].sum(1)

In [None]:
import os

In [None]:
nanostring = {}

In [None]:
for key in segmentations:
    if os.path.exists(f'{path}{key}/complete_adata_filtered.h5ad'):
        nanostring[key] = sc.read_h5ad(f'{path}{key}/complete_adata_filtered.h5ad')
        continue
    print(key)
    nanostring[key].obs['total_counts'] = nanostring[key].layers['raw_counts'].sum(1)
    nanostring[key].obs['true_counts'] = nanostring[key].obs['total_counts'] * nanostring[key].obs['true_proportion']
    nanostring[key] = nanostring[key][nanostring[key].obs['true_counts'] > 20].copy()
    _utils.compute_umap_embedding(nanostring[key], representation_key="X_resolvi_semisupervised", n_comps=None, show=True, key='resolvi_latent_semisupervised', n_neighbors=20, extra_save=key)
    _utils.compute_umap_embedding(nanostring[key], representation_key="X_resolVI", n_comps=None, show=True, key='resolvi_latent', n_neighbors=20, extra_save=key)
    _utils.compute_umap_embedding(nanostring[key], representation_key="raw_counts", show=True, key='raw_counts', n_neighbors=20, extra_save=key)
    _utils.compute_umap_embedding(nanostring[key], representation_key="raw_counts", show=True, key='raw_counts_harmony', n_neighbors=20, extra_save=key, batch_key='patient')
    _utils.compute_umap_embedding(nanostring[key], representation_key="generated_expression", show=True, key='resolvi_generated', n_neighbors=20, extra_save=key)
    _utils.compute_umap_embedding(nanostring[key], representation_key="corrected_counts", show=True, key='resolvi_corrected', n_neighbors=20, extra_save=key)
    nanostring[key].write_h5ad(f'{path}{key}/complete_adata_filtered.h5ad')

In [None]:
def plot_umap_embedding(adata, key, ax, color='cluster'):
    adata.obsm['X_umap'] = adata.obsm[f'X_umap_{key}']
    sc.pl.umap(adata, color=color, frameon=False, show=False, ax=ax)

In [None]:
fig, axs = plt.subplots(6, 8, figsize=(30, 20))

# Iterate over each AnnData object and each axis
for ind, key in enumerate(nanostring):
    print(key)
    plot_umap_embedding(nanostring[key], key='resolvi_latent_semisupervised', ax=axs[ind, 0])
    axs[ind, 0].get_legend().remove()
    axs[ind, 0].set_title(key)
    plot_umap_embedding(nanostring[key], key='resolvi_latent_semisupervised', ax=axs[ind, 1], color='patient')
    axs[ind, 1].get_legend().remove()
    plot_umap_embedding(nanostring[key], key='resolvi_latent', ax=axs[ind, 2])
    axs[ind, 2].get_legend().remove()
    plot_umap_embedding(nanostring[key], key='resolvi_latent', ax=axs[ind, 3], color='patient')
    axs[ind, 3].get_legend().remove()
    plot_umap_embedding(nanostring[key], key='raw_counts', ax=axs[ind, 4])
    axs[ind, 4].get_legend().remove()
    plot_umap_embedding(nanostring[key], key='raw_counts', ax=axs[ind, 5], color='patient')
    axs[ind, 5].get_legend().remove()
    plot_umap_embedding(nanostring[key], key='raw_counts_harmony', ax=axs[ind, 6])
    plot_umap_embedding(nanostring[key], key='raw_counts_harmony', ax=axs[ind, 7], color='patient')

# Adjust layout
#plt.tight_layout()
plt.savefig('figure3_vizgen/umap_comparison.pdf')
plt.show()

In [None]:
from contextlib import contextmanager
from scib_metrics.benchmark import Benchmarker

@contextmanager
def default_rcparams():
    default_params = plt.rcParams.copy()  # Store current rcParams
    plt.rcdefaults()   # Reset all rcParams to their defaults
    yield
    plt.rcParams.update(default_params)   # Restore rcParams to their original values

for key in nanostring:
    if key in ['baysor']: # os.path.exists(f'figure3_vizgen/scib_results_{key}.csv'):
        continue
    print(key)
    #sub = sc.pp.subsample(, n_obs=100000, copy=True)
    bm = Benchmarker(
        nanostring[key],
        batch_key="patient",
        label_key="cluster",
        embedding_obsm_keys=["X_resolVI", "X_resolvi_semisupervised", "X_pca_raw_counts_harmony", "X_pca_raw_counts"],
        pre_integrated_embedding_obsm_key='X_pca_raw_counts',
        n_jobs=12,
    )
    bm.benchmark()
    bm._results = bm._results.drop('pcr_comparison', axis=0)
    try:
        os.mkdir(f'figure3_vizgen/{key}/')
    except:
        pass
    bm.plot_results_table(min_max_scale=False, save_dir=f'figure3_vizgen/{key}/')
    bm.get_results(min_max_scale=False).to_csv(f'figure3_vizgen/scib_results_{key}.csv')

In [None]:
import json

In [None]:
sc_reference = sc.read('sc_reference_liver.h5ad')

In [None]:
sc_reference.X = sc_reference.raw.X

In [None]:
sc_reference = sc_reference[sc_reference.obs['disease']=='normal']
sc_reference.obs['author_cell_type'].value_counts()

In [None]:
sc_reference.obs['coarse_ct'] = sc_reference.obs['author_cell_type'].astype(str).map({
    'P-Hepato': 'Hepatocyte',
    'C-Hepato': 'Hepatocyte',
    'P-Hepato2': 'Hepatocyte',
    'C-Hepato2': 'Hepatocyte',
    'cvLSEC': 'Endothelial',
    'Hepato-Doublet': 'low quality',
    'Chol': 'Cholangiocyte',
    'Stellate': 'Fibroblast',
    'cvLSEC-Doublet': 'low quality',
    'ppLSEC': 'Endothelial',
    'Stellate-Doublet': 'low quality',
    'Prolif': 'low quality', 
    'aStellate': 'Fibroblast',
    'Monocyte': 'Myeloid',
    'I-Hepato': 'Hepatocyte',
    'Kupffer': 'Myeloid',
    'Kupffer-Doublet': 'low quality', 
    'CD4T': 'Lympho',
    'Chol-Doublet': 'low quality',
    'lrNK': 'Lympho',
    'cvEndo': 'Endothelial',
    'Tcell-Doublet': 'low quality',
    'Fibroblast': 'Fibroblast',
    'CholMucus': 'Cholangiocyte',
    'VSMC': 'Fibroblast',
    'AntiB': 'Bcell',
    'cvLSEC--Mac': 'low quality',
    'Chol--Stellate-Doublet': 'low quality',
    'Prolif-Mac': 'low quality',
    'Chol--Kupffer-Doublet': 'low quality'
})
sc_reference = sc_reference[~(sc_reference.obs['coarse_ct']=='low quality')]

In [None]:
sc_reference.var_names = sc_reference.var['feature_name'].astype(str)
sc_reference.var_names_make_unique()
sc_reference.obs_names_make_unique()
sc_reference = sc_reference[:, np.intersect1d(nanostring['baysor'].var_names, sc_reference.var['feature_name'])].copy()

In [None]:
sc_reference.layers['counts'] = sc_reference.X.copy()
sc.pp.normalize_total(sc_reference, layers=['counts'], target_sum=1e4)
sc_reference.obsm['counts'] = pd.DataFrame(sc_reference.layers['counts'].A, columns=sc_reference.var_names, index=sc_reference.obs_names)

In [None]:
_utils.double_positive_pmm(sc_reference, sc_reference.var_names, layer_key="counts", output_dir='figure3')

In [None]:
sc_reference.obsm['positive_pmm_counts']['celltype'] = sc_reference.obs['coarse_ct']
per_celltype_positive = sc_reference.obsm['positive_pmm_counts'].groupby('celltype').mean()

In [None]:
celltype_gene_dict = {}

# Iterate over each column
for col in per_celltype_positive.columns:
    # Check if only one value is above 0.2 and all other values are below 0.05
    if (per_celltype_positive[col] > 0.2).sum() == 1 and (per_celltype_positive[col] < 0.1).sum() == len(per_celltype_positive) - 1:
        # Get the celltype for which the value is above 0.2
        celltype = per_celltype_positive[per_celltype_positive[col] > 0.2].index[0]
        # If the celltype is not in the result dictionary, add it with an empty list
        if celltype not in celltype_gene_dict:
            celltype_gene_dict[celltype] = []
        # Append the column (gene) to the list of genes for this celltype
        celltype_gene_dict[celltype].append(col)

In [None]:
#celltype_gene_dict['Bcell'].remove('CD27')
celltype_gene_dict['Bcell'].remove('SELL')
#celltype_gene_dict.pop('Cholangiocyte')
#celltype_gene_dict.pop('Hepatocyte')
celltype_gene_dict

In [None]:
#with open('/home/cane/Documents/yoseflab/can/resolVI-eval/liver_nanostring/celltype_markers.json', "r") as j:
#    marker_dict = json.load(j)
marker_dict = celltype_gene_dict
marker_list_ = sum(marker_dict.values(), [])
marker_list = []
_ = [marker_list.append(x) for x in marker_list_ if x not in marker_list]
marker_dict

In [None]:
import json
with open('figure3_vizgen/celltype_markers_sc_ref.json', 'w') as fp:
    json.dump(marker_dict, fp)

In [None]:
import json
with open('figure3_vizgen/celltype_markers_sc_ref.json', 'r') as fp:
    marker_dict = json.load(fp)

In [None]:
for key in nanostring:
    sc.pp.normalize_total(nanostring[key], target_sum=1e4, layers=['counts', 'generated_expression', 'generated_expression_semisupervised'])
    nanostring[key].obsm['counts'] = pd.DataFrame(nanostring[key][:, marker_list].layers['counts'].A, columns=marker_list, index=nanostring[key].obs_names)
    nanostring[key].obsm['generated_expression'] = pd.DataFrame(nanostring[key][:, marker_list].layers['generated_expression'].A,
                                                                columns=marker_list, index=nanostring[key].obs_names)
    nanostring[key].obsm['generated_expression_semisupervised'] = pd.DataFrame(nanostring[key][:, marker_list].layers['generated_expression_semisupervised'].A,
                                                                columns=marker_list, index=nanostring[key].obs_names)
    _utils.cosine_distance_celltype(
        nanostring[key], marker_dict, layer_key="generated_expression", output_dir='figure3_vizgen', extra_save=key, vmax=0.3)
    plt.show()
    _utils.cosine_distance_celltype(
        nanostring[key], marker_dict, layer_key="generated_expression_semisupervised", output_dir='figure3_vizgen', extra_save=key + '_semisupervised' , vmax=0.3)
    plt.show()

In [None]:
for key in nanostring:
    _utils.double_positive_pmm(
        nanostring[key], marker_list, marker_dict=marker_dict, layer_key="generated_expression", output_dir='figure3_vizgen', file_save=key)
    _utils.double_positive_pmm(
        nanostring[key], marker_list, marker_dict=marker_dict, layer_key="generated_expression_semisupervised", output_dir='figure3_vizgen', file_save=key + '_semisupervised')
    plt.show()

In [None]:
key = 'proseg'
sc.pp.normalize_total(nanostring[key], target_sum=1e4, layers=['estimated'])
nanostring[key].obsm['estimated_expression'] = pd.DataFrame(nanostring[key][:, marker_list].layers['estimated'].A,
                                                            columns=marker_list, index=nanostring[key].obs_names)

In [None]:
_utils.double_positive_pmm(
        nanostring[key], marker_list, marker_dict=marker_dict, layer_key="estimated_expression", output_dir='figure3_vizgen', file_save=key + '_estimated')

In [None]:
from itertools import combinations, chain

# Get all pairs across all lists
all_genes = list(chain.from_iterable(celltype_gene_dict.values()))

# Get all pairs within each list
#celltype_gene_dict.pop('Vascular')
within_pairs = {key: list(combinations(value, 2)) for key, value in celltype_gene_dict.items()}
within_pairs = sum(within_pairs.values(), [])
across_pairs = list(set(combinations(all_genes, 2)) - set(within_pairs))

In [None]:
def double_positive_boxplot(adata_dict, gene_pairs, save_key='', show=False):
    index = pd.MultiIndex.from_tuples(gene_pairs)
    dp_ct_counts = pd.DataFrame(index=index, columns=adata_dict.keys())
    dp_ct_generated = pd.DataFrame(index=index, columns=adata_dict.keys())

    for i in adata_dict.keys():
        for gene_x, gene_y in gene_pairs:
            subset = adata_dict[i]
            positives_counts = subset.obsm['positive_pmm_counts'][[gene_x, gene_y]].sum(1)
            positives_generated = subset.obsm[f'positive_pmm_generated_expression'][[gene_x, gene_y]].sum(1)
            dp_ct_counts.loc[(gene_x, gene_y), i] = (np.sum(positives_counts==2) / np.sum(positives_counts>0) if np.sum(positives_counts)>0 else -0.01)
            dp_ct_generated.loc[(gene_x, gene_y), i] = (np.sum(positives_generated==2) / np.sum(positives_generated>0) if np.sum(positives_generated)>0 else -0.01)

    dp_ct_counts_df = pd.DataFrame(dp_ct_counts).melt()
    dp_ct_generated_df = pd.DataFrame(dp_ct_generated).melt()

    dp_ct_counts_df['source'] = 'Measured'
    dp_ct_generated_df['source'] = 'Generated'

    # Concatenate the dataframes
    df = pd.concat([dp_ct_counts_df, dp_ct_generated_df])

    # Create a color palette
    palette = {'Measured': (1, 0, 0, 0.2), 'Generated': (0, 0, 1, 0.2)}  # red and blue with alpha=0.2
    palette2 = {'Measured': (0.5, 0.5, 0.5, 0.2), 'Generated': (0.5, 0.5, 0.5, 0.2)}  # red and blue with alpha=0.2

    # Create the dotplot
    plt.figure(figsize=(12, 8))
    sns.set(style='white')
    violin_parts = sns.violinplot(df, y='value', x='variable', hue='source', palette=palette, split=True, inner=None)
    for pc in violin_parts.collections:
        pc.set_alpha(0.8)

    # Create the boxplot with a third of the width and black color
    sns.boxplot(df, y='value', x='variable', hue='source', width=0.6, palette=palette2, fliersize=1.5, gap=0.5)

    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.savefig(f'figure3_vizgen/overlapping_{save_key}.pdf')

    if show:
        plt.show()

In [None]:
double_positive_boxplot(nanostring, within_pairs, save_key=key + '_all', show=True)
double_positive_boxplot(nanostring, across_pairs, save_key=key, show=True)

In [None]:
path = '/external_data/other/resolvi_final_other_files/liver_cancer_vizgen/'
segmentations = ['baysor', 'original', 'proseg', 'cellpose_triplez', 'cellpose_singlez', 'cellpose_nuclei']
sub = sc.read_h5ad(f'{path}{segmentations[3]}/complete_adata_filtered.h5ad')

In [None]:
sc.pl.rank_genes_groups_dotplot(sub, layer='counts', standard_scale='var')