In [None]:
import os
import shutil
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import anndata
import scanpy as sc
sc.settings.n_jobs = 56
sc.settings.set_figure_params(dpi=180, dpi_save=300, frameon=False, figsize=(4, 4), fontsize=8, facecolor='white')

import ALLCools
from ALLCools.integration.seurat_class import SeuratIntegration

In [None]:
# Define the parameters
workspace_path = 'integration_workspace'

# Load and preprocess the data

In [None]:
# Load the datasets
adata_seq = sc.read_h5ad(os.path.join(workspace_path, 'adata_seq_common_genes.h5ad'))

adata_merfish_raw = sc.read_h5ad(os.path.join(workspace_path, 'adata_merfish.h5ad'))
adata_merfish = adata_merfish_raw.copy()

In [None]:
# Normalize and scale the data
# Note: scaling before merging resulted in better co-embedding
sc.pp.normalize_total(adata_seq, target_sum=1000)
sc.pp.log1p(adata_seq)
sc.pp.scale(adata_seq)

sc.pp.normalize_total(adata_merfish, target_sum=1000)
sc.pp.log1p(adata_merfish)
sc.pp.scale(adata_merfish)

In [None]:
# Merge the datasets
adata_merge = adata_seq.concatenate(adata_merfish,
                                    batch_categories=['seq', 'merfish'],
                                    batch_key='modality',
                                    index_unique=None)

In [None]:
%%time
# Get the significant PCs
n_pcs = 100
sc.tl.pca(adata_merge, svd_solver='arpack', n_comps=100)

In [None]:
adata_list = [adata_merge[adata_merge.obs['modality'] == 'seq'],
              adata_merge[adata_merge.obs['modality'] == 'merfish']
             ]

# Integration

In [None]:
%%time
# Find the integration anchors
integrator = SeuratIntegration()
integrator.find_anchor(adata_list,
                       k_local=None,
                       key_local='X_pca',
                       k_anchor=5,
                       key_anchor='X',
                       dim_red='cca',
                       max_cc_cells=100000,
                       k_score=30,
                       k_filter=None, #why?
                       scale1=False,
                       scale2=False,
                       n_components=n_pcs,
                       n_features=200,
                       alignments=[[[0], [1]]])

# Label transfer

In [None]:
%%time
cell_type_col = 'integration_partition'

transfer_results = integrator.label_transfer(
    ref=[0],
    qry=[1],
    categorical_key=[cell_type_col],
    key_dist='X_pca',
    kweight=100,
    npc=n_pcs
)

integrator.save_transfer_results_to_adata(adata_merge, transfer_results)

In [None]:
# Assign the transfered labels and the confidence
adata_merfish_raw.obs[cell_type_col + '_transfer'] = transfer_results[cell_type_col].idxmax(axis=1
                                                                                           ).astype('category')
adata_merfish_raw.obs[cell_type_col + '_confidence'] = transfer_results[cell_type_col].max(axis=1)

n_transfered = len(np.unique(adata_merfish_raw.obs[cell_type_col + '_transfer']))
n_total = len(np.unique(adata_merge.obs[cell_type_col + '_transfer']))
print(f'Transfered {n_transfered}/{n_total} cell types.')

In [None]:
# Save the label transfer results for each class
partition_path = os.path.join(workspace_path, 'partitions')
partitions = np.unique(adata_merfish_raw.obs[cell_type_col + '_transfer'])

for pn in partitions:
    adata_subset = adata_merfish_raw[adata_merfish_raw.obs[cell_type_col + '_transfer'] == pn]
    adata_subset.write_h5ad(os.path.join(partition_path, pn.replace('/', '-').replace(' ', '_'), 
                                    'adata_merfish_integrated.h5ad'), compression='gzip')

In [None]:
plt.hist(adata_merfish_raw.obs['integration_partition_confidence'], bins=30)
plt.title('partition_confidence')

# Coembedding

In [None]:
%%time
# Correct the PCs using the integration anchors
corrected = integrator.integrate(key_correct='X_pca',
                                 row_normalize=True,
                                 n_components=n_pcs,
                                 k_weight=100,
                                 sd=1,
                                 alignments=[[[0], [1]]])

adata_merge.obsm['X_pca_integrate'] = np.concatenate(corrected)

In [None]:
# Calculate KNN using the integrated PCs
sc.pp.neighbors(adata_merge, use_rep='X_pca_integrate')

In [None]:
%%time
# Generate the PAGA plot for the initial arrangement of the UMAP
sc.tl.paga(adata_merge, groups=cell_type_col + '_transfer')
sc.pl.paga(adata_merge, save='_tmp.png')
shutil.move('figures/paga_tmp.png', os.path.join(workspace_path, 'integration_paga_round1.png'))

In [None]:
%%time
# Save the umap
sc.tl.umap(adata_merge, init_pos='paga', min_dist=0.5)
sc.pl.umap(adata_merge, color='modality', save='_tmp.png')
shutil.move('figures/umap_tmp.png', os.path.join(workspace_path, 'integration_umap_round1_modality.png'))
sc.pl.umap(adata_merge, color=cell_type_col + '_transfer', save='_tmp.png', palette='gist_ncar')
shutil.move('figures/umap_tmp.png', os.path.join(workspace_path, 'integration_umap_round1_partitions.png'))

In [None]:
# Save the merged adata
adata_merge.write_h5ad(os.path.join(workspace_path, 'adata_merged_round1.h5ad'), compression='gzip')

In [None]:
coembedding_umap_df = adata_merge.obs[[]].copy()
coembedding_umap_df['umap_x'] = adata_merge.obsm['X_umap'][:, 0]
coembedding_umap_df['umap_y'] = adata_merge.obsm['X_umap'][:, 1]
coembedding_umap_df.to_csv(os.path.join(workspace_path, 'coembedding_umap.csv'))