In [1]:
import os
import pickle
import matplotlib.pyplot as plt
import numpy as np
import sys
sys.path.append("/home/yufan") 
import scanpy as sc
import numpy as np
import seaborn as sns
from collections import defaultdict
import pandas as pd

current_dir = os.getcwd()
parent_dir = os.path.abspath(os.path.join(current_dir, ".."))
sys.path.insert(0, parent_dir)

from age_map import age_label_mapping_10, age_label_mapping_20
from analysis_utils import sim_gene_age, sim_celltype_age, sim_gene_youngest, sim_tissue_age, plot_clock_tissue, sim_gene_gene, label_to_float, cal_z_score, top_n_genes_by_age, cal_p

In [2]:
adata = sc.read_h5ad("chunk3562_stress_tumor_repair2.h5ad")
adata.obs["age_group_10"] = adata.obs["development_stage"].map(age_label_mapping_10)
adata.obs["age_group_20"] = adata.obs["development_stage"].map(age_label_mapping_20)
adata

  utils.warn_names_duplicates("obs")


AnnData object with n_obs × n_vars = 10000 × 58604
    obs: 'soma_joinid_column', 'dataset_id', 'assay', 'assay_ontology_term_id', 'cell_type', 'cell_type_ontology_term_id', 'development_stage', 'development_stage_ontology_term_id', 'disease', 'disease_ontology_term_id', 'donor_id', 'is_primary_data', 'self_reported_ethnicity', 'self_reported_ethnicity_ontology_term_id', 'sex', 'sex_ontology_term_id', 'suspension_type', 'tissue', 'tissue_ontology_term_id', 'tissue_general', 'tissue_general_ontology_term_id', 'raw_sum', 'nnz', 'raw_mean_nnz', 'raw_variance_nnz', 'n_measured_vars', 'batch', 'GDF15_embedding_status', 'NPY_embedding_status', 'TFPI2_embedding_status', 'MEG3_embedding_status', 'WIF1_embedding_status', 'HACE1_embedding_status', 'SGK1_embedding_status', 'LRRC3B_embedding_status', 'MKRN1_embedding_status', 'THY1_embedding_status', 'LZTFL1_embedding_status', 'DUSP26_embedding_status', 'TMEFF2_embedding_status', 'PAX1_embedding_status', 'CXXC4_embedding_status', 'MXI1_embedding_s

### Figure 2: Age Prediction vs chronological age label

In [3]:
adata_cut = adata[adata.obs["disease"] == "normal"]
ground_truth = adata_cut.obs["development_stage"].apply(label_to_float).to_list()
prediction = adata_cut.obs["prediction_age"].apply(label_to_float).to_list()
y_test, y_pred = zip(*[(x, y) for x, y in zip(ground_truth, prediction) if x != -1.0 and y != -1.0])
y_test = np.array(list(y_test))
y_pred = np.array(list(y_pred))
z_scores = cal_z_score(y_test, y_pred)
df_z_scores = pd.DataFrame.from_dict(z_scores)
df_z_scores.head(5)

Unnamed: 0,ground_truth,prediction,z_score,r_value
0,29.0,29.0,-0.060622,0.960022
1,49.0,49.0,-0.060622,0.960022
2,0.326923,0.326923,-0.060622,0.960022
3,0.326923,0.326923,-0.060622,0.960022
4,0.326923,0.326923,-0.060622,0.960022


In [4]:
plt.figure(figsize=(8, 6))
sns.scatterplot(data=df_z_scores, x="ground_truth", y="prediction", hue="z_score", palette='coolwarm', edgecolor='k', alpha=0.6, legend=False)
plt.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'k--', lw=2)
norm = plt.Normalize(df_z_scores["z_score"].min(), df_z_scores["z_score"].max())
sm = plt.cm.ScalarMappable(cmap="coolwarm", norm=norm)
sm.set_array([])
plt.colorbar(sm, label='z-scored age gap')
plt.xlabel('chronological age',fontsize=18)
plt.ylabel('predicted age',fontsize=18)
plt.title(f'age prediction of healthy samples',fontsize=18)
plt.text(0.05, 0.95, f'r = {df_z_scores["r_value"].values[0]:.2f}', transform=plt.gca().transAxes, 
        fontsize=18, verticalalignment='top', bbox=dict(boxstyle="round,pad=0.3", facecolor="white", edgecolor="black"))
#plt.show()
plt.savefig("paper_plots/age_predictions_z_score/healthy_samples.png", dpi=300, bbox_inches='tight') 
plt.close()

  plt.colorbar(sm, label='z-scored age gap')


In [5]:
tissue = "respiratory airway"
adata_cut = adata[(adata.obs["disease"] == "normal") & (adata.obs["tissue"] == tissue)]
ground_truth = adata_cut.obs["development_stage"].apply(label_to_float).to_list()
prediction = adata_cut.obs["prediction_age"].apply(label_to_float).to_list()
y_test, y_pred = zip(*[(x, y) for x, y in zip(ground_truth, prediction) if x != -1.0 and y != -1.0])
y_test = np.array(list(y_test))
y_pred = np.array(list(y_pred))
z_scores = cal_z_score(y_test, y_pred)
df_z_scores = pd.DataFrame.from_dict(z_scores)
df_z_scores.head(5)

Unnamed: 0,ground_truth,prediction,z_score,r_value
0,33.0,33.0,-0.225251,0.982275
1,29.0,33.0,0.912858,0.982275
2,15.0,15.0,-0.225251,0.982275
3,4.0,4.0,-0.225251,0.982275
4,27.0,29.0,0.343804,0.982275


In [6]:
plt.figure(figsize=(8, 6))
sns.scatterplot(data=df_z_scores, x="ground_truth", y="prediction", hue="z_score", palette='coolwarm', edgecolor='k', alpha=0.6, legend=False)
plt.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'k--', lw=2)
norm = plt.Normalize(df_z_scores["z_score"].min(), df_z_scores["z_score"].max())
sm = plt.cm.ScalarMappable(cmap="coolwarm", norm=norm)
sm.set_array([])
plt.colorbar(sm, label='z-scored age gap')
plt.xlabel('chronological age',fontsize=18)
plt.ylabel('predicted age',fontsize=18)
plt.title(f'age prediction of healthy {tissue} samples',fontsize=18)
plt.text(0.05, 0.95, f'r = {df_z_scores["r_value"].values[0]:.2f}', transform=plt.gca().transAxes, 
        fontsize=18, verticalalignment='top', bbox=dict(boxstyle="round,pad=0.3", facecolor="white", edgecolor="black"))
#plt.show()
plt.savefig(f"paper_plots/age_predictions_z_score/{tissue}_healthy_samples.png", dpi=300, bbox_inches='tight') 
plt.close()

  plt.colorbar(sm, label='z-scored age gap')


In [7]:
tissue = "breast"
adata_cut = adata[(adata.obs["disease"] == "normal") & (adata.obs["tissue"] == tissue)]
ground_truth = adata_cut.obs["development_stage"].apply(label_to_float).to_list()
prediction = adata_cut.obs["prediction_age"].apply(label_to_float).to_list()
y_test, y_pred = zip(*[(x, y) for x, y in zip(ground_truth, prediction) if x != -1.0 and y != -1.0])
y_test = np.array(list(y_test))
y_pred = np.array(list(y_pred))
z_scores = cal_z_score(y_test, y_pred)
df_z_scores = pd.DataFrame.from_dict(z_scores)
df_z_scores.head(5)

Unnamed: 0,ground_truth,prediction,z_score,r_value
0,27.0,27.0,-0.112226,0.624188
1,27.0,36.0,0.709122,0.624188
2,23.0,33.0,0.800383,0.624188
3,29.0,29.0,-0.112226,0.624188
4,36.0,36.0,-0.112226,0.624188


In [8]:
plt.figure(figsize=(8, 6))
sns.scatterplot(data=df_z_scores, x="ground_truth", y="prediction", hue="z_score", palette='coolwarm', edgecolor='k', alpha=0.6, legend=False)
plt.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'k--', lw=2)
norm = plt.Normalize(df_z_scores["z_score"].min(), df_z_scores["z_score"].max())
sm = plt.cm.ScalarMappable(cmap="coolwarm", norm=norm)
sm.set_array([])
plt.colorbar(sm, label='z-scored age gap')
plt.xlabel('chronological age',fontsize=18)
plt.ylabel('predicted age',fontsize=18)
plt.title(f'age prediction of healthy {tissue} samples',fontsize=18)
plt.text(0.05, 0.95, f'r = {df_z_scores["r_value"].values[0]:.2f}', transform=plt.gca().transAxes, 
        fontsize=18, verticalalignment='top', bbox=dict(boxstyle="round,pad=0.3", facecolor="white", edgecolor="black"))
#plt.show()
plt.savefig(f"paper_plots/age_predictions_z_score/{tissue}_healthy_samples.png", dpi=300, bbox_inches='tight') 
plt.close()

  plt.colorbar(sm, label='z-scored age gap')


### Fig 3: cell-type-specific age predictions

In [9]:
adata_cut = adata[adata.obs["disease"] == "normal"]
data = pd.DataFrame({
    'Predicted Age': adata_cut.obs["prediction_age"].apply(label_to_float).to_list(),
    'Ground Truth Age': adata_cut.obs["development_stage"].apply(label_to_float).to_list(),
    'Cell Type': adata_cut.obs["cell_type"].to_list()
})
data = data[(data['Predicted Age'] != -1) & (data['Ground Truth Age'] != -1)]
ref_cell_types = ['astrocyte','regulatory T cell','oligodendrocyte precursor cell',
                    'central memory CD4-positive, alpha-beta T cell',
                    'B cell','macrophage']
filtered_data = data[data['Cell Type'].isin(ref_cell_types)]

In [10]:
for cell_type in ref_cell_types:
    subset = filtered_data[filtered_data["Cell Type"] == cell_type]
    plt.figure(figsize=(6, 4))
    sns.regplot(
        x="Ground Truth Age",
        y="Predicted Age",
        data=subset,
        ci=None,
        scatter_kws={"s": 20, "color": "blue"},
        line_kws={"color": "black"},
    )
    
    correlation = cal_z_score(subset["Ground Truth Age"], subset["Predicted Age"])["r_value"]
    if pd.notna(correlation):
        plt.text(
            0.05,
            0.9,
            f'R = {correlation:.2f}',
            transform=plt.gca().transAxes,
            fontsize=12,
            verticalalignment='top',
            bbox=dict(boxstyle="round,pad=0.3", facecolor="white", edgecolor="black", alpha=0.8),
            color="black",
        )
        formatted_title = cell_type.replace("-positive", "+").replace("alpha-beta", "")
    plt.title(f"{formatted_title} age predictions", fontsize=16)
    plt.xlabel("chronological age")
    plt.ylabel("predicted age")
    
    plt.savefig(f"paper_plots/cell_type_age_predictions/{cell_type}_healthy_samples.png", dpi=300, bbox_inches='tight')
    plt.close()
    #plt.show()

### Figure 4: Similarity between cell type token and age token

In [11]:
adata_cut = adata[adata.obs["disease"] == "normal"]
cell_type_embeddings = adata_cut.obsm["cell_type_embeddings"]
age_embeddings = adata_cut.obsm["development_stage_embeddings"]
cell_types = adata_cut.obs["cell_type"].tolist()

celltype_age_sims = sim_celltype_age(cell_type_embeddings, age_embeddings, cell_types)

In [12]:
top10 = sorted(celltype_age_sims.items(), key=lambda x: np.mean(x[1]), reverse=True)[:10]
plot_data = {
    'cell type': [],
    'cosine similarity': []
}
for cell_type, cosine_similarity_values in top10:
    plot_data['cell type'].extend([cell_type] * len(cosine_similarity_values))
    plot_data['cosine similarity'].extend(cosine_similarity_values)
df = pd.DataFrame(plot_data)
#df.to_csv("mean_top10age-celltype.csv")

plt.figure(figsize=(10, 6))
sns.boxplot(x='cell type', y='cosine similarity', data=df, palette="magma_r")
plt.xticks(rotation=45, ha='right')
plt.ylim(0.1, 0.95)
plt.title('top 10 most age-related cell types',fontsize=20)
plt.xlabel('', fontsize=16)
plt.ylabel('cosine similarity', fontsize=16)
plt.tight_layout()
#plt.show()
plt.savefig("paper_plots/mean_top10_most age-related cell types.png", dpi=300, bbox_inches='tight')
plt.close()


Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  sns.boxplot(x='cell type', y='cosine similarity', data=df, palette="magma_r")


In [13]:
top10 = sorted(celltype_age_sims.items(), key=lambda x: np.mean(x[1]))[:10]
plot_data = {
    'cell type': [],
    'cosine similarity': []
}
for cell_type, cosine_similarity_values in top10:
    plot_data['cell type'].extend([cell_type] * len(cosine_similarity_values))
    plot_data['cosine similarity'].extend(cosine_similarity_values)
df = pd.DataFrame(plot_data)
#df.to_csv("mean_last10age-celltype.csv")

plt.figure(figsize=(10, 6))
sns.boxplot(x='cell type', y='cosine similarity', data=df, palette="magma")
plt.xticks(rotation=45, ha='right')
plt.ylim(0.1, 0.95)
plt.xlabel('', fontsize=16)
plt.ylabel('cosine similarity', fontsize=16)
plt.title('top 10 least age-related cell types',fontsize=20)
plt.tight_layout()
#plt.show()
plt.savefig("paper_plots/mean_top10_least age-related cell types.png", dpi=300, bbox_inches='tight')
plt.close()


Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  sns.boxplot(x='cell type', y='cosine similarity', data=df, palette="magma")


In [14]:
adata_cut = adata[adata.obs["disease"] == "normal"]
tissue_embeddings = adata_cut.obsm["tissue_embeddings"]
age_embeddings = adata_cut.obsm["development_stage_embeddings"]
tissues = adata_cut.obs["tissue_general"].tolist()

tissue_age_sims = sim_tissue_age(tissue_embeddings, age_embeddings, tissues)

In [15]:
top10 = sorted(tissue_age_sims.items(), key=lambda x: np.mean(x[1]), reverse=True)[:10]
plot_data = {
    'tissue': [],
    'cosine similarity': []
}
for tissue, cosine_similarity_values in top10:
    plot_data['tissue'].extend([tissue] * len(cosine_similarity_values))
    plot_data['cosine similarity'].extend(cosine_similarity_values)
df = pd.DataFrame(plot_data)
#df.to_csv("mean_top10age-tissue.csv")

plt.figure(figsize=(10, 6))
sns.boxplot(x='tissue', y='cosine similarity', data=df, palette="magma_r")
plt.xticks(rotation=45, ha='right')
plt.ylim(0.1, 0.95)
plt.title('top 10 most age-related tissues',fontsize=20)
plt.xlabel('', fontsize=16)
plt.ylabel('cosine similarity', fontsize=16)
plt.tight_layout()
#plt.show()
plt.savefig("paper_plots/mean_top10_most age-related general tissues.png", dpi=300, bbox_inches='tight')
plt.close()


Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  sns.boxplot(x='tissue', y='cosine similarity', data=df, palette="magma_r")


In [16]:
top10 = sorted(tissue_age_sims.items(), key=lambda x: np.mean(x[1]))[:10]
plot_data = {
    'tissue': [],
    'cosine similarity': []
}
for tissue, cosine_similarity_values in top10:
    plot_data['tissue'].extend([tissue] * len(cosine_similarity_values))
    plot_data['cosine similarity'].extend(cosine_similarity_values)
df = pd.DataFrame(plot_data)
#df.to_csv("mean_last10age-tissue.csv")

plt.figure(figsize=(10, 6))
sns.boxplot(x='tissue', y='cosine similarity', data=df, palette="magma")
plt.xticks(rotation=45, ha='right')
plt.ylim(0.1, 0.95)
plt.xlabel('', fontsize=16)
plt.ylabel('cosine similarity', fontsize=16)
plt.title('top 10 least age-related tissues',fontsize=20)
plt.tight_layout()
#plt.show()
plt.savefig("paper_plots/mean_top10_least age-related general tissues.png", dpi=300, bbox_inches='tight')
plt.close()


Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  sns.boxplot(x='tissue', y='cosine similarity', data=df, palette="magma")


### Fig 5: Dynamics of changes in aging using z-scored age gap

In [17]:
adata_cut = adata[adata.obs["disease"] == "normal"]
ground_truth = [label_to_float(label) for label in adata_cut.obs["development_stage"]]
predictions = [label_to_float(label) for label in adata_cut.obs["prediction_age"]]

df = pd.DataFrame({
    'Ground Truth': ground_truth,
    'Prediction': predictions
})

df = df[(df['Prediction'] != -1) & (df["Ground Truth"] != -1)]

df['Age Gap'] = df['Prediction'] - df['Ground Truth']
df = df[df['Age Gap'] != 0]
res = cal_z_score(df['Ground Truth'], df['Prediction'], cal_r=False)
df['Z-Score'] = res["z_score"]
df.head(5)

Unnamed: 0,Ground Truth,Prediction,Age Gap,Z-Score
8,59.0,74.0,15.0,0.903564
12,87.0,80.0,-7.0,-0.578158
24,67.0,77.0,10.0,0.566809
27,63.0,28.0,-35.0,-2.463987
28,0.230769,0.288462,0.057692,-0.102816


In [18]:
bins = range(10, 100, 5)  # 0-5, 5-10, ..., 95-100
labels = [f"{i}-{i+5}" for i in bins[:-1]]
df = df[df["Ground Truth"] > 10]
df['Age Gap Range'] = pd.cut(df['Ground Truth'], bins=bins, labels=labels, right=False)

plt.figure(figsize=(12, 8))
sns.violinplot(x='Age Gap Range', y='Z-Score', data=df, scale='width', inner='quartile',palette="viridis_r")
plt.axhline(0, color='red', linestyle='--', linewidth=1)
plt.title('Z-score distribution by 5-year age gap ranges: healthy samples (excluding perfect predictions)',fontsize=20)
plt.xlabel('chronological age',fontsize=16)
plt.ylabel('z-score of age gap',fontsize=16)
plt.tight_layout()
#plt.show()
plt.savefig("plots/z_score_5year_gap_healthy_exclude_agegap0.png", dpi=300, bbox_inches='tight')
plt.close()


Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  sns.violinplot(x='Age Gap Range', y='Z-Score', data=df, scale='width', inner='quartile',palette="viridis_r")

The `scale` parameter has been renamed and will be removed in v0.15.0. Pass `density_norm='width'` for the same effect.
  sns.violinplot(x='Age Gap Range', y='Z-Score', data=df, scale='width', inner='quartile',palette="viridis_r")


### Fig 6: Tissue-dependent and cell-type-specific aging dynamics

### blood lack:[3562-3570], brain lack [might not be complete: 3562, 3563, 3573, lack: 3574]

In [19]:
ref_chunks = range(3562,3575)
for_these_tissues = ['liver', 'heart','lung', 'kidney', 'breast']#,'blood',"brain"] # 'bone marrow' is not included because all are <1 year old
age_count = {}
for tissue in for_these_tissues:
    age_count[tissue] = {}
    for chunk in ref_chunks:
        validation_data = sc.read_h5ad(f"/storage_bizon/sabrant_rocket_2tb/farhan/cellxgene/primary/cxg_chunk{chunk}.h5ad")
        validation_data.obs["age_group"] = validation_data.obs["development_stage"].map(age_label_mapping_10)
        filtered_data = validation_data[((validation_data.obs['tissue_general'] == tissue)) 
                                        & (validation_data.obs["disease"] == "normal")
                                        & (validation_data.obs["age_group"] != "unknown")]
        chunk_age_counts = filtered_data.obs["age_group"].value_counts()
        for age_group, sample_count in chunk_age_counts.items():
            if age_group in age_count[tissue]:
                age_count[tissue][age_group] += sample_count
            else:
                age_count[tissue][age_group] = sample_count

  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")


In [20]:
sim_genes_by_age_by_tissue = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
for chunk in ref_chunks:
    for tissue in for_these_tissues:
        file_path = f"/home/yufan/perturbgene/notebooks/general_tissue_{tissue}_genes_by_age/{tissue}_chunk{chunk}_by_age.pkl"
        with open(file_path, 'rb') as f:
            data = pickle.load(f)
            if data is {}:
                print(f"sth wrong with {file_path}")
            for gene, age_dict in data.items():
                for age, values in age_dict.items():
                    age_range = age_label_mapping_10[age]
                    if age_range != "unknown":
                        sim_genes_by_age_by_tissue[tissue][age_range][gene].extend(values)

In [21]:
top3genes_by_age_by_tissue = top_n_genes_by_age(sim_genes_by_age_by_tissue,age_count)
df = pd.DataFrame([
    {'Tissue': tissue, 'Age': age, "Gene": gene, 'Mean Cosine Similarites': values}
    for tissue, age_dict in top3genes_by_age_by_tissue.items()
    for age, gene_values in age_dict.items()
    for gene,values in gene_values.items()
])
#df.to_csv("tissue_top3genes.csv", index=False)
df.head(5)

Unnamed: 0,Tissue,Age,Gene,Mean Cosine Similarites
0,liver,>80,MZT1P2,0.672035
1,liver,>80,EEF1A1P33,0.656566
2,liver,>80,RP3-477O4.5,0.65452
3,liver,60-70,RPS12,0.402485
4,liver,60-70,CH17-258A22.4,0.401862


In [22]:
for tissue in for_these_tissues:
    fig = plot_clock_tissue(tissue, top3genes_by_age_by_tissue)
    #fig.show()
    fig.savefig(f'paper_plots/clock_expressed_in1%_samples/{tissue}_clock.png', dpi=300, bbox_inches='tight')
    plt.close(fig)

### Fig 7: Temporal drift in gene embedding space during aging: gene variance to youngest group

In [23]:
gene_functions = {
    "ATR": "DNA Repair", #same
    "TDP2": "DNA Repair", #up
    "SAFB2": "Stress Response", #up
    "HERPUD1": "Stress Response", #down
    "MCTP1": "Oxidative Stress", #up
    "HEBP2": "Oxidative Stress", #down
    "SESN1": "Oxidative Stress", #same
    "APC": "Tumor Suppressor", #up
    "MKRN1": "Tumor Suppressor",  #same
    "BANP": "Tumor Suppressor", #down
    "IFITM3":"Immune Response"
}
function_colors = {
    "DNA Repair": "purple",
    "Stress Response": "teal",
    "Oxidative Stress": "orange",
    "Tumor Suppressor": "red",
    "Immune Response":"blue"
}
age_order = ["<20", "20-40", "40-60", "60-80", ">80"]

In [24]:
gene2youngest_each_group = {}
ref_age = "<20"
do_for_these_genes = gene_functions.keys()
for gene in do_for_these_genes:
    adata_sub = adata[(adata.obs[f'{gene}_embedding_status'] == "present") & (adata.obs["disease"] == "normal")]
    adata_sub = adata_sub[~adata_sub.obs["age_group_20"].isin(["unknown"])]
    gene_embeddings = adata_sub.obsm[f'{gene}_embeddings'] 
    age_groups = adata_sub.obs['age_group_20'].to_list()
    gene2youngest_each_group[gene] = sim_gene_youngest(gene_embeddings, age_group_list=age_groups, ref_age = ref_age, age_order=age_order)

In [25]:
gene2youngest_each_group_p = {}
for gene, values in gene2youngest_each_group.items():
    ref_sims = gene2youngest_each_group[gene][ref_age]
    gene2youngest_each_group_p[gene] = {}
    for age_group, sims in values.items():
        if age_group != ref_age:
            gene2youngest_each_group_p[gene][age_group] = cal_p(values[age_group], ref_sims)
        else:
            gene2youngest_each_group_p[gene][ref_age] = ""

In [26]:
for gene, age_values in gene2youngest_each_group.items():
    data = []
    mean = np.mean(age_values[ref_age])
    for age, sims in age_values.items():
        for sim in sims:
            data.append({"age group": age, "cosine similarity": sim-mean})
    df = pd.DataFrame(data)
    df['age group'] = pd.Categorical(df['age group'], categories=age_order, ordered=True)

    plt.figure(figsize=(8, 6))
    sns.lineplot(data=df, x="age group", y="cosine similarity", marker='o')
    plt.title(f"cosine similarity compare to <20 for {gene}", fontsize=20)
    plt.xlabel('age group', fontsize=16)
    plt.ylabel('adjusted cosine similarity', fontsize=16)
    plt.ylim(-0.5,0.1)
    ylim = plt.gca().get_ylim()
    
    for age, p_star in gene2youngest_each_group_p[gene].items():
        annotation = p_star 
        plt.text(
            age,  
            ylim[1] - 0.05 * (ylim[1] - ylim[0]),  
            annotation,
            ha="center", va="top", fontsize=10, color="black"
        )
    plt.tight_layout()
    #plt.show()
    plt.savefig(f'paper_plots/compare_gene2youngest/{gene}_withP.png', dpi=300, bbox_inches='tight')
    plt.close()

### Figure 7c: similarity between genes: check check

In [27]:
gene2gene_each_group = {}
age_order_10 = ["10-20", "20-30", "30-40", "40-50","50-60", "60-70","70-80",">80"]
ref_gene = "IGHA1"
do_for_these_genes = ["IGHG3", #up
                      "ARHGAP11A", #down
                      "BCL2L12" #down
                      ]
adata_ref = adata[(~adata.obs["age_group_10"].isin(["unknown"])) & (adata.obs["disease"] == "normal")]
for gene in do_for_these_genes:
    geneA_adata = adata_ref[adata_ref.obs[f'{ref_gene}_embedding_status'] == "present"]
    geneB_adata= adata_ref[adata_ref.obs[f'{gene}_embedding_status'] == "present"]
    geneA_embeddings = geneA_adata.obsm[f"{ref_gene}_embeddings"]
    geneB_embeddings= geneB_adata.obsm[f'{gene}_embeddings']
    geneA_age_groups = geneA_adata.obs['age_group_10'].to_list()
    geneB_age_groups = geneB_adata.obs['age_group_10'].to_list()
    gene2gene_each_group[gene] = sim_gene_gene(geneA_embeddings, geneB_embeddings, age_group_list=geneA_age_groups, ref_age_group_list=geneB_age_groups, age_order=age_order_10)

In [28]:
gene2gene_each_group_p = {}
for gene, values in gene2gene_each_group.items():
    gene2gene_each_group_p[gene] = {}
    for age in age_order_10:
        if age == "10-20":
            gene2gene_each_group_p[gene]["10-20"] = ''
        else:
            previous=age_order_10[age_order_10.index(age)-1]
            sims_age_now = values[age]
            sims_age_previous = values[previous]
            if (sims_age_now is not None) and (sims_age_previous is not None):
                gene2gene_each_group_p[gene][age] = cal_p(values[age], values[previous])
            else:
                gene2gene_each_group_p[gene][age] = ""

In [29]:
for gene, age_values in gene2gene_each_group.items():
    data = []
    for age, sims in age_values.items():
        if sims is not None:
            for sim in sims:
                data.append({"age group": age, "cosine similarity": sim})
    df = pd.DataFrame(data)
    df['age group'] = pd.Categorical(df['age group'], categories=age_order_10, ordered=True)

    plt.figure(figsize=(8, 6))
    sns.lineplot(data=df, x="age group", y="cosine similarity", marker='o')
    plt.title(f"cosine similarity between {ref_gene} and {gene}", fontsize=20)
    plt.xlabel('age group', fontsize=16)
    plt.ylabel('cosine similarity', fontsize=16)
    ylim = plt.gca().get_ylim()
    
    for age, p_star in gene2gene_each_group_p[gene].items():
        annotation = p_star 
        plt.text(
            age,  
            ylim[1] - 0.05 * (ylim[1] - ylim[0]),  
            annotation,
            ha="center", va="top", fontsize=10, color="black"
        )
    plt.tight_layout()
    #plt.show()
    plt.savefig(f'paper_plots/compare_gene2gene/{ref_gene}_{gene}_withP.png', dpi=300, bbox_inches='tight')
    plt.close() #this function still need to be checked

### Figure 8: conservative genes while aging (gene-age, pairweise, by age group)

In [30]:
gene2age_each_group = {}
age_order = ["<20", "20-40", "40-60", "60-80", ">80"]
do_for_these_genes = gene_functions.keys()
adata_ref = adata[(~adata.obs["age_group_20"].isin(["unknown"])) & (adata.obs["disease"] == "normal")]
for gene in do_for_these_genes:
    adata_gene = adata_ref[adata_ref.obs[f"{gene}_embedding_status"] == "present"]
    gene_embeddings = adata_gene.obsm[f"{gene}_embeddings"]
    age_embeddings= adata_gene.obsm['development_stage_embeddings']
    age_groups = adata_gene.obs['age_group_20'].to_list()
    gene2age_each_group[gene] = sim_gene_age(gene_embeddings, age_embeddings, age_group_list=age_groups, age_order=age_order)

In [31]:
gene2age_each_group_p = {}
for gene, values in gene2age_each_group.items():
    gene2age_each_group_p[gene] = {}
    for age in age_order:
        if age == "<20":
            gene2age_each_group_p[gene]["<20"] = ''
        else:
            previous=age_order[age_order.index(age)-1]
            sims_age_now = values[age]
            sims_age_previous = values[previous]
            if (sims_age_now is not None) and (sims_age_previous is not None):
                gene2age_each_group_p[gene][age] = cal_p(sims_age_now, sims_age_previous)
            else:
                gene2age_each_group_p[gene][age] = ""

In [32]:
for gene, age_values in gene2age_each_group.items():
    data = []
    for age, sims in age_values.items():
        if sims is not None:
            for sim in sims:
                data.append({"age group": age, "cosine similarity": sim})
    df = pd.DataFrame(data)
    df['age group'] = pd.Categorical(df['age group'], categories=age_order, ordered=True)
    mean_values = df.groupby('age group')['cosine similarity'].mean().reset_index()

    plt.figure(figsize=(8, 6))
    sns.boxplot(data=df, x="age group", y="cosine similarity")
    plt.plot(
        mean_values['age group'], 
        mean_values['cosine similarity'], 
        marker='o', color='red', linestyle='-', label='Mean Cosine Similarity'
    )
    plt.title(f"cosine similarity between {gene} token and age token", fontsize=18)
    plt.xlabel('age group', fontsize=16)
    plt.ylabel('cosine similarity', fontsize=16)
    plt.ylim(0,0.95)
    ylim = plt.gca().get_ylim()
    
    for age, p_star in gene2age_each_group_p[gene].items():
        annotation = p_star 
        plt.text(
            age,  
            ylim[1] - 0.05 * (ylim[1] - ylim[0]),  
            annotation,
            ha="center", va="top", fontsize=16, color="black"
        )
    
    plt.tight_layout()
    #plt.show()
    plt.savefig(f'paper_plots/compare_gene2age/{gene}_withP.png', dpi=300, bbox_inches='tight')
    plt.close() 

### Fig 9: gene-age in healthy vs. diseased

In [33]:
gene2age_each_group_diseased = {}
age_order = ["<20", "20-40", "40-60", "60-80", ">80"]
do_for_these_genes = gene_functions.keys()
adata_ref = adata[(~adata.obs["age_group_20"].isin(["unknown"])) & (adata.obs["disease"] != "normal")]
for gene in do_for_these_genes:
    adata_gene = adata_ref[adata_ref.obs[f"{gene}_embedding_status"] == "present"]
    gene_embeddings = adata_gene.obsm[f"{gene}_embeddings"]
    age_embeddings= adata_gene.obsm['development_stage_embeddings']
    age_groups = adata_gene.obs['age_group_20'].to_list()
    gene2age_each_group_diseased[gene] = sim_gene_age(gene_embeddings, age_embeddings, age_group_list=age_groups, age_order=age_order)

In [34]:
gene2age_each_group_diseased_p = {}
for gene, values in gene2age_each_group_diseased.items():
    gene2age_each_group_diseased_p[gene] = {}
    for age in age_order:
        sims_age_now = values[age]
        sims_age_previous = gene2age_each_group[gene][age]
        if (sims_age_now is not None) and (sims_age_previous is not None):
            gene2age_each_group_diseased_p[gene][age] = cal_p(sims_age_now, sims_age_previous)
        else:
            gene2age_each_group_diseased_p[gene][age] = ""

  res = hypotest_fun_out(*samples, **kwds)


In [35]:
for gene, age_values in gene2age_each_group.items():
    data = []
    for age, sims in age_values.items():
        if sims is not None:
            for sim in sims:
                data.append({"age group": age, "cosine similarity": sim})
    df = pd.DataFrame(data)
    df['age group'] = pd.Categorical(df['age group'], categories=age_order, ordered=True)
    df['disease'] = 'healthy'
    mean_values = df.groupby('age group')['cosine similarity'].mean().reset_index()

    diseased_data = []
    for age, sims in gene2age_each_group_diseased[gene].items():
        if sims is not None:
            for sim in sims:
                diseased_data.append({"age group": age, "cosine similarity": sim})
    diseased_df = pd.DataFrame(diseased_data)
    diseased_df['age group'] = pd.Categorical(diseased_df['age group'], categories=age_order, ordered=True)
    diseased_df["disease"] = "diseased"
    diseased_mean_values = diseased_df.groupby('age group')['cosine similarity'].mean().reset_index()
    combined_df = pd.concat([df,diseased_df])

    mean_values['disease'] = 'healthy'
    diseased_mean_values['disease'] = 'diseased'
    combined_mean_values = pd.concat([mean_values, diseased_mean_values])
    combined_mean_values['age group'] = pd.Categorical(
        combined_mean_values['age group'], 
        categories=["<20", "20-40", "40-60", "60-80", ">80"], 
        ordered=True
    )

    plt.figure(figsize=(8, 6))
    boxplot = sns.boxplot(data=combined_df, x="age group", y="cosine similarity", hue="disease", palette='pastel')

    sns.lineplot(
        data=combined_mean_values,
        x="age group",
        y="cosine similarity",
        hue="disease",
        style="disease",
        markers=["o", "s"], 
        palette={"healthy": "#0047AB", "diseased": "lightcoral"},
        linewidth=2
    )
    sns.move_legend(boxplot, "upper left", bbox_to_anchor=(1, 1))

    plt.title(f"cosine similarity between {gene} token and age token \n in healthy vs. diseased", fontsize=18)
    plt.xlabel('age group', fontsize=16)
    plt.ylabel('cosine similarity', fontsize=16)
    plt.ylim(0,0.95)
    ylim = plt.gca().get_ylim()
    
    
    for age, p_star in gene2age_each_group_diseased_p[gene].items():
        annotation = p_star 
        plt.text(
            age,  
            ylim[1] - 0.05 * (ylim[1] - ylim[0]),  
            annotation,
            ha="center", va="top", fontsize=16, color="black"
        )
    
    plt.tight_layout()
    #plt.show()
    plt.savefig(f'paper_plots/compare_gene2age_with_diseased/{gene}_withP.png', dpi=300, bbox_inches='tight')
    plt.close() 