In [1]:
import scanpy as sc
import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
import os

In [2]:
def subsample_anndata(adata, label_col="y", n_samples=3390):
    sampled_indices = (
        adata.obs.groupby(label_col)
        .apply(lambda x: x.sample(min(n_samples, len(x)), random_state=42))
        .index.get_level_values(1)
    )
    
    return adata[sampled_indices].copy(), sampled_indices

In [3]:
save_dir= f'/data/scDisentangle figures/Tabula/'

# Read original data

In [4]:
adata = sc.read_h5ad(
    '../../Datasets/preprocessed_datasets/tabula.h5ad'
)
cov_key = 'cell_type'
cond_key = 'specie'
stim_name = 'sapiens'
control_name = 'muris'

# Read predictions for OOD donors

In [5]:
adata_cat = None

In [6]:
# for donor in covs:
#     pred = sc.read_h5ad(f'SCDISENTANGLE/myocarditis/pred_adata/{donor}_1.h5ad')
#     pred = pred[(pred.obs['tissue_org'] == 'Blood') & (pred.obs['donor'] == donor)]
#     if adata_cat is None:
#         adata_cat = pred.copy()
#     else:
#         adata_cat = adata_cat.concatenate(pred)

In [7]:
pred = sc.read_h5ad(
    '../../Benchmarks/SCDISENTANGLE/Tabula/predictions/tabula/pulmonary alveolar type 2 cell_5.h5ad'
)

  utils.warn_names_duplicates("obs")


In [8]:
adata_cat = pred.copy()

  utils.warn_names_duplicates("obs")


# Normalize original data and predictions

In [9]:
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)

adata.X seems to be already log-transformed.


In [10]:
sc.pp.normalize_total(adata_cat)
sc.pp.log1p(adata_cat)

In [13]:
import pandas as pd

In [14]:
pd.crosstab(adata.obs['specie'], adata.obs['cell_type'])

cell_type,club cell,B cell,natural killer cell,"CD4-positive, alpha-beta T cell","CD8-positive, alpha-beta T cell",pericyte,basophil,neutrophil,myeloid dendritic cell,plasmacytoid dendritic cell,plasma cell,mature NK T cell,classical monocyte,non-classical monocyte,pulmonary alveolar type 2 cell,endothelial cell of lymphatic vessel,intermediate monocyte,adventitial cell,vein endothelial cell,bronchial smooth muscle cell
specie,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1
muris,15,1501,1193,551,870,61,130,552,245,75,49,420,7922,1010,125,41,1749,526,320,2339
sapiens,1747,663,1019,2124,1898,739,1322,371,33,19,148,163,1569,527,11594,315,2785,581,1336,220


In [None]:
adata_cat.obs['cell_type']

In [None]:
#adata_cat = adata_cat[adata_cat.obs['cell_type'] == 'CD8']
adata = adata[adata.obs['cell_type'] == 'pulmonary alveolar type 2 cell']

In [None]:
adata

# Predicted Tissue resident T cells

In [None]:
pred_sapiens = adata_cat[adata_cat.obs['specie_pred'] == 'sapiens']

In [None]:
pred_sapiens.obs['y'] = 'Predicted Sapiens cells'
adata.obs['y'] = adata.obs['specie'].copy()
adata.obs['y'] = adata.obs['y'].replace('sapiens', 'Real Sapiens cells')
adata.obs['y'] = adata.obs['y'].replace('muris', 'Real Muris cells')

In [None]:
concat = adata.concatenate(pred_sapiens)

In [None]:
concat.obs['y'].value_counts()

In [None]:
concat_subsampled, indices = subsample_anndata(concat)

In [None]:
concat_subsampled.obs['y'].value_counts()

In [None]:
sc.pp.pca(concat)
# sc.pp.neighbors(concat)
# sc.tl.umap(concat)

In [None]:
custom_palette = {
    'Real Sapiens cells': "#264653",  # Deep Royal Blue
    'Real Muris cells': "#f4a261",  # Muted Red-Orange
    'Predicted Sapiens cells': "#d62728"  # Red
}

In [None]:
sc.pl.pca(concat, color='y', palette=custom_palette, show=False)
plt.tight_layout()
#plt.savefig(f'{save_dir}PCA.png', dpi=600)
#plt.savefig(f'{save_dir}PCA.pdf', dpi=600)

In [None]:
ag = ['PDCD1', 'TNFRSF9', 'IFNG', 'TNF'] # Increase ag recognition and activation
stemness = ['TCF7', 'KLF2'] # Decreased TF involved in cell stemness and blood recirculation
mnp = ['CXCL10', 'IL27'] # some CD14 MNPs subset showed increase of it iplicated in irMyocarditis
tgf = 'TGF-B'

In [None]:
adata.uns['rank_genes_groups_specie']['sapiens']['pulmonary alveolar type 2 cell']

In [None]:
degs = adata.uns['rank_genes_groups_specie']['sapiens']['pulmonary alveolar type 2 cell'].tolist()

In [None]:
# MTRNR2L8 is the top DEG and CCL4 is the second top DEG that doesn't start with MTRNR

In [None]:
from matplotlib.patches import Patch

In [None]:
for gene_name in degs[:5]: #['MTRNR2L8', 'TNF', 'PDCD1', 'CCL4', 'GZMB']:
    gene_idx = adata.var_names.get_loc(gene_name)
    concat.obs['value'] = concat[:, gene_idx].X.toarray()
    
    category_order = list(custom_palette.keys()) #['Real Circulating T/NK cells', 'Real Heart infiltrating T/NK cells', 'Predicted Heart infiltrating T/NK cells']
    legend_patches = [Patch(color=custom_palette[cat], label=cat) for cat in category_order]

    sns.set(style="white",
            font_scale=1.0,  # Keep scale at 1.0 to manually control sizes
            rc={
                "axes.titlesize": 12,
                "axes.labelsize": 12,
                "xtick.labelsize": 12,
                "ytick.labelsize": 12,
                "legend.fontsize": 12
            })

    plt.figure(figsize=(10, 6))
    ax = sns.violinplot(
        x="y", y="value", data=concat.obs, inner="quartile", 
        order=category_order, palette=custom_palette
    )
    
    ax.set_title(gene_name, fontsize=12)
    ax.set_xlabel("")  # optional, if you want to remove x-axis label
    ax.set_ylabel("Expression", fontsize=12)

    ax.set_xticklabels(ax.get_xticklabels(), fontsize=12)
    ax.set_yticklabels(ax.get_yticks(), fontsize=12)

    ax.grid(False)
    
    plt.tight_layout()
    # plt.savefig(f'{save_dir}{gene_name}.png', dpi=400)
    # plt.savefig(f'{save_dir}{gene_name}.pdf', dpi=400)