In [1]:
%load_ext autoreload
%autoreload 2

## Basic setup

In [None]:
import Concord as ccd
import scanpy as sc
import torch
import warnings
from pathlib import Path
warnings.filterwarnings('ignore')
import matplotlib as mpl
mpl.rcParams['svg.fonttype'] = 'none'
mpl.rcParams['pdf.fonttype'] = 42

data_dir = Path('../data/celegans_data/')
data_path = data_dir / "celegans_global_adata.h5ad"
# adata = sc.read(
#     data_path
# )
#adata = sc.read(data_dir/'celegans_Dec17-2043.h5ad')
adata = sc.read(data_dir/'adata_pknn_Jan23-1836.h5ad')

In [None]:
adata.obsm.keys()

In [None]:
concord_keys = [key for key in adata.obsm.keys() if 'sknn300_clr0.5_aug0.3' in key and 'UMAP' not in key]
concord_keys

In [None]:
import time
from pathlib import Path
proj_name = "concord_celegans"
save_dir = f"../save/dev_{proj_name}-{time.strftime('%b%d')}/"
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
#device = torch.device('mps' if torch.cuda.is_available() else 'cpu')
device = torch.device('mps')
file_suffix = f"{time.strftime('%b%d-%H%M')}"
seed = 0

In [None]:
adata.layers["counts"] = adata.X.copy()
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)
adata = adata[adata.obs['to.filter'] == 'FALSE']

## Run Concord

In [None]:
feature_list = ccd.ul.select_features(adata, n_top_features=10000, flavor='seurat_v3')

concord_args = {
        'adata': adata,
        'input_feature': feature_list,
        'batch_size':64,
        'latent_dim': 300,
        'encoder_dims':[1000],
        'decoder_dims':[1000],
        'augmentation_mask_prob': 0.3, 
        'clr_temperature': 0.5,
        'p_intra_knn': 0.3,
        'sampler_knn': 300,
        'min_p_intra_domain': .95,
        'n_epochs': 10,
        'domain_key': 'batch',
        'verbose': False,
        'inplace': False,
        'seed': seed,
        'device': device,
        'save_dir': save_dir
    }

### Concord, no decoder

In [None]:
output_key = f'Concord_pknn{concord_args["p_intra_knn"]}_sknn{concord_args["sampler_knn"]}_clr{concord_args["clr_temperature"]}_aug{concord_args["augmentation_mask_prob"]}_bs{concord_args["batch_size"]}_latent{concord_args["latent_dim"]}'
cur_ccd = ccd.Concord(use_decoder=False, **concord_args)

In [None]:
cur_ccd.encode_adata(input_layer_key='X_log1p', output_key=output_key)
# Save the latent embedding to a filem, so that it can be loaded later
ccd.ul.save_obsm_to_hdf5(cur_ccd.adata, save_dir / f"obsm_{file_suffix}.h5")
adata.obsm = cur_ccd.adata.obsm # If not inplace

In [None]:
basis = output_key
ccd.ul.run_umap(adata, source_key=basis, result_key=f'{basis}_UMAP', n_components=2, n_neighbors=30, min_dist=0.1, metric='euclidean', random_state=seed)

In [None]:

show_basis = basis + '_UMAP'
show_cols = ['cell.type', 'plot.cell.type',  'raw.embryo.time']
pal = {'cell.type': 'tab20', 'plot.cell.type':'tab20', 'raw.embryo.time': 'BlueGreenRed'}
ccd.pl.plot_embedding(
    adata, show_basis, show_cols, figsize=(11,4), dpi=600, ncols=3, font_size=3, point_size=1, legend_loc='on data',
    pal = pal,
    save_path=save_dir / f"{show_basis}_{file_suffix}.pdf"
)

In [None]:
cur_ccd.init_dataloader(input_layer_key="X_log1p", use_sampler=True)

In [None]:
import numpy as np
dataloader = cur_ccd.loader[0][0]
# from 0 to 100
batch_indices = np.arange(10)
data_structure = cur_ccd.data_structure
attribute = 'idx'
hl_indices = []
found_indices = set()

for batch_idx, batch in enumerate(dataloader):
    if batch_idx in batch_indices:
        attr_data = batch[data_structure.index(attribute)].cpu()
        hl_indices.append(attr_data)
        found_indices.add(batch_idx)
        if len(found_indices) == len(batch_indices):
            break

hl_indices

In [None]:
# Plot embedding highlighting points in each batch
batch_idx = 3
show_cols = ['cell.type', 'plot.cell.type',  'raw.embryo.time']
pal = {'cell.type': 'tab20', 'plot.cell.type':'tab20', 'raw.embryo.time': 'BlueGreenRed'}
show_basis = f'{basis}_UMAP'
ccd.pl.plot_embedding(
    adata, show_basis, show_cols, figsize=(11,4), dpi=600, ncols=3, font_size=3, point_size=1, legend_loc=None,
    highlight_indices = hl_indices[batch_idx].cpu().numpy(), highlight_color='black', highlight_size=2,
    pal = pal,
    save_path=save_dir / f"{show_basis}_{file_suffix}_batchhl_{batch_idx}.pdf"
)

In [None]:
# plot everything
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from matplotlib import font_manager, rcParams
file_suffix = f"{time.strftime('%b%d-%H%M')}"
# Set Arial as the default font
custom_rc = {
    'font.family': 'Arial',  # Set the desired font for this plot
}

show_cols = ['raw.embryo.time', 'cell.type', 'plot.cell.type']
pal = {'cell.type': 'Set1', 'plot.cell.type':'tab20', 'raw.embryo.time': 'BlueGreenRed'}
#basis_types = ['', 'PAGA', 'KNN', 'PCA', 'UMAP']
basis_types = ['UMAP']
font_size=5
point_size=.5
alpha=0.8
figsize=(10,1.6)
ncols = len(concord_keys)
nrows = int(np.ceil(len(concord_keys) / ncols))
k=15
edges_color='grey'
edges_width=0.05
layout='kk'
threshold = 0.1
node_size_scale=0.1
edge_width_scale=0.1

with plt.rc_context(rc=custom_rc):
    ccd.pl.plot_all_embeddings(
        adata,
        concord_keys,
        color_bys=show_cols,
        basis_types=basis_types,
        pal=pal,
        k=k,
        edges_color=edges_color,
        edges_width=edges_width,
        layout=layout,
        threshold=threshold,
        node_size_scale=node_size_scale,
        edge_width_scale=edge_width_scale,
        font_size=font_size,
        point_size=point_size,
        alpha=alpha,
        figsize=figsize,
        ncols=ncols,
        seed=1,
        leiden_key='leiden',
        legend_loc=None,
        save_dir=save_dir,
        file_suffix=file_suffix+'_subset',
        save_format='pdf'
    )

### Subgroup analysis

In [None]:
subset_name = 'Neuron'
selected_lins = ['Ciliated_non_amphid_neuron', 'Ciliated_amphid_neuron']
adata_subset = adata[adata.obs['cell.type'].isin(selected_lins)]
print(adata_subset.shape)

In [None]:
# Plot expression pattern of top VEGs
n_veg = 3000
top_vegs = ccd.ul.select_features(adata_subset, n_top_features=n_veg, flavor='seurat_v3')
ds_cellnum=1000
# downsample adata_subset
adata_subset_ds = adata_subset[np.random.choice(adata_subset.obs.index, ds_cellnum, replace=False)]
ccd.pl.heatmap_with_annotations(adata_subset_ds[:, top_vegs], val = 'X', obs_keys=['cell.type', 'plot.cell.type', 'embryo.time'], 
                                pal=pal,
                                yticklabels=False, figsize=(5, 5),
                                save_path=save_dir / f"heatmap_{subset_name}_veg{n_veg}_{file_suffix}.pdf")

In [None]:
# Plot expression pattern of top VEGs
n_veg = 10000
top_vegs = ccd.ul.select_features(adata, n_top_features=n_veg, flavor='seurat_v3')
ds_cellnum=3000
# downsample adata
adata_ds = adata[np.random.choice(adata.obs.index, ds_cellnum, replace=False)]
ccd.pl.heatmap_with_annotations(adata_ds[:, top_vegs], val = 'X', obs_keys=['cell.type', 'plot.cell.type', 'embryo.time'], 
                                pal=pal,
                                yticklabels=False, figsize=(5, 5),
                                save_path=save_dir / f"heatmap_global_veg{n_veg}_{file_suffix}.pdf")

In [None]:
# Run umap and PCA for all latent embeddings
for basis in concord_keys:
    print("Running UMAP for", basis)
    if basis not in adata_subset.obsm:
        continue
    if 'UMAP' not in basis:
        ccd.ul.run_umap(adata_subset, source_key=basis, result_key=f'{basis}_UMAP_{subset_name}', n_components=2, n_neighbors=30, min_dist=0.1, metric='euclidean', random_state=seed)


In [None]:
# plot everything
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from matplotlib import font_manager, rcParams
file_suffix = f"{time.strftime('%b%d-%H%M')}"
# Set Arial as the default font
custom_rc = {
    'font.family': 'Arial',  # Set the desired font for this plot
}

show_cols = ['raw.embryo.time']
pal = {'cell.type': 'tab20', 'plot.cell.type':'tab20', 'raw.embryo.time': 'BlueGreenRed'}
#basis_types = ['', 'PAGA', 'KNN', 'PCA', 'UMAP']
basis_types = ['UMAP']
font_size=5
point_size=0.5
alpha=0.8
figsize=(10,1.6)
ncols = len(concord_keys)
nrows = int(np.ceil(len(concord_keys) / ncols))
k=15
edges_color='grey'
edges_width=0.05
layout='kk'
threshold = 0.1
node_size_scale=0.1
edge_width_scale=0.1

with plt.rc_context(rc=custom_rc):
    ccd.pl.plot_all_embeddings(
        adata_subset,
        concord_keys,
        color_bys=show_cols,
        basis_types=basis_types,
        pal=pal,
        k=k,
        edges_color=edges_color,
        edges_width=edges_width,
        layout=layout,
        threshold=threshold,
        node_size_scale=node_size_scale,
        edge_width_scale=edge_width_scale,
        font_size=font_size,
        point_size=point_size,
        alpha=alpha,
        figsize=figsize,
        ncols=ncols,
        seed=1,
        legend_loc='on data',
        leiden_key='leiden',
        save_dir=save_dir,
        file_suffix=file_suffix,
        save_format='pdf'
    )


In [None]:
# Run umap and PCA for all latent embeddings
for basis in concord_keys:
    print("Running UMAP for", basis)
    if basis not in adata_subset.obsm:
        continue
    ccd.ul.run_umap(adata_subset, source_key=basis, result_key=f'{basis}_UMAP_{subset_name}', n_components=2, n_neighbors=30, min_dist=0.1, metric='euclidean', random_state=seed)


In [None]:
basis_types = ['UMAP_Neuron']
point_size=1
figsize=(10,1.6)
with plt.rc_context(rc=custom_rc):
    ccd.pl.plot_all_embeddings(
        adata_subset,
        concord_keys,
        color_bys=show_cols,
        basis_types=basis_types,
        pal=pal,
        k=k,
        edges_color=edges_color,
        edges_width=edges_width,
        layout=layout,
        threshold=threshold,
        node_size_scale=node_size_scale,
        edge_width_scale=edge_width_scale,
        font_size=font_size,
        point_size=point_size,
        alpha=alpha,
        figsize=figsize,
        ncols=ncols,
        seed=1,
        legend_loc='on data',
        leiden_key='leiden',
        save_dir=save_dir,
        file_suffix=file_suffix,
        save_format='pdf'
    )

In [None]:
ncols = len(concord_keys)
    
# Create a figure with subplots
fig, axes = plt.subplots(1, ncols, figsize=(figsize[0] * ncols, figsize[1]))
figsize=(10,10)
# Plot heatmaps for each layer
glist = []
for i, key in enumerate(concord_keys):

    g = ccd.pl.heatmap_with_annotations(
        adata_subset, 
        key, 
        obs_keys=show_cols, 
        cmap='viridis', 
        cluster_rows=True, 
        cluster_cols=True, 
        value_annot=False, 
        vmax = 5,
        figsize=figsize,
        show=False
    )
        
    # Save the clustermap figure to a buffer
    from io import BytesIO
    buf = BytesIO()
    g.figure.savefig(buf, format='png', dpi=600)
    buf.seek(0)

    # Load the image from the buffer and display it in the subplot
    import matplotlib.image as mpimg
    img = mpimg.imread(buf)
    axes[i].imshow(img)
    axes[i].axis('off')
    axes[i].set_title(f'{key}')

    # Close the clustermap figure to free memory
    plt.close(g.figure)
    buf.close()



plt.savefig(save_dir/f'heatmap_all_{file_suffix}_subset.pdf', dpi=600, bbox_inches='tight')


In [None]:
cur_ccd.init_dataloader(input_layer_key="X_log1p", use_sampler=True)

In [None]:
data_structure

In [None]:
import numpy as np
dataloader = cur_ccd.loader[0][0]
# from 0 to 100
batch_indices = np.arange(10)
data_structure = cur_ccd.data_structure
attribute = 'idx'
hl_indices = []
found_indices = set()

for batch_idx, batch in enumerate(dataloader):
    if batch_idx in batch_indices:
        attr_data = batch[data_structure.index(attribute)].cpu()
        hl_indices.append(attr_data)
        found_indices.add(batch_idx)
        if len(found_indices) == len(batch_indices):
            break

hl_indices

In [None]:
# Plot embedding highlighting points in each batch
basis =  concord_keys[4]
show_cols = ['cell.type', 'plot.cell.type',  'raw.embryo.time']
pal = {'cell.type': 'tab20', 'plot.cell.type':'tab20', 'raw.embryo.time': 'BlueGreenRed'}
show_basis = f'{basis}_UMAP'
ccd.pl.plot_embedding(
    adata, show_basis, show_cols, figsize=(11,4), dpi=600, ncols=3, font_size=3, point_size=1, legend_loc=None,
    highlight_indices = hl_indices[0].cpu().numpy(), highlight_color='black', highlight_size=8,
    pal = pal,
    save_path=save_dir / f"{show_basis}_{file_suffix}.pdf"
)