In [None]:
import squidpy as sq
import cellcharter as cc
import scanpy as sc
import anndata as ad
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
from lightning.pytorch import seed_everything
seed_everything(42)
sc._settings.ScanpyConfig.n_jobs = -1
plt.style.use('ggplot')

In [None]:
autok = cc.tl.ClusterAutoK.load("/Users/lukashat/Documents/PhD_Schapiro/Projects/Myeloma_Standal/results/downstream/neighborhoods/cellcharter/autok_dmr_scale_layer3")
ax = cc.pl.autok_stability(autok, return_ax=True)
ax.figure.set_size_inches(12, 6)  

In [None]:
layers = 3
scale = True
dmr = True
if dmr:
    if scale:
        trvae_model = 'trvae_scaled'
        save = f'dmr_scale_layer{layers}'
    else:
        trvae_model = 'trvae_nonscaled'
        save = f'dmr_noscale_layer{layers}'
else:
    if scale:
        save = f'nodmr_noscale_layer{layers}'
    else:
        save = f'nodmr_noscale_layer{layers}'

n = 9
save_path = f'/Users/lukashat/Documents/PhD_Schapiro/Projects/Myeloma_Standal/results/downstream/neighborhoods/cellcharter/{save}/k{n}/'
if not os.path.exists(save_path):
    os.makedirs(save_path)



In [None]:
adata = ad.read_h5ad("/Users/lukashat/Documents/PhD_Schapiro/Projects/Myeloma_Standal/results/standard/adatas/cells_annotated_pp_osteocytes_cleaned.h5ad")
adata.X =  adata.layers['arcsinh']
adata.X = adata.X.astype(np.float32).copy()
if scale:
    sc.pp.scale(adata)
adata.X = adata.X.astype(np.float32).copy()
condition_key = 'patient_ID'
cell_type_key = 'Phenotype3'
conditions = adata.obs[condition_key].unique().tolist()
if dmr:
    model = cc.tl.TRVAE.load(
        f'/Users/lukashat/Documents/PhD_Schapiro/Projects/Myeloma_Standal/github/myeloma_standal/src/downstream/advanced_neighborhood/cellcharter/{trvae_model}', 
        adata, 
        map_location='cpu'
    )
    adata.obsm['X_trvae']= model.get_latent(adata.X, adata.obs['patient_ID'])
adata.obsm['spatial'] = np.array(adata.obs[['X_centroid', 'Y_centroid']])
adata.uns['spatial'] = {
    'X': adata.obs['X_centroid'].values,
    'Y': adata.obs['Y_centroid'].values
}
unique_image_ids = adata.obs['image_ID'].unique()
adata.uns['spatial'] = {image_id: {} for image_id in unique_image_ids}
sq.gr.spatial_neighbors(adata, library_key='image_ID', coord_type='generic', delaunay=True)
cc.gr.remove_long_links(adata, distance_percentile=95)
if dmr:
    cc.gr.aggregate_neighbors(adata, n_layers=layers, use_rep='X_trvae')
else:
    cc.gr.aggregate_neighbors(adata, n_layers=layers)

In [None]:
gmm = cc.tl.Cluster(
    n_clusters=n, 
    random_state=42,
    trainer_params={'accelerator':'cpu',
                    'devices': 'auto',
                    'enable_progress_bar': True}
)
gmm.fit(adata, use_rep='X_cellcharter')
adata.obs[f'cellcharterCN{n}'] = gmm.predict(adata, use_rep='X_cellcharter')
adata.write_h5ad(f'{save_path}/adata_cellcharterCN{n}.h5ad')