In [None]:
import sys
sys.path.insert(0,'/home/cane/Documents/yoseflab/can/resolVI')
from scvi.external import RESOLVI

In [None]:
import pyro

In [None]:
import scanpy as sc
import pandas as pd
import numpy as np
import scvi
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
scvi.settings.seed = 0
sc.set_figure_params(dpi=100, dpi_save=300, format='png', frameon=False, vector_friendly=True, fontsize=14, color_map='viridis', figsize=None)
sc.settings.figdir = 'figure2'
plt.rcParams['pdf.fonttype'] = 'truetype'
plt.rcParams['svg.fonttype'] = 'none'
plt.rcParams['pdf.use14corefonts'] = True 

In [None]:
sys.path.append('.')
import _utils

## Load datasets

In [None]:
adata = sc.read_h5ad('xenium_brain/original_high_lr/complete_adata.h5ad')

In [None]:
cdata = sc.read_h5ad('xenium_brain/original_semisupervised/complete_adata.h5ad')

In [None]:
resolvi = RESOLVI.load('xenium_brain/original_high_lr/resolvae')
resolvi_semisupervised = RESOLVI.load('xenium_brain/original_semisupervised/resolvae')

In [None]:
_ = resolvi.history['elbo_train'].plot()
_ = resolvi_semisupervised.history['elbo_train'].plot()
plt.show()

In [None]:
bdata = adata.copy()
bdata.X = bdata.layers['raw_counts']
sc.pp.downsample_counts(bdata, counts_per_cell=20, random_state=0)
bdata.layers['raw_counts'] = bdata.X

bdata.obsm["x_resolVI_downsampled_30"] = resolvi.get_latent_representation(adata=bdata)
_utils.compute_umap_embedding(bdata, representation_key="x_resolVI_downsampled_20", n_comps=None, show=True, key='resolvi_latent_downsampled20', n_neighbors=20)

In [None]:
_utils.compute_umap_embedding(adata, representation_key="X_resolVI", n_comps=None, show=True, key='resolvi_latent', n_neighbors=20)
_utils.compute_umap_embedding(adata, representation_key="counts", show=True, key='raw_counts', n_neighbors=20)
_utils.compute_umap_embedding(adata, representation_key="generated_expression", show=True, key='resolvi_generated', n_neighbors=20)
_utils.compute_umap_embedding(adata, representation_key="corrected_counts", show=True, key='resolvi_corrected', n_neighbors=20)

In [None]:
_utils.compute_umap_embedding(cdata, representation_key="corrected_counts", show=True, key='resolvi_corrected', n_neighbors=20)
_utils.compute_umap_embedding(cdata, representation_key="generated_expression", show=True, key='resolvi_generated', n_neighbors=20)

## SCIB metrics

In [None]:
pd.options.display.max_columns = None

In [None]:
import scvi

scvi.model.SCVI.setup_anndata(adata, layer="raw_counts")
vae = scvi.model.SCVI(adata, gene_likelihood="nb", n_layers=2, n_latent=10)
vae.train()
adata.obsm["scVI"] = vae.get_latent_representation()

In [None]:
lvae = scvi.model.SCANVI.from_scvi_model(
    vae,
    adata=adata,
    labels_key="predicted_celltype",
    unlabeled_category="Unknown",
)
lvae.train(max_epochs=20, n_samples_per_label=100)
adata.obsm["scANVI"] = lvae.get_latent_representation()

In [None]:
import scib_metrics
from scib_metrics.benchmark import Benchmarker

In [None]:
adata.obsm['Unintegrated'] = adata.obsm['X_pca_raw_counts']
adata.obsm['resolVI'] = adata.obsm['X_resolVI']
adata.obsm['resolVI Supervised'] = cdata.obsm['X_resolVI']
adata.obsm['Generated Expression'] = adata.obsm['X_pca_resolvi_generated']
adata.obsm['Generated Expression Supervised'] = cdata.obsm['X_pca_resolvi_generated']

In [None]:
adata = sc.read('xenium_brain/full_data_concatenated.h5ad')

In [None]:
import scib_metrics
batch_correction = scib_metrics.benchmark._core.BatchCorrection(pcr_comparison=False)

In [None]:
adata.obs['diffusion_cells'] = [str(i<5120) for i in adata.obs['x_centroid']]
adata.obs['diffusion_cells'] = [str(i>0.8) for i in adata.obs['true_proportion']]

In [None]:
sc.pl.spatial(adata, spot_size=30, color='diffusion_cells')

In [None]:
bm = Benchmarker(
    adata,
    batch_key="diffusion_cells",
    label_key="predicted_celltype",
    embedding_obsm_keys=["Unintegrated", "Generated Expression", "resolVI", "scVI", "scANVI", "resolVI Supervised", "Generated Expression Supervised"],
    pre_integrated_embedding_obsm_key='Unintegrated',
    n_jobs=12,
)
bm.benchmark()

In [None]:
bm._results = bm._results.drop('pcr_comparison', axis=0)

In [None]:
bm.get_results(min_max_scale=False).to_csv('xenium_brain/scib_results_all.csv')

In [None]:
from contextlib import contextmanager

@contextmanager
def default_rcparams():
    default_params = plt.rcParams.copy()  # Store current rcParams
    plt.rcdefaults()   # Reset all rcParams to their defaults
    yield
    plt.rcParams.update(default_params)   # Restore rcParams to their original values
    plt.rcParams['svg.fonttype'] = 'none'

# Example usage
with default_rcparams():
    bm.plot_results_table(min_max_scale=False, save_dir='figure2')

In [None]:
bm.get_results().to_csv('scib_results_filtered.csv')

In [None]:
sc.pl.umap(adata, color=['diffusion_proportion', 'background_proportion'], ncols=1, size=1, save='semisupervised_proportions.pdf')
sc.pl.umap(adata[adata.obs['true_proportion']>0.8], color=['predicted_celltype'], ncols=1, size=1, save='semisupervised_low_diffusion.pdf')
sc.pl.umap(adata[adata.obs['true_proportion']<0.8], color=['predicted_celltype'], ncols=1, size=1, save='semisupervised_high_diffusion.pdf')

## Prediction

In [None]:
from pynndescent import PyNNDescentTransformer
from sklearn.neighbors import KNeighborsClassifier
from sklearn.pipeline import make_pipeline

complete_data = adata # [adata.obs['diffusion_proportion']<0.2]
train_X = complete_data[~pd.isna(complete_data.obs["predicted_celltype"])].obsm["scVI"]
train_Y = complete_data[~pd.isna(complete_data.obs["predicted_celltype"])].obs["predicted_celltype"].to_numpy()
knn = make_pipeline(
    PyNNDescentTransformer(
        n_neighbors=15,
        parallel_batch_queries=True,
    ),
    KNeighborsClassifier(metric="precomputed", weights="uniform"),
)
knn.fit(train_X, train_Y)
adata.obs['predicted_celltypes_scvi'] = knn.predict(adata.obsm["scVI"])

In [None]:
from pynndescent import PyNNDescentTransformer
from sklearn.neighbors import KNeighborsClassifiersdc
from sklearn.pipeline import make_pipeline

complete_data = adata # [adata.obs['diffusion_proportion']<0.2]
train_X = complete_data[~pd.isna(complete_data.obs["predicted_celltype"])].obsm["resolVI"]
train_Y = complete_data[~pd.isna(complete_data.obs["predicted_celltype"])].obs["predicted_celltype"].to_numpy()
knn = make_pipeline(
    PyNNDescentTransformer(
        n_neighbors=15,
        parallel_batch_queries=True,
    ),
    KNeighborsClassifier(metric="precomputed", weights="uniform"),
)
knn.fit(train_X, train_Y)
adata.obs['predicted_celltypes'] = knn.predict(adata.obsm["resolVI"])

In [None]:
from pynndescent import PyNNDescentTransformer
from sklearn.neighbors import KNeighborsClassifier
from sklearn.pipeline import make_pipeline

complete_data = adata # [adata.obs['diffusion_proportion']<0.2]
train_X = complete_data[~pd.isna(complete_data.obs["predicted_celltype"])].obsm["resolVI Supervised"]
train_Y = complete_data[~pd.isna(complete_data.obs["predicted_celltype"])].obs["predicted_celltype"].to_numpy()
knn = make_pipeline(
    PyNNDescentTransformer(
        n_neighbors=15,
        parallel_batch_queries=True,
    ),
    KNeighborsClassifier(metric="precomputed", weights="uniform"),
)
knn.fit(train_X, train_Y)
adata.obs['predicted_celltypes_supervised'] = knn.predict(adata.obsm["resolVI Supervised"])

In [None]:
from pynndescent import PyNNDescentTransformer
from sklearn.neighbors import KNeighborsClassifier
from sklearn.pipeline import make_pipeline

complete_data = adata # [adata.obs['diffusion_proportion']<0.2]
train_X = complete_data[~pd.isna(complete_data.obs["predicted_celltype"])].obsm["scANVI"]
train_Y = complete_data[~pd.isna(complete_data.obs["predicted_celltype"])].obs["predicted_celltype"].to_numpy()
knn = make_pipeline(
    PyNNDescentTransformer(
        n_neighbors=15,
        parallel_batch_queries=True,
    ),
    KNeighborsClassifier(metric="precomputed", weights="uniform"),
)
knn.fit(train_X, train_Y)
adata.obs['predicted_celltypes_scanvi'] = knn.predict(adata.obsm["scANVI"])

In [None]:
resolvi_semisupervised = RESOLVI.load('xenium_brain/original_semisupervised/resolvae')

In [None]:
def predict(
    model,
    adata = None,
    indices = None,
    soft: bool = False,
    batch_size: int | None = 500,
    num_samples: int | None = 30
) -> np.ndarray | pd.DataFrame:
    adata = model._validate_anndata(adata)

    if indices is None:
        indices = np.arange(adata.n_obs)

    sampled_prediction = model.sample_posterior_predictive(
        adata=adata,
        indices=indices,
        model=model.module.model_corrected,
        return_sites=['probs_prediction'],
        num_samples=num_samples,
        return_samples=False,
        batch_size=batch_size,
        batch_steps=10
    )
    y_pred = sampled_prediction['post_sample_means']['probs_prediction']

    if not soft:
        y_pred = y_pred.argmax(axis=1)
        predictions = [model._code_to_label[p] for p in y_pred]
        return np.array(predictions)
    else:
        n_labels = len(y_pred[0])
        predictions = pd.DataFrame(
            y_pred,
            columns=model._label_mapping[:n_labels],
            index=adata.obs_names[indices],
        )
        return predictions

In [None]:
adata.obs['celltype_predicted'] = predict(resolvi_semisupervised)
subset = adata[~adata.obs['predicted_celltype'].isin(['Junk', 'Endothelial-Astrocyte', 'Thalamus_Glia', 'Thalamus_Oligodendrocyte'])].copy()
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score

# Create a DataFrame from 'cluster' and 'celltype_predicted' columns
df = pd.DataFrame({
    'cluster': subset.obs['cluster'],
    'celltype_predicted': subset.obs['celltype_predicted']
})

# Group by 'cluster' and compute accuracy for each group
accuracy = df.groupby('cluster').apply(lambda x: accuracy_score(x['cluster'], x['celltype_predicted']))

# Plot accuracy for each cluster
plt.figure(figsize=(10, 6))
accuracy.plot(kind='bar')
plt.ylabel('Accuracy')
plt.ylim(0., 1)
plt.title('Accuracy for each label in cluster')
plt.savefig('figure2/accuracy_resolvi_prediction.pdf')
plt.show()

In [None]:
from sklearn.metrics import f1_score

In [None]:
f1_score(df['cluster'], df['celltype_predicted'], average='macro')

In [None]:
subset.obs['celltype_predicted_scanvi'] = lvae.predict(adata=subset)
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score

# Create a DataFrame from 'cluster' and 'celltype_predicted' columns
df2 = pd.DataFrame({
    'cluster': subset.obs['cluster'],
    'celltype_predicted': subset.obs['celltype_predicted_scanvi']
})

# Group by 'cluster' and compute accuracy for each group
accuracy = df2.groupby('cluster').apply(lambda x: accuracy_score(x['cluster'], x['celltype_predicted']))

# Plot accuracy for each cluster
plt.figure(figsize=(10, 6))
accuracy.plot(kind='bar')
plt.ylabel('Accuracy')
plt.ylim(0., 1)
plt.title('Accuracy for each label in cluster')
plt.savefig('figure2/accuracy_scanvi_prediction.pdf')
plt.show()

In [None]:
adata.write_h5ad('xenium_brain/full_data_concatenated.h5ad')

## Spatial display

In [None]:
adata = sc.read_h5ad('xenium_brain/full_data_concatenated.h5ad')
sc.pl.spatial(adata, color='cluster', spot_size=20, sort_order=False, save='spatial_celltypes.pdf')

In [None]:
with plt.rc_context({"figure.figsize": (8, 8), "figure.dpi": (300)}):
    sc.pl.spatial(adata, color='celltype_predicted', spot_size=30, sort_order=False, palette=sc.plotting.palettes.default_102, )

In [None]:
resolvi = RESOLVI.load('xenium_brain/original_high_lr/resolvae')

In [None]:
samples_res = resolvi.sample_posterior_predictive(
    model=resolvi.module.model_residuals,
    return_sites=['px_rate', 'obs'],
    num_samples=10, return_samples=False, batch_size=2000, batch_steps=20)

In [None]:
samples_corr = resolvi.sample_posterior_predictive(
    model=resolvi.module.model_corrected,
    return_sites=['px_rate', 'obs'],
    num_samples=10, return_samples=False, batch_size=5000, batch_steps=10)
samples_corr = pd.DataFrame(samples_corr).T

In [None]:
adata.layers['generated_expression_small'] = samples_corr.loc['post_sample_q50', 'obs']

In [None]:
sc.pp.normalize_total(adata, layers=['generated_expression_small', 'generated_expression', 'corrected_counts', 'raw_counts'])

In [None]:
import matplotlib as mpl
mpl.rcParams.update(mpl.rcParamsDefault)

In [None]:
with plt.rc_context({"figure.figsize": (8, 8), "figure.dpi": (300)}):
    sc.pl.spatial(adata[adata.obs['celltype_predicted']=='Microglia'],
                  color=['Gfap', 'Slc17a6', 'Slc17a7', 'Trem2', 'Aqp4', 'Pecam1', 'diffusion_proportion'], spot_size=100, layer='raw_counts', vmax='p95', sort_order=False)
    

In [None]:
with plt.rc_context({"figure.figsize": (8, 8), "figure.dpi": (300)}):
    sc.pl.spatial(adata[adata.obs['celltype_predicted']=='Microglia'],
                  color=['Gfap', 'Slc17a6', 'Slc17a7', 'Trem2', 'Aqp4', 'Pecam1', 'diffusion_proportion'], spot_size=100, layer='generated_expression', vmax='p95', sort_order=False)

In [None]:
with plt.rc_context({"figure.figsize": (8, 8), "figure.dpi": (300)}):
    sc.pl.spatial(adata,
                  color=['Slc17a6', 'Trem2', 'diffusion_proportion'], spot_size=30, layer='raw_counts', vmax=[10, 20], sort_order=False, save='spatial_all_slc17a6_trem2.pdf')

In [None]:
with plt.rc_context({"figure.figsize": (8, 8), "figure.dpi": (300)}):
    sc.pl.spatial(adata[adata.obs['celltype_predicted']=='Microglia'],
                  color=['Slc17a6', 'Trem2', 'diffusion_proportion'], spot_size=100, layer='raw_counts', vmax=[10, 20], sort_order=False, save='spatial_microglia_slc17a6_trem2.pdf')

In [None]:
with plt.rc_context({"figure.figsize": (8, 8), "figure.dpi": (300)}):
    sc.pl.spatial(adata,
                  color=['Slc17a6', 'Trem2'], spot_size=30, layer='generated_expression', vmax=[10, 20], sort_order=False, save='spatial_all_slc17a6_trem2_generated.pdf')

In [None]:
with plt.rc_context({"figure.figsize": (8, 8), "figure.dpi": (300)}):
    sc.pl.spatial(adata[adata.obs['celltype_predicted']=='Microglia'],
                  color=['Slc17a6', 'Trem2'], spot_size=100, layer='generated_expression', vmax=[10, 20], sort_order=False, save='spatial_microglia_slc17a6_trem2_generated.pdf')

In [None]:
ax = sc.pl.scatter(adata, x='Slc17a6', y='Trem2', layers='raw_counts', size=30, color='cluster', groups=['Microglia', 'Excitatory Neurons Thalamus'], show=False)
ax.set_xlim(0, 100)
ax.set_ylim(0, 40)
plt.savefig('figure2/scatter_Slc17a6_Trem2_raw.pdf')
plt.show()

In [None]:
adata.layers['generated_expression'][adata.layers['generated_expression']>70] = 70

In [None]:
sc.pl.scatter(adata, x='Slc17a6', y='Trem2', layers='generated_expression', size=30, color='cluster', groups=['Microglia', 'Excitatory Neurons Thalamus'], save='Slc17a6_Trem2_generated.pdf')

In [None]:
adata.write_h5ad('xenium_brain/full_data.h5ad')

## Single cell reference

In [None]:
adata = sc.read_h5ad('xenium_brain/full_data.h5ad')

In [None]:
import scanpy as sc

In [None]:
single_cell_reference = sc.read_loom('xenium_brain/single_cell_reference.loom')

In [None]:
single_cell_reference.var_names_make_unique()
single_cell_reference.obs_names_make_unique()
single_cell_reference = single_cell_reference[:, np.intersect1d(adata.var_names, single_cell_reference.var_names)].copy()
single_cell_reference

In [None]:
single_cell_reference.layers['counts'] = single_cell_reference.X.copy()
single_cell_reference.obsm['counts'] = pd.DataFrame(single_cell_reference.layers['counts'].A, columns=single_cell_reference.var_names, index=single_cell_reference.obs_names)

In [None]:
_utils.double_positive_pmm(single_cell_reference, single_cell_reference.var_names, layer_key="counts", output_dir='figure2')

In [None]:
single_cell_reference.obsm['counts'].head()

In [None]:
single_cell_reference.obsm['positive_pmm_counts']['celltype'] = single_cell_reference.obs['Class']
per_celltype_positive = single_cell_reference.obsm['positive_pmm_counts'].groupby('celltype').mean()
per_celltype_positive.drop('PeripheralGlia', inplace=True)

In [None]:
per_celltype_positive

In [None]:
celltype_gene_dict = {}

# Iterate over each column
for col in per_celltype_positive.columns:
    # Check if only one value is above 0.2 and all other values are below 0.05
    if (per_celltype_positive[col] > 0.05).sum() == 1 and (per_celltype_positive[col] < 0.01).sum() == len(per_celltype_positive) - 1:
        # Get the celltype for which the value is above 0.2
        celltype = per_celltype_positive[per_celltype_positive[col] > 0.05].index[0]
        # If the celltype is not in the result dictionary, add it with an empty list
        if celltype not in celltype_gene_dict:
            celltype_gene_dict[celltype] = []
        # Append the column (gene) to the list of genes for this celltype
        celltype_gene_dict[celltype].append(col)

In [None]:
celltype_gene_dict

In [None]:
celltype_gene_dict = {'Vascular': ['Adgrl4', 'Cldn5', 'Emcn', 'Nostrin', 'Pln', 'Slfn5', 'Sox17'],
 'Neurons': ['Bcl11b',
  'Cabp7',
  'Cbln1',
  'Cbln4',
  'Chrm2',
  'Cntnap4',
  'Cpne4',
  'Fibcd1',
  'Gsg1l',
  'Hs3st2',
  'Lamp5',
  'Ndst4',
  'Necab1',
  'Nell1',
  'Neurod6',
  'Nwd2',
  'Plcxd3',
  'Rxfp1',
  'Satb2',
  'Slc17a6',
  'Sncg',
  'Syt2',
  'Syt6'],
 'Immune': ['Cd53', 'Ikzf1', 'Lyz2', 'Siglech', 'Spi1', 'Trem2'],
 'Oligos': ['Sema3d'],
 'Ependymal': ['Spag16', 'Trp73']}

In [None]:
marker_list = sum(celltype_gene_dict.values(), [])

In [None]:
sc.pp.normalize_total(adata, layers=['generated_expression', 'raw_counts', 'corrected_counts'])
adata.obsm['counts'] = pd.DataFrame(adata[:, marker_list].layers['raw_counts'].A, columns=marker_list, index=adata.obs_names)
adata.obsm['generated_expression'] = pd.DataFrame(np.array(adata[:, marker_list].layers['generated_expression'].A), columns=marker_list, index=adata.obs_names)
adata.obsm['corrected_counts'] = pd.DataFrame(np.array(adata[:, marker_list].layers['corrected_counts'].A), columns=marker_list, index=adata.obs_names)

In [None]:
_utils.cosine_distance_celltype(single_cell_reference, celltype_gene_dict, layer_key="counts", output_dir='figure2')
plt.show()

In [None]:
_utils.cosine_distance_celltype(adata, celltype_gene_dict, layer_key="generated_expression", output_dir='figure2', vmax=0.2)
plt.show()

In [None]:
_utils.double_positive_pmm(adata, marker_list, marker_dict=celltype_gene_dict, layer_key="generated_expression", output_dir='figure2')
plt.show()

In [None]:
_utils.double_positive_pmm(adata[adata.obs['true_proportion']>0.9], marker_list, marker_dict=celltype_gene_dict, layer_key="generated_expression", output_dir='figure2', file_save='_high_wrong')
plt.show()

In [None]:
_utils.double_positive_pmm(single_cell_reference, marker_list, marker_dict=celltype_gene_dict, layer_key='counts', output_dir='figure2', file_save='_single_cell')
plt.show()

In [None]:
def double_positive_boxplot(adata, gene_pairs, save_key='', show=False):
    ranges = [0] + [i/10 for i in np.arange(5, 11)]
    index = pd.MultiIndex.from_tuples(gene_pairs)
    dp_ct_counts = pd.DataFrame(index=index, columns=ranges[1:])
    dp_ct_generated = pd.DataFrame(index=index, columns=ranges[1:])

    for index, i in enumerate(ranges[1:]):
        for gene_x, gene_y in gene_pairs:
            subset = adata[np.logical_and(adata.obs['true_proportion']>ranges[index], adata.obs['true_proportion']<ranges[index+1])] 
            positives_counts = subset.obsm['positive_pmm_counts'][[gene_x, gene_y]].sum(1)
            positives_generated = subset.obsm[f'positive_pmm_generated_expression'][[gene_x, gene_y]].sum(1)
            dp_ct_counts.loc[(gene_x, gene_y), i] = (np.sum(positives_counts==2) / np.sum(positives_counts>0) if np.sum(positives_counts)>0 else -0.01)
            dp_ct_generated.loc[(gene_x, gene_y), i] = (np.sum(positives_generated==2) / np.sum(positives_generated>0) if np.sum(positives_generated)>0 else -0.01)

    dp_ct_counts_df = pd.DataFrame(dp_ct_counts).melt()
    dp_ct_generated_df = pd.DataFrame(dp_ct_generated).melt()

    dp_ct_counts_df['source'] = 'Measured'
    dp_ct_generated_df['source'] = 'Generated'

    # Concatenate the dataframes
    df = pd.concat([dp_ct_counts_df, dp_ct_generated_df])

    # Create a color palette
    palette = {'Measured': (1, 0, 0, 0.2), 'Generated': (0, 0, 1, 0.2)}  # red and blue with alpha=0.2
    palette2 = {'Measured': (0.5, 0.5, 0.5, 0.2), 'Generated': (0.5, 0.5, 0.5, 0.2)}  # red and blue with alpha=0.2

    # Create the dotplot
    plt.figure(figsize=(12, 8))
    sns.set(style='white')
    violin_parts = sns.violinplot(df, y='value', x='variable', hue='source', palette=palette, split=True, inner=None)
    for pc in violin_parts.collections:
        pc.set_alpha(0.8)

    # Create the boxplot with a third of the width and black color
    sns.boxplot(df, y='value', x='variable', hue='source', width=0.6, palette=palette2, fliersize=1.5, gap=0.5)

    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.savefig(f'figure2/overlapping_{save_key}.pdf')

    if show:
        plt.show()

In [None]:
from itertools import combinations, chain

# Get all pairs across all lists
all_genes = list(chain.from_iterable(celltype_gene_dict.values()))

# Get all pairs within each list
#celltype_gene_dict.pop('Vascular')
within_pairs = {key: list(combinations(value, 2)) for key, value in celltype_gene_dict.items()}
within_pairs = sum(within_pairs.values(), [])
across_pairs = list(set(combinations(all_genes, 2)) - set(within_pairs))

In [None]:
within_ct_counts_reference = single_cell_reference.uns['double_positive_counts'].reset_index().melt(id_vars=['index'], var_name='Gene', value_name='Value')

# Set a MultiIndex with both the row and column labels
within_ct_counts_reference.set_index(['index', 'Gene'], inplace=True)
within_ct_counts_reference = within_ct_counts_reference[within_ct_counts_reference.index.get_level_values('Gene') != within_ct_counts_reference.index.get_level_values('index')]
# Display the resulting DataFrame
print(within_ct_counts_reference)

In [None]:
within_ct_counts_reference[within_ct_counts_reference>0.1].dropna(how='all')

In [None]:
subset = within_ct_counts_reference[within_ct_counts_reference>0.2].dropna(how='all')
index_list = subset.index.to_list()
within_pairs_coexpressed = index_list

In [None]:
within_pairs_coexpressed

In [None]:
within_pairs_coexpressed

In [None]:
double_positive_boxplot(adata, within_pairs_coexpressed, save_key='original_true_coexpressed', show=True)
double_positive_boxplot(adata, across_pairs, save_key='original_false_coexpressed', show=True)

In [None]:
ls xenium_brain

## Other segmentations

In [None]:
bdata = sc.read_h5ad(f'xenium_brain/original_high_lr/complete_adata.h5ad')
sc.tl.rank_genes_groups(bdata, groupby='predicted_celltype')
sc.pl.rank_genes_groups_dotplot(bdata, n_genes=3, save='original_segmentation.pdf')

In [None]:
bdata

In [None]:
cdata = sc.read_h5ad(f'xenium_brain/proseg_nucleus/complete_adata.h5ad')
cdata.uns['dendrogram_predicted_celltype'] = bdata.uns['dendrogram_predicted_celltype']
sc.tl.rank_genes_groups(cdata, groupby='predicted_celltype')
sc.pl.rank_genes_groups_dotplot(cdata, n_genes=3, save='proseg_segmentation.pdf')

In [None]:
for i in ['original_high_lr', 'proseg_nucleus', 'original_nucleus', 'baysor_prior']:
    print(i)
    bdata = sc.read_h5ad(f'xenium_brain/{i}/complete_adata.h5ad')
    sc.pp.normalize_total(bdata, layers=['generated_expression', 'raw_counts'])
    sc.pl.spatial(bdata[bdata.obs['predicted_celltype']=='Microglia'], save=f'_spatial_{i}_microglia_generated.pdf',
                  color=['Gfap', 'Slc17a6', 'Slc17a7', 'Trem2', 'Aqp4', 'Pecam1', 'diffusion_proportion'], spot_size=100, layer='generated_expression', vmax='p95', sort_order=False)
    sc.pl.spatial(bdata[bdata.obs['predicted_celltype']=='Microglia'], save=f'_spatial_{i}_microglia_raw.pdf',
                  color=['Gfap', 'Slc17a6', 'Slc17a7', 'Trem2', 'Aqp4', 'Pecam1', 'diffusion_proportion'], spot_size=100, layer='raw_counts', vmax='p95', sort_order=False)
    bdata.obsm['counts'] = pd.DataFrame(bdata[:, marker_list].layers['raw_counts'].A, columns=marker_list, index=bdata.obs_names)
    bdata.obsm['generated_expression'] = pd.DataFrame(np.array(bdata[:, marker_list].layers['generated_expression'].A), columns=marker_list, index=bdata.obs_names)
    _utils.double_positive_pmm(bdata, marker_list, marker_dict=celltype_gene_dict, layer_key="generated_expression", output_dir='figure2', file_save=f'_{i}')
    double_positive_boxplot(bdata, within_pairs_coexpressed, save_key=f'{i}_true_coexpressed', show=True)
    double_positive_boxplot(bdata, across_pairs, save_key=f'{i}_false_coexpressed', show=True)

In [None]:
celltype_gene_dict_extended = {}

# Iterate over each column
for col in per_celltype_positive.columns:
    # Check if only one value is above 0.2 and all other values are below 0.05
    if (per_celltype_positive[col] > 0.1).sum() == 1 and (per_celltype_positive[col] < 0.05).sum() == len(per_celltype_positive) - 1:
        # Get the celltype for which the value is above 0.2
        celltype = per_celltype_positive[per_celltype_positive[col] > 0.1].index[0]
        # If the celltype is not in the result dictionary, add it with an empty list
        if celltype not in celltype_gene_dict_extended:
            celltype_gene_dict_extended[celltype] = []
        # Append the column (gene) to the list of genes for this celltype
        celltype_gene_dict_extended[celltype].append(col)

In [None]:
bdata = sc.read_h5ad(f'xenium_brain/proseg_nucleus/complete_adata.h5ad')

In [None]:
import json
with open('figure2/celltype_markers_sc_ref_extended.json', 'w') as fp:
    json.dump(celltype_gene_dict_extended, fp)

In [None]:
celltype_gene_dict = {'Vascular': ['Adgrl4', 'Cldn5', 'Emcn', 'Nostrin', 'Pln', 'Slfn5', 'Sox17'],
 'Neurons': ['Bcl11b',
  'Cabp7',
  'Cbln1',
  'Cbln4',
  'Chrm2',
  'Cntnap4',
  'Cpne4',
  'Fibcd1',
  'Gsg1l',
  'Hs3st2',
  'Lamp5',
  'Ndst4',
  'Necab1',
  'Nell1',
  'Neurod6',
  'Nwd2',
  'Plcxd3',
  'Rxfp1',
  'Satb2',
  'Slc17a6',
  'Sncg',
  'Syt2',
  'Syt6'],
 'Immune': ['Cd53', 'Ikzf1', 'Lyz2', 'Siglech', 'Spi1', 'Trem2'],
 'Oligos': ['Sema3d'],
 'Ependymal': ['Spag16', 'Trp73']}

In [None]:
marker_list = sum(celltype_gene_dict.values(), [])

In [None]:
sc.pp.normalize_total(bdata, layers=['generated_expression', 'counts', 'estimated'])
bdata.obsm['counts'] = pd.DataFrame(bdata[:, marker_list].layers['counts'].A, columns=marker_list, index=bdata.obs_names)
bdata.obsm['estimated'] = pd.DataFrame(np.array(bdata[:, marker_list].layers['estimated'].A), columns=marker_list, index=bdata.obs_names)
_utils.double_positive_pmm(bdata, marker_list, marker_dict=celltype_gene_dict, layer_key="estimated_expression", output_dir='figure2', file_save=f'proseg_corrected_rates')

In [None]:
bdata.obsm['counts'] = pd.DataFrame(bdata[:, marker_list].layers['counts'].A, columns=marker_list, index=bdata.obs_names)
bdata.obsm['generated_expression'] = pd.DataFrame(np.array(bdata[:, marker_list].layers['generated_expression'].A), columns=marker_list, index=bdata.obs_names)
_utils.double_positive_pmm(bdata, marker_list, marker_dict=celltype_gene_dict, layer_key="generated_expression", output_dir='figure2', file_save=f'proseg_counts')