# imports

In [None]:
def wrap_labels(label, max_length=20, ignore=[]):
    if len(label) > max_length:
        for i in ignore:
            if i in label:
                return label
        last_space = label.rfind(' ', 0, max_length)
        if last_space > 0:
            if len(label[last_space+1:]) > max_length:
                label = label[:last_space] + '<br>' + wrap_labels(label[last_space+1:], max_length=max_length)
            else:
                label = label[:last_space] + '<br>' + label[last_space+1:]
    return label

def remove_rare_values(df, column_names):
    for column in column_names:
        counts = df[column].value_counts()
        to_remove = counts[counts < 3].index
        df = df[~df[column].isin(to_remove)]
    return df

# check completeness

In [None]:
import os
result_path="./results/Axolotls/"
samples = [ name for name in os.listdir(result_path) if os.path.isdir(os.path.join(result_path, name)) ]
for sample in samples:
    result_dir = os.path.join(result_path,sample)
    files = os.listdir(os.path.join(result_dir))
    sample = sample[:-5] 
    chip = sample.split(".")[0]
    if 'adata.h5ad' not in files: 
        print(sample)
        #if sample[0].isdigit():
        #    !python ./CRUST.py ./data/Axolotls/regeneration/{chip}/{sample}.csv ./results/Axolotls/ Axolotls
        #else:
        #    !python ./CRUST.py ./data/Axolotls/development/{chip}/{sample}.csv ./results/Axolotls/ Axolotls

# dataview

## gem check

In [None]:
gems={}
gems['Stage44'] = pd.read_csv("./data/Axolotls/development/Stage44_telencephalon_rep2_FP200000239BL_E4_scgem.csv")
gems['Stage54'] = pd.read_csv("./data/Axolotls/development/Stage54_telencephalon_rep2_DP8400015649BRD6_2_scgem.csv")
gems['Stage57'] = pd.read_csv("./data/Axolotls/development/Stage57_telencephalon_rep2_DP8400015649BRD5_1_scgem.csv")
gems['Juv'] = pd.read_csv("./data/Axolotls/development/Injury_control_FP200000239BL_E3_scgem.csv")
gems['Adult'] = pd.read_csv("./data/Axolotls/development/Adult_telencephalon_rep2_DP8400015234BLA3_1_scgem.csv")
gems['Meta'] = pd.read_csv("./data/Axolotls/development/Meta_telencephalon_rep1_DP8400015234BLB2_1_scgem.csv")

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

# Set the style of the visualization
sns.set(style="whitegrid")

# Create a figure and axis
plt.figure(figsize=(12, 8))

# Define a color palette
palette = sns.color_palette("husl", len(gems))

# Loop through each sample in gems and plot the histogram
for (sample_name, gem_data), color in zip(gems.items(), palette):
    gem_total_counts = gem_data.groupby('geneID')['MIDCounts'].sum()
    sns.histplot(gem_total_counts, bins=1000, kde=True, label=sample_name, color=color, alpha=0.6)

# Add titles and labels
plt.title('Distribution of Total Counts per Gene for Each Sample', fontsize=16)
plt.xlabel('Total Counts', fontsize=14)
plt.ylabel('Frequency', fontsize=14)

# Set x-axis limits
plt.xlim(0, 10000)

# Add a legend
plt.legend(title='Sample', fontsize=12)

# Show the plot
plt.show()

## development

In [None]:
import scanpy as sc
adata_dev = sc.read_h5ad("./data/Axolotls/Development.h5ad")
adata_dev.obs_names_make_unique()
adata_reg = sc.read_h5ad("./data/Axolotls/Regeneration.h5ad")
adata_reg.obs_names_make_unique()

In [None]:
adata_dev.obs['x']=adata_dev.obsm['spatial'][:,0]
adata_dev.obs['y']=4500-adata_dev.obsm['spatial'][:,1]
adata_reg.obs['x']=adata_reg.obsm['spatial'][:,0]
adata_reg.obs['y']=13000-adata_reg.obsm['spatial'][:,1]

In [None]:
sc.set_figure_params(dpi=300, figsize=[15,3], fontsize=12)
sc.pl.scatter(adata_dev,x="x",y="y",color='Annotation', size=9, legend_loc = 'right margin', legend_fontsize='small', show=False, frameon=False, title='')

#legend = ax.get_legend()
#legend.set_bbox_to_anchor((1, 0.5)) # set legend position
#plt.setp(legend.get_texts(), fontsize='x-small') # set legend font size
#legend._legend_box.align = "left" # set legend alignment
#legend.set_title('Legend Title') # set legend title
#legend._legend_box.ncol = 0 # set number of columns
ax = plt.gca()
# To remove the background frame
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["bottom"].set_visible(False)
ax.spines["left"].set_visible(False)

# To remove x-y axis
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

sc.set_figure_params(dpi=300, figsize=[15,3], fontsize=12)
sc.pl.scatter(adata_dev,x="x",y="y",color='Annotation', size=9, legend_loc = 'upper right', legend_fontsize='large', show=False, 
    frameon=False, title='', groups=['VLMC'], palette=sns.color_palette('viridis'))

ax = plt.gca()
# To remove the background frame
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["bottom"].set_visible(False)
ax.spines["left"].set_visible(False)

# To remove x-y axis
ax.axes.get_xaxis().set_visible(False)
for label, x in zip(['St.44', 'St.54', 'St.57', 'Juv.', 'Adult', 'Meta.'], [760, 3900, 7700, 13700, 20400, 27000]):
    ax.text(x, -200, label, ha='center', va='top', fontsize=14)
ax.axes.get_yaxis().set_visible(False)
# Add annotation for developmental stages
ax.annotate('Developmental stages', xy=(4230, 400), xytext=(4230, 4000),
            fontsize=16, color='black', ha='center')

plt.savefig("Fig5a_u.pdf", format="pdf", bbox_inches="tight")

In [None]:
adata_dev.obs['Annotation'] = adata_dev.obs['Annotation'].replace({'CP  ':'CP'})
sc.set_figure_params(dpi=300, figsize=[15,3], fontsize=12)
sc.pl.scatter(adata_dev,x="x",y="y",color='Annotation', size=9, legend_loc = 'upper right', legend_fontsize='large', show=False, 
    frameon=False, title='', groups=['CP'], palette=sns.color_palette('viridis'))

ax = plt.gca()
# To remove the background frame
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["bottom"].set_visible(False)
ax.spines["left"].set_visible(False)

# To remove x-y axis
ax.axes.get_xaxis().set_visible(False)
for label, x in zip(['St.44', 'St.54', 'St.57', 'Juv.', 'Adult', 'Meta.'], [760, 3900, 7700, 13700, 20400, 27000]):
    ax.text(x, -200, label, ha='center', va='top', fontsize=14)
ax.axes.get_yaxis().set_visible(False)

plt.savefig("Fig5a_d.pdf", format="pdf", bbox_inches="tight")

# downstram analysis

In [None]:
import os
import scanpy as sc
result_path="./results/Axolotls/"
samples = [ name for name in os.listdir(result_path) if os.path.isdir(os.path.join(result_path, name))]
adata={}
Xs={}
for sample in samples:
    result_dir = os.path.join(result_path,sample)
    files = os.listdir(os.path.join(result_dir))
    chip = sample.split(".")[0]
    if 'adata.h5ad' in files: 
        adata[sample] = sc.read_h5ad(os.path.join(result_dir+"/adata.h5ad"))
        Xs[sample] = adata[sample].uns['X']

In [None]:
for i in sample:
    

## 1.netplot

In [None]:
import os
from functools import reduce
import scanpy as sc
import numpy as np
from CRUST import normalizeX, numpy_svd_rmsd_rot

def mirror(X):
    mirrorX = np.copy(X)
    mirrorX[:, 2] = -mirrorX[:, 2]
    return mirrorX

Xs={}
GeneLists={}
ctypes=[]

result_path="./results/Axolotls/"
samples = [ name for name in sorted(os.listdir(result_path)) if os.path.isdir(os.path.join(result_path, name)) ]
for sample in samples:
    result_dir = os.path.join(result_path,sample)
    files = os.listdir(os.path.join(result_dir))
    info = sample.split(".")
    chip = info[0]
    ctype = info[-1][:-5]
    ctypes.append(ctype)
    sample = chip+' '+ctype
    if 'adata.h5ad' in files and sample[0].isalpha(): 
        print(sample)
        adata = sc.read_h5ad(os.path.join(result_dir+"/adata.h5ad"))
        Xs[sample] = adata.uns['X']
        GeneLists[sample] = adata.uns['X'].index

# Build distance matrix
utypes = np.unique(ctypes)
keys = list(Xs.keys())
N = len(keys)
D_axo_dev = np.zeros((N,N))

for n, key_n in enumerate(keys):
    for m, key_m in enumerate(keys[:n+1]):
        intersected_values = np.intersect1d(GeneLists[key_n], GeneLists[key_m])[:100]
        boolean_arrays_n = np.in1d(GeneLists[key_n], intersected_values)
        boolean_arrays_m = np.in1d(GeneLists[key_m], intersected_values)
        normalized_x_n = normalizeX(Xs[key_n][boolean_arrays_n], method='mean')
        normalized_x_m = normalizeX(Xs[key_m][boolean_arrays_m], method='mean')
        d1, _, _ = numpy_svd_rmsd_rot(normalized_x_n, normalized_x_m)
        d2, _, _ = numpy_svd_rmsd_rot(mirror(normalized_x_n), normalized_x_m)
        D_axo_dev[n,m] = D_axo_dev[m,n] = min(d1,d2)

In [None]:
godsnot_102 = [
    #"#000000",  # remove the black, as often, we have black colored annotation
    "#FFFF00",
    "#1CE6FF",
    "#FF34FF",
    "#FF4A46",
    "#008941",
    "#006FA6",
    "#A30059",
    "#FFDBE5",
    "#7A4900",
    "#0000A6",
    "#63FFAC",
    "#B79762",
    "#004D43",
    "#8FB0FF",
    "#997D87",
    "#5A0007",
    "#809693",
    "#6A3A4C",
    "#1B4400",
    "#4FC601",
    "#3B5DFF",
    "#4A3B53",
    "#FF2F80",
    "#61615A",
    "#BA0900",
    "#6B7900",
    "#00C2A0",
    "#FFAA92",
    "#FF90C9",
    "#B903AA",
    "#D16100",
    "#DDEFFF",
    "#000035",
    "#7B4F4B",
    "#A1C299",
    "#300018",
    "#0AA6D8",
    "#013349",
    "#00846F",
    "#372101",
    "#FFB500",
    "#C2FFED",
    "#A079BF",
    "#CC0744",
    "#C0B9B2",
    "#C2FF99",
    "#001E09",
    "#00489C",
    "#6F0062",
    "#0CBD66",
    "#EEC3FF",
    "#456D75",
    "#B77B68",
    "#7A87A1",
    "#788D66",
    "#885578",
    "#FAD09F",
    "#FF8A9A",
    "#D157A0",
    "#BEC459",
    "#456648",
    "#0086ED",
    "#886F4C",
    "#34362D",
    "#B4A8BD",
    "#00A6AA",
    "#452C2C",
    "#636375",
    "#A3C8C9",
    "#FF913F",
    "#938A81",
    "#575329",
    "#00FECF",
    "#B05B6F",
    "#8CD0FF",
    "#3B9700",
    "#04F757",
    "#C8A1A1",
    "#1E6E00",
    "#7900D7",
    "#A77500",
    "#6367A9",
    "#A05837",
    "#6B002C",
    "#772600",
    "#D790FF",
    "#9B9700",
    "#549E79",
    "#FFF69F",
    "#201625",
    "#72418F",
    "#BC23FF",
    "#99ADC0",
    "#3A2465",
    "#922329",
    "#5B4534",
    "#FDE8DC",
    "#404E55",
    "#0089A3",
    "#CB7E98",
    "#A4E804",
    "#324E72",
]


In [None]:
#color_map_ct =  dict(zip(utypes, godsnot_102))

color_map_ct =  {
'CP': '#1CE6FF',
'IMN': '#FF34FF',
'Immature_CMPN': '#FF4A46',
'CMPN': '#FFFF00',
'Immature_MSN': '#006FA6',
'MCG': '#63FFAC',
'MSN': '#B79762',
'Oligo': '#004D43',
'VLMC': '#997D87',
'WSN': '#5A0007',
'dNBL1': '#1B4400',
'dNBL2': '#4FC601',
'dNBL3': '#3B5DFF',
'dNBL4': '#4A3B53',
'dNBL5': '#FF2F80',
'obNBL': '#B903AA',
'tlNBL': '#372101',
'Immature_dpEX': '#FFDBE5',
'Immature_mpEX': '#7A4900',
'Immature_nptxEX': '#0000A6',
'dpEX': '#61615A',
'mpEX': '#BA0900',
'nptxEX': '#00C2A0',
'mpIN': '#6B7900',
'cckIN': '#809693',
'npyIN': '#FFAA92',
'Immature_DMIN': '#008941',
'Immature_cckIN': '#A30059',
'ntng1IN': '#FF90C9',
'scgnIN': '#0AA6D8',
'sstIN': '#00846F',
'rIPC1': '#D16100',
'rIPC2': '#DDEFFF',
'rIPC3': '#000035',
'rIPC4': '#7B4F4B',
'dEGC': '#6A3A4C',
'reaEGC': '#A1C299',
'ribEGC': '#300018',
'sfrpEGC': '#013349',
'wntEGC': '#FFB500',
'Unknown': '#8FB0FF',
}

In [None]:
color_map_ct =  {
'CP': '#1CE6FF',
'IMN': '#FF34FF',
'Immature_CMPN': '#FF4A46',
'CMPN': '#FF4A46',
'Immature_MSN': '#006FA6',
'MSN': '#006FA6',
'MCG': '#63FFAC',
'Oligo': '#004D43',
'VLMC': '#997D87',
'WSN': '#5A0007',
'dNBL1': '#1B4400',
'dNBL2': '#1B4400',
'dNBL3': '#1B4400',
'dNBL4': '#1B4400',
'dNBL5': '#1B4400',
'obNBL': '#B903AA',
'tlNBL': '#372101',

'Immature_dpEX': '#FFDBE5',
'dpEX': '#FFDBE5',
'Immature_mpEX': '#7A4900',
'mpEX': '#7A4900',
'Immature_nptxEX': '#0000A6',
'nptxEX': '#0000A6',
'Immature_cckIN': '#A30059',
'cckIN': '#A30059',

'mpIN': '#6B7900',
'npyIN': '#FFAA92',
'Immature_DMIN': '#008941',
'ntng1IN': '#FF90C9',
'scgnIN': '#0AA6D8',
'sstIN': '#00846F',

'rIPC1': '#D16100',
'rIPC2': '#D16100',
'rIPC3': '#D16100',
'rIPC4': '#D16100',
'dEGC': '#FFB500',
'reaEGC': '#FFB500',
'ribEGC': '#FFB500',
'sfrpEGC': '#FFB500',
'wntEGC': '#FFB500',
'Unknown': '#8FB0FF',
}

In [None]:
from pyvis.network import Network
from IPython.display import HTML
import plotly.graph_objs as go
import itertools
import networkx as nx
import plotly.io as pio
import kaleido, random
random.seed(997)
np.random.seed(997)

# create an empty graph
G = nx.Graph()

# add nodes
N = D_axo_dev.shape[0]
for i in range(N):
    G.add_node(i, label=keys[i])

# add edges
edges = []

for i in range(N):
    arr = D_axo_dev[i,:]
    min_indices = np.argpartition(arr, 4)[:4]
    min_values = arr[min_indices]
    for j in min_indices:
        if not i==j:
            edges.append((i,int(j),{'weight':float(1/arr[j]*100-60)}))

G.add_edges_from(edges)

# Draw pyvis
net = Network(width=2000, height=2000, notebook=True, cdn_resources='remote')
net.from_nx(G)

# set node color based on label
for node in net.nodes:
    node['group'] = node['label'].split(" ")[1]
    node['color'] = color_map_ct[node['group']]

# show
net.show_buttons()
net.set_options("""
const options = {
  "configure": {
    "enabled": true,
    "filter": ["nodes","edges","physics"]
  },
  "nodes": {
    "borderWidth": 3,
    "opacity": 1,
    "font": {
      "size": 15,
      "strokeWidth": 5
    },
    "size": 0
  },
  "edges": {
    "color": {
      "opacity": 0.7
    },
    "selfReferenceSize": 0,
    "selfReference": {
      "size": 0,
      "angle": 0.7853981633974483
    },
    "smooth": {
      "forceDirection": "vertical"
    }
  },
  "physics": {
    "minVelocity": 0.75
  }
}
""")

net.show('networkplot_Axo_dev_neighbor3.html')

## 2.sub-conformation by clustering on variance matrix

### CP

#### Preprocess

In [None]:
import os
from functools import reduce
import scanpy as sc
import numpy as np
from tqdm import tqdm
Xs={}
GeneLists={}
adatas={}

result_path="./results/Axolotls/"
samples = [ name for name in sorted(os.listdir(result_path)) if os.path.isdir(os.path.join(result_path, name)) ]
for sample in samples:
    if sample[0].isalpha():
        result_dir = os.path.join(result_path,sample)
        files = os.listdir(os.path.join(result_dir))
        info = sample.split(".")
        chip = info[0]
        ctype = info[1][:-5]
        sample = chip+'_'+ctype
        if 'adata.h5ad' in files and ctype=='CP': 
            adata = sc.read_h5ad(os.path.join(result_dir+"/adata.h5ad"))
            adatas[sample] = adata
            Xs[sample] = adata.uns['X']
            GeneLists[sample] = adata.uns['X'].index

intersected_genes = reduce(lambda x, y: x.intersection(y), GeneLists.values())
M = intersected_genes.size
print(M)

In [None]:
# Build gene-gene distance matrix
keys = ['Stage44_CP',
        'Stage54_CP',
        'Stage57_CP',
        'Control_Juv_CP',
        'Adult_CP',
        'Meta_CP']
N = len(keys)
GG_dev = np.zeros((N,M,M))
for n in tqdm(range(N)):
    index_k_intersection = [list(GeneLists[keys[n]]).index(gene) for gene in intersected_genes]
    for i, idx_n in enumerate(index_k_intersection):
        for j, idx_m in enumerate(index_k_intersection[:i+1]):
            g1_cord = Xs[keys[n]].iloc[idx_n,:].to_numpy()
            g2_cord = Xs[keys[n]].iloc[idx_m,:].to_numpy()
            d = np.linalg.norm(g1_cord - g2_cord)
            GG_dev[n,i,j] = GG_dev[n,j,i] = d

In [None]:
import statistics
V_dev_CP = np.zeros((M,M))
for i in range(M):
    for j in range(M):
        v = statistics.variance(GG_dev[:,i,j])
        V_dev_CP[i,j] = V_dev_CP[j,i] = v

In [None]:
# translate gene
import pandas as pd
anno = pd.read_csv("./Axo_Summary_Gene_Annotation_0325.txt", sep='\t')
annodict = {row['Axolotl_ID']: row['hs_gene'] for _, row in anno.iterrows() if row['hs_gene'] != '-'}

intersected_symbols=[]
for i in intersected_genes:
    intersected_symbols.append(annodict.get(i, i))

In [None]:
V_dev_sqrt_CP = np.sqrt(V_dev_CP)

#### Heatmap

In [None]:
# hierarchy
import seaborn as sns
import pandas as pd
import matplotlib.patches as patches
import matplotlib.pyplot as plt
from scipy.cluster.hierarchy import fcluster

clustergrid = sns.clustermap(V_dev_sqrt_CP, xticklabels=intersected_symbols, yticklabels=intersected_symbols, method='ward')
plt.close()

# set
sns.set(rc={"figure.dpi":300,'figure.figsize':(10, 10)},font_scale=1)

ax = sns.clustermap(V_dev_sqrt_CP, xticklabels=[], yticklabels=[], cmap="viridis", method='ward', cbar_kws={'ticks': [0,5,10],'orientation':'horizontal'}, vmax=10)
ax.ax_cbar.set_title("Standard Deviation", fontdict={'fontsize': 15})
ax.ax_cbar.set_position((0.03, .89, .20, .05))
# Remove the dendrograms
#ax.ax_row_dendrogram.set_visible(False)
#ax.ax_col_dendrogram.set_visible(False)

## subconformation 1
# Get clusters with threshold
linkage = clustergrid.dendrogram_row.linkage
threshold = 45
clusters = fcluster(linkage, threshold, criterion='distance')
# add patches
row_order = ax.dendrogram_row.reordered_ind
# start
x = y = 0
# width
w = h = 0
formal_c = 1
count = 1
for c in np.roll(clusters[row_order],-1):
    if c == formal_c:
        formal_c = c
        count += 1
    else:
        w = h = count
        if w>=3:
            ax.ax_heatmap.add_patch(patches.Rectangle((x, y), w, h, fill=False, edgecolor='white', lw=1.5, clip_on=False))
        x += count
        y = x
        formal_c = c
        count = 1

## subconformation 2
# Get clusters with threshold
linkage = clustergrid.dendrogram_row.linkage
threshold = 20
clusters = fcluster(linkage, threshold, criterion='distance')
# add patches
row_order = ax.dendrogram_row.reordered_ind
# start
x = y = 0
# width
w = h = 0
formal_c = 1
count = 1
for c in np.roll(clusters[row_order],-1):
    if c == formal_c:
        formal_c = c
        count += 1
    else:
        w = h = count
        if w>=3:
            ax.ax_heatmap.add_patch(patches.Rectangle((x, y), w, h, fill=False, edgecolor='orange', lw=1.5, clip_on=False))
        x += count
        y = x
        formal_c = c
        count = 1

plt.savefig("figures/Fig5f.pdf", format="pdf", bbox_inches="tight")
# mask
#mask = np.tril(np.ones_like(V_dev_sqrt_CP),k=-1)
#values = ax.ax_heatmap.collections[0].get_array().reshape(V_dev_sqrt_CP.shape)
#new_values = np.ma.array(values, mask=mask)
#ax.ax_heatmap.collections[0].set_array(new_values)

#### Sunburst

In [None]:
import pandas as pd
# Generates Seaborn clustermap and dendrograms
df = pd.DataFrame(V_dev_sqrt_CP)
col_name = 'class'
sns_cluster = sns.clustermap(df, method='ward', metric='euclidean')
plt.close()
row_linkage = sns_cluster.dendrogram_row.linkage
col_linkage = sns_cluster.dendrogram_col.linkage

# Assigns clusters to the data
df[col_name + '1'] = fcluster(row_linkage, t=20, criterion='distance')
df[col_name + '2'] = fcluster(row_linkage, t=45, criterion='distance')
df['weight']=1

##### Enrichment

In [None]:
import gseapy as gp
for gs in ['KEGG_2021_Human','GO_Molecular_Function_2023','GO_Biological_Process_2023']:
    enr = pd.DataFrame(columns=['Class','Sub-conformation','Gene_set', 'Term', 'Overlap', 'P-value', 'Adjusted P-value', 'Old P-value',	'Old Adjusted P-value',	'Odds Ratio', 'Combined Score', 'Genes'])

    # Enricr API
    for clusterclass in ['class1','class2']:
        unique_values, counts = np.unique(df[clusterclass], return_counts=True)
        for value, count in zip(unique_values, counts):
            if count>=3:
                try:
                    genelist = np.array(intersected_symbols)[df[clusterclass]==value].tolist()
                    filterlist = list(filter(lambda x: 'AMEX60DD' not in x, genelist))
                    enr_up = gp.enrichr(gene_list=filterlist,
                                        organism='Human',
                                        gene_sets=gs,
                                        outdir=None, 
                                        cutoff=0.05,
                                        no_plot=True)

                    # trim (go:...)
                    enr_up.res2d.Term = enr_up.res2d.Term.str.split(" \(GO").str[0]
                    enr_up.res2d['Class'] = clusterclass
                    enr_up.res2d['Sub-conformation'] = value
                    enr = pd.concat([enr, enr_up.res2d[enr_up.res2d['Adjusted P-value']<0.05]], ignore_index=True)
                    # dotplot
                    # gp.dotplot(enr_up.res2d, title="Cluster "+str(value), size=10, cmap = plt.cm.viridis_r)
                except ValueError or FileNotFoundError or TypeError as e:
                    print(value)
                    print(e)
                    continue

    enr.to_csv('AXO_dev_CP_subconformation_hierachy_clustering_'+gs+'.pjust.05.csv')

In [None]:
df['KEGG_label1']=df['GO_BP_label1']=df['GO_MF_label1']=df['class1']
df['KEGG_label2']=df['GO_BP_label2']=df['GO_MF_label2']=df['class2']

# GO_BP
enr = pd.read_csv('subconfiguration/AXO_dev_CP_subconformation_hierachy_clustering_GO_Biological_Process_2023.pjust.05.csv')
for clusterclass, label in zip(['class1', 'class2'] ,['GO_BP_label1', 'GO_BP_label2']):
    enr_class = enr[enr["Class"]==clusterclass]
    for cluster in np.unique(df[clusterclass]):
        try:
            enr_class_top1_label = enr_class[enr_class['Sub-conformation']==cluster].iloc[0]['Term']
        except:
            enr_class_top1_label = 'unknown'
        df[label] = df[label].replace(cluster, str(cluster)+' '+enr_class_top1_label)

# GO_MF
enr = pd.read_csv('subconfiguration/AXO_dev_CP_subconformation_hierachy_clustering_GO_Molecular_Function_2023.pjust.05.csv')
for clusterclass, label in zip(['class1', 'class2'] ,['GO_MF_label1', 'GO_MF_label2']):
    enr_class = enr[enr["Class"]==clusterclass]
    for cluster in np.unique(df[clusterclass]):
        try:
            enr_class_top1_label = enr_class[enr_class['Sub-conformation']==cluster].iloc[0]['Term']
        except:
            enr_class_top1_label = 'unknown'
        df[label] = df[label].replace(cluster, str(cluster)+' '+enr_class_top1_label)

# KEGG
enr = pd.read_csv('subconfiguration/AXO_dev_CP_subconformation_hierachy_clustering_KEGG_2021_Human.pjust.05.csv')
for clusterclass, label in zip(['class1', 'class2'] ,['KEGG_label1', 'KEGG_label2']):
    enr_class = enr[enr["Class"]==clusterclass]
    for cluster in np.unique(df[clusterclass]):
        try:
            enr_class_top1_label = enr_class[enr_class['Sub-conformation']==cluster].iloc[0]['Term']
        except:
            enr_class_top1_label = 'unknown'
        df[label] = df[label].replace(cluster, str(cluster)+' '+enr_class_top1_label)

In [None]:
import pandas as pd
import numpy as np
labels=[]
enr = pd.read_csv('subconfiguration/AXO_dev_CP_subconformation_hierachy_clustering_GO_Molecular_Function_2023.pjust.05.csv', index_col=0)
for clusterclass, label in zip(['class1', 'class2'] ,['GO_BP_label1', 'GO_BP_label2']):
    enr_class = enr[enr["Class"]==clusterclass]
    for cluster in np.unique(enr_class['Sub-conformation']):
        try:
            enr_class_top1_label = enr_class[enr_class['Sub-conformation']==cluster].iloc[0]['Term']
            labels.append(enr_class_top1_label)
        except:
            pass

In [None]:
labels

##### draw and save

In [None]:
import plotly.io as pio
import plotly.graph_objs as go
from cytocraft.plot import get_cmap
# naming a layout theme for future reference
pio.templates["viridis"] = go.layout.Template(
    layout_colorway=get_cmap('viridis', 8)
)
pio.templates["google"] = go.layout.Template(
    layout_colorway=['#4285F4', '#DB4437', '#F4B400', '#0F9D58',
                     '#185ABC', '#B31412', '#EA8600', '#137333',
                     '#d2e3fc', '#ceead6']
)
# setting Google color palette as default
pio.templates.default = "viridis"

df=remove_rare_values(df, column_names=['class1','class2'])

In [None]:
labels=df['KEGG_label1'].unique()
dict(zip(labels,labels))

In [None]:
df['KEGG_label1'] = df['KEGG_label1'].map({'8 Ribosome': '8 Ribosome',
 '4 unknown': '4 unknown',
 '1 unknown': '1 unknown',
 '5 Oxidative phosphorylation': '5 Oxidative phosphorylation',
 '12 Ribosome': '12 Ribosome',
 '10 unknown': '10 unknown',
 '7 Ribosome': '7 Ribosome',
 '2 Endocytosis': '2 Endocytosis',
 '19 unknown': '18 unknown',
 '6 Ribosome': '6 Ribosome',
 '23 Synaptic vesicle cycle': '22 Synaptic vesicle cycle',
 '9 Ribosome': '9 Ribosome',
 '16 unknown': '15 unknown',
 '11 unknown': '11 unknown',
 '3 unknown': '3 unknown',
 '14 unknown': '13 unknown',
 '18 Amino sugar and nucleotide sugar metabolism': '17 Amino sugar and nucleotide sugar metabolism',
 '17 unknown': '16 unknown',
 '15 Cell adhesion molecules': '14 Cell adhesion molecules',
 '20 Type I diabetes mellitus': '19 Type I diabetes mellitus',
 '21 unknown': '20 unknown',
 '22 unknown': '21 unknown'})

df['GO_MF_label1'] = df['GO_MF_label1'].map({'8 RNA Binding': '8 RNA Binding',
 '4 MHC Class II Protein Binding': '4 MHC Class II Protein Binding',
 '1 Protein Serine/Threonine Phosphatase Activity': '1 Protein Serine/Threonine Phosphatase Activity',
 '5 Oxidoreduction-Driven Active Transmembrane Transporter Activity': '5 Oxidoreduction-Driven Active Transmembrane Transporter Activity',
 '12 RNA Binding': '12 RNA Binding',
 '10 RNA Binding': '10 RNA Binding',
 '7 RNA Binding': '7 RNA Binding',
 '2 Purine Ribonucleoside Triphosphate Binding': '2 Purine Ribonucleoside Triphosphate Binding',
 '19 Calcium Channel Inhibitor Activity': '18 Calcium Channel Inhibitor Activity',
 '6 RNA Binding': '6 RNA Binding',
 '23 Acetyltransferase Activator Activity': '22 Acetyltransferase Activator Activity',
 "9 5'-3' Exonuclease Activity": "9 5'-3' Exonuclease Activity",
 '16 Metal-Dependent Deubiquitinase Activity': '15 Metal-Dependent Deubiquitinase Activity',
 '11 mRNA Binding': '11 mRNA Binding',
 '3 RNA Binding': '3 RNA Binding',
 '14 unknown': '13 unknown',
 '18 Cytidylyltransferase Activity': '17 Cytidylyltransferase Activity',
 '17 Aminophospholipid Flippase Activity': '16 Aminophospholipid Flippase Activity',
 '15 Axon Guidance Receptor Activity': '14 Axon Guidance Receptor Activity',
 '20 Transmembrane Receptor Protein Phosphatase Activity': '19 Transmembrane Receptor Protein Phosphatase Activity',
 '21 Transferrin Receptor Binding': '20 Transferrin Receptor Binding',
 '22 Clathrin Light Chain Binding': '21 Clathrin Light Chain Binding'})

df['GO_BP_label1'] = df['GO_BP_label1'].map({'8 Cytoplasmic Translation': '8 Cytoplasmic Translation',
 '4 Humoral Immune Response Mediated By Circulating Immunoglobulin': '4 Humoral Immune Response Mediated By Circulating Immunoglobulin',
 '1 Peptidyl-Serine Dephosphorylation': '1 Peptidyl-Serine Dephosphorylation',
 '5 Negative Regulation Of Protein Polyubiquitination': '5 Negative Regulation Of Protein Polyubiquitination',
 '12 Macromolecule Biosynthetic Process': '12 Macromolecule Biosynthetic Process',
 '10 Cytoplasmic Translation': '10 Cytoplasmic Translation',
 '7 Macromolecule Biosynthetic Process': '7 Macromolecule Biosynthetic Process',
 '2 Vesicle-Mediated Transport': '2 Vesicle-Mediated Transport',
 '19 Negative Regulation Of Heart Rate': '18 Negative Regulation Of Heart Rate',
 '6 Cytoplasmic Translation': '6 Cytoplasmic Translation',
 '23 Synaptic Vesicle Docking': '22 Synaptic Vesicle Docking',
 '9 Cytoplasmic Translation': '9 Cytoplasmic Translation',
 '16 Regulation Of Wnt Signaling Pathway': '15 Regulation Of Wnt Signaling Pathway',
 '11 Cytoplasmic Translation': '11 Cytoplasmic Translation',
 '3 unknown': '3 unknown',
 '14 Regulation Of Dendrite Development': '13 Regulation Of Dendrite Development',
 '18 Nucleotide-Sugar Biosynthetic Process': '17 Nucleotide-Sugar Biosynthetic Process',
 '17 DNA-templated DNA Replication Maintenance Of Fidelity': '16 DNA-templated DNA Replication Maintenance Of Fidelity',
 '15 Positive Regulation Of Axon Extension': '14 Positive Regulation Of Axon Extension',
 '20 Regulation Of Protein Depolymerization': '19 Regulation Of Protein Depolymerization',
 '21 Positive Regulation Of Dendritic Spine Morphogenesis': '20 Positive Regulation Of Dendritic Spine Morphogenesis',
 '22 Regulation Of Hyaluronan Biosynthetic Process': '21 Regulation Of Hyaluronan Biosynthetic Process'})

In [None]:
# KEGG
import plotly.express as px
from plotly.io import write_image
df['KEGG_label1'] = df['KEGG_label1'].apply(lambda x: wrap_labels(x,28,ignore=['14']))
df['KEGG_label2'] = df['KEGG_label2'].apply(lambda x: wrap_labels(x,28,ignore=['5','7']))
fig = px.sunburst(df, path=['KEGG_label2','KEGG_label1'], values='weight')
fig.update_layout(autosize=False, width=600, height=600, margin=dict(l=0, r=0, t=0, b=0))
fig.update_traces(insidetextorientation='radial')
fig.show()
# export figure to pdf image
write_image(fig, 'figures/Sunburst_KEGG_AXO_dev_CP.pdf')

In [None]:
# GO MF
import plotly.express as px
from plotly.io import write_image
df['GO_MF_label1'] = df['GO_MF_label1'].apply(lambda x: wrap_labels(x,32,ignore=['15','18','19','21','22','23']))
df['GO_MF_label2'] = df['GO_MF_label2'].apply(lambda x: wrap_labels(x,32))
fig = px.sunburst(df, path=['GO_MF_label2','GO_MF_label1'], values='weight')
fig.update_layout(autosize=False, width=600, height=600, margin=dict(l=0, r=0, t=0, b=0))
fig.update_traces(insidetextorientation='radial')
fig.show()
# export figure to png image
write_image(fig, 'figures/Sunburst_GO_MF_AXO_dev_CP.pdf')

In [None]:
# GO BP
import plotly.express as px
from plotly.io import write_image
df['GO_BP_label1'] = df['GO_BP_label1'].apply(lambda x: wrap_labels(x,32,ignore=['15','14']))
df['GO_BP_label2'] = df['GO_BP_label2'].apply(lambda x: wrap_labels(x,32,ignore=['6']))
fig = px.sunburst(df, path=['GO_BP_label2','GO_BP_label1'], values='weight')
fig.update_layout(autosize=False, width=600, height=600,margin=dict(l=0, r=0, t=0, b=0))
fig.update_traces(insidetextorientation='radial')
fig.show()
# export figure to png image
write_image(fig, 'figures/Sunburst_GO_BP_AXO_dev_CP.pdf')

### VLMC

#### Preprocess

In [None]:
import os
from functools import reduce
import scanpy as sc
import numpy as np
from tqdm import tqdm
Xs={}
GeneLists={}
adatas={}

result_path="./results/Axolotls/"
samples = [ name for name in sorted(os.listdir(result_path)) if os.path.isdir(os.path.join(result_path, name)) ]
for sample in samples:
    if sample[0].isalpha():
        result_dir = os.path.join(result_path,sample)
        files = os.listdir(os.path.join(result_dir))
        info = sample.split(".")
        chip = info[0]
        ctype = info[1][:-5]
        sample = chip+'_'+ctype
        if 'adata.h5ad' in files and ctype=='VLMC': 
            adata = sc.read_h5ad(os.path.join(result_dir+"/adata.h5ad"))
            adatas[sample] = adata
            Xs[sample] = adata.uns['X']
            GeneLists[sample] = adata.uns['X'].index

intersected_genes = reduce(lambda x, y: x.intersection(y), GeneLists.values())
M = intersected_genes.size
print(M)

In [None]:
# Build gene-gene distance matrix
keys = ['Stage44_VLMC',
        'Stage54_VLMC',
        'Stage57_VLMC',
        'Control_Juv_VLMC',
        'Adult_VLMC',
        'Meta_VLMC']
N = len(keys)
GG_dev = np.zeros((N,M,M))
for n in tqdm(range(N)):
    index_k_intersection = [list(GeneLists[keys[n]]).index(gene) for gene in intersected_genes]
    for i, idx_n in enumerate(index_k_intersection):
        for j, idx_m in enumerate(index_k_intersection[:i+1]):
            g1_cord = Xs[keys[n]].iloc[idx_n,:].to_numpy()
            g2_cord = Xs[keys[n]].iloc[idx_m,:].to_numpy()
            d = np.linalg.norm(g1_cord - g2_cord)
            GG_dev[n,i,j] = GG_dev[n,j,i] = d

In [None]:
import statistics
V_dev_VLMC = np.zeros((M,M))
for i in range(M):
    for j in range(M):
        v = statistics.variance(GG_dev[:,i,j])
        V_dev_VLMC[i,j] = V_dev_VLMC[j,i] = v

In [None]:
# translate gene
import pandas as pd
anno = pd.read_csv("./Axo_Summary_Gene_Annotation_0325.txt", sep='\t')
annodict = {row['Axolotl_ID']: row['hs_gene'] for _, row in anno.iterrows() if row['hs_gene'] != '-'}

intersected_symbols=[]
for i in intersected_genes:
    intersected_symbols.append(annodict.get(i, i))

In [None]:
V_dev_sqrt_VLMC = np.sqrt(V_dev_VLMC)

#### Heatmap

In [None]:
# hierarchy
import seaborn as sns
import pandas as pd
import matplotlib.patches as patches
import matplotlib.pyplot as plt
from scipy.cluster.hierarchy import fcluster

clustergrid = sns.clustermap(V_dev_sqrt_VLMC, xticklabels=intersected_symbols, yticklabels=intersected_symbols, method='ward')
plt.close()

# set
sns.set(rc={"figure.dpi":300,'figure.figsize':(10, 10)},font_scale=1)

ax = sns.clustermap(V_dev_sqrt_VLMC, xticklabels=[], yticklabels=[], cmap="viridis", method='ward', cbar_kws={'ticks': [0,5,10],'orientation':'horizontal'}, vmax=10)
ax.ax_cbar.set_title("Standard Deviation", fontdict={'fontsize': 15})
ax.ax_cbar.set_position((0.03, .89, .20, .05))
# Remove the dendrograms
#ax.ax_row_dendrogram.set_visible(False)
#ax.ax_col_dendrogram.set_visible(False)

## subconformation 1
# Get clusters with threshold
linkage = clustergrid.dendrogram_row.linkage
threshold = 60
clusters = fcluster(linkage, threshold, criterion='distance')
# add patches
row_order = ax.dendrogram_row.reordered_ind
# start
x = y = 0
# width
w = h = 0
formal_c = 1
count = 1
for c in np.roll(clusters[row_order],-1):
    if c == formal_c:
        formal_c = c
        count += 1
    else:
        w = h = count
        if w>=3:
            ax.ax_heatmap.add_patch(patches.Rectangle((x, y), w, h, fill=False, edgecolor='white', lw=1.5, clip_on=False))
        x += count
        y = x
        formal_c = c
        count = 1

## subconformation 2
# Get clusters with threshold
linkage = clustergrid.dendrogram_row.linkage
threshold = 30
clusters = fcluster(linkage, threshold, criterion='distance')
# add patches
row_order = ax.dendrogram_row.reordered_ind
# start
x = y = 0
# width
w = h = 0
formal_c = 1
count = 1
for c in np.roll(clusters[row_order],-1):
    if c == formal_c:
        formal_c = c
        count += 1
    else:
        w = h = count
        if w>=3:
            ax.ax_heatmap.add_patch(patches.Rectangle((x, y), w, h, fill=False, edgecolor='orange', lw=1.5, clip_on=False))
        x += count
        y = x
        formal_c = c
        count = 1

plt.savefig("figures/Fig5f_l.pdf", format="pdf", bbox_inches="tight")
# mask
#mask = np.tril(np.ones_like(V_dev_sqrt_VLMC),k=-1)
#values = ax.ax_heatmap.collections[0].get_array().reshape(V_dev_sqrt_VLMC.shape)
#new_values = np.ma.array(values, mask=mask)
#ax.ax_heatmap.collections[0].set_array(new_values)

In [None]:
# import the Python Image 
# processing Library 
from PIL import Image 

# Giving The Original image Directory 
# Specified 
Original_Image = Image.open("./Clustermap_subconf_Axo_dev_sqrt_VLMC.png") 

# Rotate Image By 180 Degree 
rotated_image1 = Original_Image.rotate(180) 

# This is Alternative Syntax To Rotate 
# The Image 
rotated_image2 = Original_Image.transpose(Transpose.ROTATE_90)

# This Will Rotate Image By 60 Degree 
rotated_image3 = Original_Image.rotate(60) 

rotated_image1.show() 
rotated_image2.show() 
rotated_image3.show() 


In [None]:
import cv2  # importing cv 
import imutils 

# read an image as input using OpenCV 
image = cv2.imread("./Clustermap_subconf_Axo_dev_sqrt_VLMC.jpg") 
#Rotated_image = imutils.rotate(image, angle=45) 
#Rotated1_image = imutils.rotate(image, angle=90) 
  
# display the image using OpenCV of 
# angle 45 
#cv2.imshow("Rotated", Rotated_image) 
  
# display the image using OpenCV of
# angle 90 
#cv2.imshow("Rotated", Rotated1_image)
# This is used for To Keep On Displaying 
# The Image Until Any Key is Pressed 
#cv2.waitKey(0) 

#### Sunburst

In [None]:
# Generates Seaborn clustermap and dendrograms
df = pd.DataFrame(V_dev_sqrt_VLMC)
col_name = 'class'
sns_cluster = sns.clustermap(df, method='ward', metric='euclidean')
plt.close()
row_linkage = sns_cluster.dendrogram_row.linkage
col_linkage = sns_cluster.dendrogram_col.linkage

# Assigns clusters to the data
df[col_name + '1'] = fcluster(row_linkage, t=30, criterion='distance')
df[col_name + '2'] = fcluster(row_linkage, t=60, criterion='distance')
df['weight']=1

##### Enrichment

In [None]:
import gseapy as gp
for gs in ['KEGG_2021_Human','GO_Molecular_Function_2023','GO_Biological_Process_2023']:
    enr = pd.DataFrame(columns=['Class','Sub-conformation','Gene_set', 'Term', 'Overlap', 'P-value', 'Adjusted P-value', 'Old P-value',	'Old Adjusted P-value',	'Odds Ratio', 'Combined Score', 'Genes'])

    # Enricr API
    for clusterclass in ['class1','class2']:
        unique_values, counts = np.unique(df[clusterclass], return_counts=True)
        for value, count in zip(unique_values, counts):
            if count >= 3:
                try:
                    genelist = np.array(intersected_symbols)[df[clusterclass]==value].tolist()
                    filterlist = list(filter(lambda x: 'AMEX60DD' not in x, genelist))
                    enr_up = gp.enrichr(gene_list=filterlist,
                                        organism='Human',
                                        gene_sets=gs,
                                        outdir=None, 
                                        cutoff=0.05,
                                        no_plot=True)

                    # trim (go:...)
                    enr_up.res2d.Term = enr_up.res2d.Term.str.split(" \(GO").str[0]
                    enr_up.res2d['Class'] = clusterclass
                    enr_up.res2d['Sub-conformation'] = value
                    enr = pd.concat([enr, enr_up.res2d[enr_up.res2d['Adjusted P-value']<0.05]], ignore_index=True)
                    # dotplot
                    # gp.dotplot(enr_up.res2d, title="Cluster "+str(value), size=10, cmap = plt.cm.viridis_r)
                except ValueError or FileNotFoundError or TypeError as e:
                    print(value)
                    print(e)
                    continue

    enr.to_csv('AXO_dev_VLMC_subconformation_hierachy_clustering_'+gs+'.pjust.05.csv')

In [None]:
# write df
df['KEGG_label1']=df['GO_BP_label1']=df['GO_MF_label1']=df['class1']
df['KEGG_label2']=df['GO_BP_label2']=df['GO_MF_label2']=df['class2']

# GO_BP
enr = pd.read_csv('subconfiguration/AXO_dev_VLMC_subconformation_hierachy_clustering_GO_Biological_Process_2023.pjust.05.csv')
for clusterclass, label in zip(['class1', 'class2'] ,['GO_BP_label1', 'GO_BP_label2']):
    enr_class = enr[enr["Class"]==clusterclass]
    for cluster in np.unique(df[clusterclass]):
        try:
            enr_class_top1_label = enr_class[enr_class['Sub-conformation']==cluster].iloc[0]['Term']
        except:
            enr_class_top1_label = 'unknown'
        df[label] = df[label].replace(cluster, str(cluster)+' '+enr_class_top1_label)

# GO_MF
enr = pd.read_csv('subconfiguration/AXO_dev_VLMC_subconformation_hierachy_clustering_GO_Molecular_Function_2023.pjust.05.csv')
for clusterclass, label in zip(['class1', 'class2'] ,['GO_MF_label1', 'GO_MF_label2']):
    enr_class = enr[enr["Class"]==clusterclass]
    for cluster in np.unique(df[clusterclass]):
        try:
            enr_class_top1_label = enr_class[enr_class['Sub-conformation']==cluster].iloc[0]['Term']
        except:
            enr_class_top1_label = 'unknown'
        df[label] = df[label].replace(cluster, str(cluster)+' '+enr_class_top1_label)

# KEGG
enr = pd.read_csv('subconfiguration/AXO_dev_VLMC_subconformation_hierachy_clustering_KEGG_2021_Human.pjust.05.csv')
for clusterclass, label in zip(['class1', 'class2'] ,['KEGG_label1', 'KEGG_label2']):
    enr_class = enr[enr["Class"]==clusterclass]
    for cluster in np.unique(df[clusterclass]):
        try:
            enr_class_top1_label = enr_class[enr_class['Sub-conformation']==cluster].iloc[0]['Term']
        except:
            enr_class_top1_label = 'unknown'
        df[label] = df[label].replace(cluster, str(cluster)+' '+enr_class_top1_label)

##### draw and save

In [None]:
import plotly.io as pio
import plotly.graph_objs as go
from cytocraft.plot import get_cmap
# naming a layout theme for future reference
pio.templates["viridis"] = go.layout.Template(
    layout_colorway=get_cmap('viridis', 8)
)
pio.templates["google"] = go.layout.Template(
    layout_colorway=['#4285F4', '#DB4437', '#F4B400', '#0F9D58',
                     '#185ABC', '#B31412', '#EA8600', '#137333',
                     '#d2e3fc', '#ceead6']
)
# setting Google color palette as default
pio.templates.default = "viridis"

In [None]:
df=remove_rare_values(df, column_names=['class1','class2'])

In [None]:
labels=df['KEGG_label1'].unique()
dict(zip(labels,labels))

In [None]:
df['KEGG_label1'] = df['KEGG_label1'].map({'7 Ribosome': '7 Ribosome',
    '12 unknown': '12 unknown',
    '15 Ribosome': '15 Ribosome',
    '9 Ribosome': '9 Ribosome',
    '6 Ribosome': '6 Ribosome',
    '10 Diabetic cardiomyopathy': '10 Diabetic cardiomyopathy',
    '2 unknown': '2 unknown',
    '3 Ribosome': '3 Ribosome',
    '5 Ribosome': '5 Ribosome',
    '14 Pentose phosphate pathway': '14 Pentose phosphate pathway',
    '8 Amyotrophic lateral sclerosis': '8 Amyotrophic lateral sclerosis',
    '4 unknown': '4 unknown',
    '1 Adrenergic signaling in cardiomyocytes': '1 Adrenergic signaling in cardiomyocytes',
    '22 unknown': '21 unknown',
    '25 Human papillomavirus infection': '23 Human papillomavirus infection',
    '11 unknown': '11 unknown',
    '20 unknown': '19 unknown',
    '16 unknown': '16 unknown',
    '26 unknown': '24 unknown',
    '13 unknown': '13 unknown',
    '24 Legionellosis': '22 Legionellosis',
    '18 Type I diabetes mellitus': '18 Type I diabetes mellitus',
    '21 Vasopressin-regulated water reabsorption': '20 Vasopressin-regulated water reabsorption',
    '17 Collecting duct acid secretion': '17 Collecting duct acid secretion'})

df['GO_MF_label1'] = df['GO_MF_label1'].map({'7 RNA Binding': '7 RNA Binding',
    '12 unknown': '12 unknown',
    '15 Translation Initiation Factor Activity': '15 Translation Initiation Factor Activity',
    '9 RNA Binding': '9 RNA Binding',
    '6 RNA Binding': '6 RNA Binding',
    '10 RNA Binding': '10 RNA Binding',
    '2 Cadherin Binding': '2 Cadherin Binding',
    '3 Serine-Type Endopeptidase Inhibitor Activity': '3 Serine-Type Endopeptidase Inhibitor Activity',
    '5 RNA Binding': '5 RNA Binding',
    '14 Dynactin Binding': '14 Dynactin Binding',
    "8 pre-mRNA 3'-Splice Site Binding": "8 pre-mRNA 3'-Splice Site Binding",
    '4 P-type Calcium Transporter Activity': '4 P-type Calcium Transporter Activity',
    '1 RNA Binding': '1 RNA Binding',
    '22 High-Density Lipoprotein Particle Binding': '21 High-Density Lipoprotein Particle Binding',
    '25 Platelet-Derived Growth Factor Binding': '23 Platelet-Derived Growth Factor Binding',
    '11 pre-mRNA Binding': '11 pre-mRNA Binding',
    '20 Arginine N-methyltransferase Activity': '19 Arginine N-methyltransferase Activity',
    '16 acetyl-CoA C-acyltransferase Activity': '16 acetyl-CoA C-acyltransferase Activity',
    '26 Axon Guidance Receptor Activity': '24 Axon Guidance Receptor Activity',
    '13 unknown': '13 unknown',
    '24 Disulfide Oxidoreductase Activity': '22 Disulfide Oxidoreductase Activity',
    '18 Transmembrane Receptor Protein Phosphatase Activity': '18 Transmembrane Receptor Protein Phosphatase Activity',
    '21 GDP-dissociation Inhibitor Activity': '20 GDP-dissociation Inhibitor Activity',
    '17 Proton-Transporting ATPase Activity, Rotational Mechanism': '17 Proton-Transporting ATPase Activity, Rotational Mechanism'})

df['GO_BP_label1'] = df['GO_BP_label1'].map({'7 Cytoplasmic Translation': '7 Cytoplasmic Translation',
    '12 Cytoplasmic Translation': '12 Cytoplasmic Translation',
    '15 Cytoplasmic Translation': '15 Cytoplasmic Translation',
    '9 Cytoplasmic Translation': '9 Cytoplasmic Translation',
    '6 Cytoplasmic Translation': '6 Cytoplasmic Translation',
    '10 Alternative mRNA Splicing, Via Spliceosome': '10 Alternative mRNA Splicing, Via Spliceosome',
    '2 Glucose Metabolic Process': '2 Glucose Metabolic Process',
    '3 Regulation Of Translation': '3 Regulation Of Translation',
    '5 Cytoplasmic Translation': '5 Cytoplasmic Translation',
    '14 Protein Stabilization': '14 Protein Stabilization',
    '8 unknown': '8 unknown',
    '4 Regulation Of Cellular Response To Insulin Stimulus': '4 Regulation Of Cellular Response To Insulin Stimulus',
    '1 Cytoplasmic Translation': '1 Cytoplasmic Translation',
    '22 Regulation Of Cytokinesis': '21 Regulation Of Cytokinesis',
    '25 Regulation Of Microtubule Binding': '23 Regulation Of Microtubule Binding',
    '11 unknown': '11 unknown',
    '20 Negative Regulation Of Programmed Cell Death': '19 Negative Regulation Of Programmed Cell Death',
    '16 Hypotonic Response': '16 Hypotonic Response',
    '26 Neuron Projection Guidance': '24 Neuron Projection Guidance',
    '13 SNARE Complex Assembly': '13 SNARE Complex Assembly',
    '24 Regulation Of Gene Expression': '22 Regulation Of Gene Expression',
    '18 Regulation Of Transport': '18 Regulation Of Transport',
    '21 Negative Regulation Of Trophoblast Cell Migration': '20 Negative Regulation Of Trophoblast Cell Migration',
    '17 ncRNA Transcription': '17 ncRNA Transcription'})

In [None]:
# KEGG
import plotly.express as px
from plotly.io import write_image
df['KEGG_label1'] = df['KEGG_label1'].apply(lambda x: wrap_labels(x,35,ignore=['18']))
df['KEGG_label2'] = df['KEGG_label2'].apply(lambda x: wrap_labels(x,32))
fig = px.sunburst(df, path=['KEGG_label2','KEGG_label1'], values='weight')
fig.update_layout(autosize=False, width=600, height=600, margin=dict(l=0, r=0, t=0, b=0))
fig.update_traces(insidetextorientation='radial')
fig.show()
# export figure to pdf image
write_image(fig, 'figures/Sunburst_KEGG_AXO_dev_VLMC.pdf', scale=3)

In [None]:
# GO MF
import plotly.express as px
from plotly.io import write_image
df['GO_MF_label1'] = df['GO_MF_label1'].apply(lambda x: wrap_labels(x,32,ignore=['18','21','25', '20']))
df['GO_MF_label2'] = df['GO_MF_label2'].apply(lambda x: wrap_labels(x,35, ignore=['10']))
fig = px.sunburst(df, path=['GO_MF_label2','GO_MF_label1'], values='weight')
fig.update_layout(autosize=True, width=600, height=600,margin=dict(l=0, r=0, t=0, b=0))
fig.update_traces(insidetextorientation='radial')
fig.update_layout(title='CP')
fig.show()
# export figure to png image
write_image(fig, 'figures/Sunburst_GO_MF_AXO_dev_VLMC.pdf')

In [None]:
# GO BP
import plotly.express as px
from plotly.io import write_image

df['GO_BP_label1'] = df['GO_BP_label1'].apply(lambda x: wrap_labels(x,35,ignore=['23']))
df['GO_BP_label2'] = df['GO_BP_label2'].apply(lambda x: wrap_labels(x,30,ignore=['6']))
fig = px.sunburst(df, path=['GO_BP_label2','GO_BP_label1'], values='weight', color_discrete_map='Set2')
fig.update_layout(autosize=False, width=600, height=600,margin=dict(l=0, r=0, t=0, b=0))
fig.update_traces(insidetextorientation='radial')
fig.show()
# export figure to pdf image
write_image(fig, 'figures/Sunburst_GO_BP_AXO_dev_VLMC.pdf')

## 3.similarity matrix

In [None]:
sns.set(rc={"figure.dpi":300,'figure.figsize':(5, 2)}, font_scale=1)
fig, axes = plt.subplots(1, 3, gridspec_kw={'width_ratios':[1,1,0.1]})

# VLMC heatmap
sns.heatmap(D_Axo_dev_VLMC, xticklabels=labels_dev, yticklabels=labels_dev, cmap="viridis", cbar=False, ax=axes[0])
axes[0].set_title('VLMC', fontdict={'fontsize': 10, 'fontweight': 'bold'})

# CP heatmap
sns.heatmap(D_Axo_dev_CP, xticklabels=labels_dev, yticklabels=labels_dev, cmap="viridis", cbar_ax=axes[2], ax=axes[1])
axes[1].set_title('CP', fontdict={'fontsize': 10, 'fontweight': 'bold'})

axes[2].set_ylabel('RMSD', rotation=90, va="bottom", loc="center", size=10, labelpad=10)
plt.tight_layout()
plt.savefig('figures/Fig5c_combined.pdf', bbox_inches='tight', dpi=300)
plt.show()


### CP

In [None]:
import os
from functools import reduce
import numpy as np
import scanpy as sc
import seaborn as sns
from cytocraft.craft import *
import matplotlib.pyplot as plt

adatas={}
GeneLists={}
GeneCounts={}

result_path="./results/Axolotls/"
samples = [ name for name in sorted(os.listdir(result_path)) if os.path.isdir(os.path.join(result_path, name)) ]
for sample in samples:
    if sample[0].isalpha():
        result_dir = os.path.join(result_path,sample)
        files = os.listdir(os.path.join(result_dir))
        info = sample.split(".")
        chip = info[0]
        ctype = info[1].split('_')[0]
        sample = chip+'_'+ctype
        if 'adata.h5ad' in files and ctype=='CP':
                adata = sc.read_h5ad(os.path.join(result_dir+"/adata.h5ad"))
                adatas[sample] = adata
                GeneLists[sample] = adata.uns['X'].index.values
                GeneCounts[sample] = adata.n_vars  

keys = ['Stage44_CP',
        'Stage54_CP',
        'Stage57_CP',
        'Control_Juv_CP',
        'Adult_CP',
        'Meta_CP']


#### Fig5c heatmap

In [None]:
D_Axo_dev_CP = RMSD_distance_matrix(adatas, order=keys, ngene=100)

labels_dev = ['St.44',
        'St.54',
        'St.57',
        'Juv.',
        'Adult',
        'Meta.']
sns.set(rc={"figure.dpi":300,'figure.figsize':(1.8, 1.5)},font_scale=1)
f,(ax,axcb) = plt.subplots(1,2, gridspec_kw={'width_ratios':[1,0.1]})
g = sns.heatmap(D_Axo_dev_CP, xticklabels=labels_dev, yticklabels=labels_dev, cmap="viridis", cbar_ax=axcb, ax=ax, vmax=D_Axo_dev_VLMC.max())
g.set_title('CP', fontdict={'fontsize': 10, 'fontweight': 'bold'})
axcb.set_ylabel('RMSD', rotation=90, va="bottom", loc="center", size=10, labelpad=10)
plt.savefig('figures/Fig5c_r.pdf', bbox_inches='tight', dpi=300)
plt.show()

#### Fig.5b Lineplot

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

count=[]
file_1 = './data/Axolotls/development/Stage44/Stage44.CP.csv'
file_2 = './data/Axolotls/development/Stage54/Stage54.CP.csv'
file_3 = './data/Axolotls/development/Stage57/Stage57.CP.csv'
file_4 = './data/Axolotls/development/Control_Juv/Control_Juv.CP.csv'
file_5 = './data/Axolotls/development/Adult/Adult.CP.csv'
file_6 = './data/Axolotls/development/Meta/Meta.CP.csv'

# Read the file and calculate the row count
for file in [file_5, file_4, file_6, file_1, file_2, file_3]:
    df = pd.read_csv(file, sep='\t')
    count.append(df['MIDCounts'].count())

# Converting Dictionary into DataFrame
df = pd.DataFrame(list(GeneLists.items()), columns=['Stage', 'Genes'])
df['Count'] = df['Genes'].str.len()
df['Stage'] = ['Adult','Juv.','Meta.','St.44','St.54','St.57']
df['Genes'] = GeneCounts.values()
df['Transcripts'] = count

df.set_index('Stage', inplace=True)
df=df.reindex(['St.44','St.54','St.57','Juv.','Adult','Meta.'])

from scipy.stats import pearsonr

transcription_center_count = df['Count']
transcript_count = df['Transcripts']

correlation, p_value = pearsonr(transcription_center_count, transcript_count)

# Draw Barplot
sns.set(rc={"figure.dpi":300,'figure.figsize':(3, 2.2)},font_scale=1)
sns.set_theme(style='white', palette='Set2')
ax=sns.lineplot(data=df, x='Stage', y='Transcripts', color='#023fa5')
ax.set_ylabel('Total transcript count', color='#023fa5')
plt.xticks(rotation=90, fontsize=12)
plt.yticks(fontsize=12)

ax2 = ax.twinx()
sns.lineplot(data=df, x='Stage', y='Count', color='#8e063b', ax=ax2)
ax2.set_ylabel('Transcription center count', color='#8e063b')
ax.set_xlabel('')
plt.title('CP')

ax.set_ylim(0,2500000)
ax2.set_ylim(0,10000)
ax.text(0.05, 0.98, f'Pearson r={correlation:.4f}\n$P={p_value:.4f}$', transform=ax.transAxes, fontsize=11, color='black', va='top')

plt.savefig('figures/Fig5b_d.pdf', bbox_inches='tight')
plt.show()

#### MDS

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn import manifold

mds = manifold.MDS(n_components=2, dissimilarity="precomputed", random_state=6)
results = mds.fit(D_Axo_dev_CP)
coords = results.embedding_
mds_CP=pd.DataFrame(coords, columns=['x','y'], index=labels_dev)

mds = manifold.MDS(n_components=2, dissimilarity="precomputed", random_state=6)
results = mds.fit(D_Axo_dev_VLMC)
coords = results.embedding_
mds_VLMC=pd.DataFrame(coords, columns=['x','y'], index=labels_dev)

In [None]:
mds_CP

In [None]:
mds_VLMC

In [None]:
sns.set(rc={"figure.dpi":300,'figure.figsize':(2, 2)}, font_scale=1)
sns.set_theme(style='white', palette='pastel')

ax = sns.scatterplot(data=mds, x='x', y='y', s=5)

for i in mds.index.values:
    plt.text(mds.loc[i]['x'],mds.loc[i]['y'], i, fontdict=dict(color='black', alpha=0.5, size=6))

ax.set_ylabel('y')
ax.set_xlabel('x')
ax.set_yticks([])
ax.set_xticks([])

In [None]:
sns.set(rc={"figure.dpi":300,'figure.figsize':(2, 2)}, font_scale=1)
sns.set_theme(style='white', palette='pastel')

ax = sns.scatterplot(data=mds_VLMC, x='x', y='y', s=5)

for i in mds_VLMC.index.values:
    plt.text(mds_VLMC.loc[i]['x'],mds_VLMC.loc[i]['y'], i, fontdict=dict(color='black', alpha=0.5, size=6))

ax.set_ylabel('y')
ax.set_xlabel('x')
ax.set_yticks([])
ax.set_xticks([])

#### netplot

In [None]:
import plotly.graph_objs as go
import itertools
import networkx as nx
random.seed(996)
np.random.seed(996)

# create an empty graph
G = nx.Graph()

# add nodes
N = D_Axo_dev_CP.shape[0]
for i in range(N):
    G.add_node(i)

# add edges
edges = []

# known edges:
#for i,j in [(3,4),(4,5),(5,1),(1,0),(0,2)]:
#    edges.append((i,j,{'weight':1/D_Axo_dev_CP[i,j]}))

for i in range(N):
    arr = D_Axo_dev_CP[i,:]
    min_indices = np.argpartition(arr, 3)[:3]
    min_values = arr[min_indices]
    for j in min_indices:
        if not i==j:
            edges.append((i,j,{'weight':1/arr[j]}))

#G.add_weighted_edges_from(edges)

G.add_edges_from(edges)
labels = nx.get_edge_attributes(G,'weight')
final = dict()
for key in labels:
    final[key] = round(labels[key], 2)

widths = nx.get_edge_attributes(G,'weight')
widths.update((key, value * 6) for key, value in widths.items())
nodelist = G.nodes()

plt.figure(figsize=(25,25))

#pos = nx.shell_layout(G)
pos = nx.spring_layout(G, k=0.5, iterations=100)
for n, p in pos.items():
    G.nodes[n]['pos'] = p

# edges
# edge_trace setup
edge_trace = []
for edge in G.edges():
    weight = G.edges[edge]['weight'] * 150 - 95
    edge_trace.append(go.Scatter(
       x=[],
       y=[],
       line=dict(width=weight, color='#888'),
       hoverinfo='none',
       mode='lines'))

for trace, edge in zip(edge_trace, G.edges()):
   x0, y0 = G.nodes[edge[0]]['pos']
   x1, y1 = G.nodes[edge[1]]['pos']
   trace['x'] += tuple([x0, x1, None])
   trace['y'] += tuple([y0, y1, None])

# nodes
node_trace = go.Scatter(
    x=[],
    y=[],
    text=[],
    mode='markers+text',
    hoverinfo='text',
    marker=dict(
        showscale=True,
        colorscale='YlOrRd',
        reversescale=False,
        color=[],
        size=50,
        colorbar=dict(
            thickness=7,
            title='Node Connections',
            xanchor='left',
            titleside='right',
            tickmode='array',
            tickvals=[1, 2, 3, 4, 5],
        ),
        line=dict(width=0)))

for node in G.nodes():
    x, y = G.nodes[node]['pos']
    node_trace['x'] += tuple([x])
    node_trace['y'] += tuple([y])

node_adjacencies = []
for node, adjacencies in enumerate(G.adjacency()):
    node_adjacencies.append(len(adjacencies[1]))
    
node_trace.marker.color = node_adjacencies +[4,5]+[1]
node_trace.text = labels_dev

fig = go.Figure(data = edge_trace+ [node_trace],
            layout=go.Layout(
                title="",
                titlefont=dict(size=16),
                showlegend=False,
                hovermode='closest',
                margin=dict(b=0, l=0, r=0, t=0),
                plot_bgcolor='#fff',
                annotations=[dict(
                    text="Development - CP",
                    showarrow=False,
                    xref="paper", yref="paper",
                    x=0.98, y=0.98, font=dict(size=16))],
                xaxis=dict(showgrid=False, zeroline=False,
                        showticklabels=False, mirror=True),
                yaxis=dict(showgrid=False, zeroline=False,
                showticklabels=False, mirror=True)))

fig.update_layout(
    autosize=False,
    width=400,
    height=300)

fig.show()
#save a figure of 300dpi, with 1.5 inches, and  height 0.75inches
pio.write_image(fig, "networkplot_axo_dev_CP.png", width=1.3*300, height=0.8*300, scale=4)

In [None]:
get_cmap('Blues',10)

In [None]:
from cytocraft.plot import *
color_map_sample = {
'St.44': '#b8d5ea',
'St.54': '#97c6df',
'St.57': '#4e9acb',
'Juv.': '#3282be',
'Adult': '#1a68ae',
'Meta.': '#084f99',
}
plot_network(D_Axo_dev_CP, labels=labels_dev, cmap=color_map_sample, solver='forceAtlas2Based', corder=0, html='network_Axo_dev_CP_neighbor2.html', edge_scale=10, edge_adjust=0, N_neighbor=2)

### VLMC

In [None]:
def RMSD_distance_matrix_v1(
    Confs, GeneLists, order=None, ngene=100, compare_method="pair", norm_method=None
):
    # check if Confs and GeneLists have the same keys
    if order==None:
        keys = list(Confs.keys())
    else:
        keys = order
    for i in keys:
        if not i in GeneLists.keys():
            print(f"Warning: {i} in Confs but not in GeneLists")
            return

    # calculate the distance matrix
    N = len(keys)
    DM = np.zeros((N, N))
    if compare_method == "complete":
        from functools import reduce

        intersected_values = reduce(np.intersect1d, GeneLists.values())
        intersected_values = intersected_values[:ngene]
        if len(intersected_values) < ngene:
            print(
                f"Warning: {len(intersected_values)} common genes between samples are less than {ngene}"
            )
    from tqdm import tqdm

    for n, key_n in enumerate(tqdm(keys)):
        for m, key_m in enumerate(keys[: n + 1]):
            if compare_method == "pair":
                intersected_values = np.intersect1d(GeneLists[key_n], GeneLists[key_m])
                intersected_values = intersected_values[:ngene]
                if len(intersected_values) < ngene:
                    print(
                        f"Warning: {len(intersected_values)} common genes between {key_n} and {key_m} are less than {ngene}"
                    )
            boolean_arrays_n = np.in1d(GeneLists[key_n], intersected_values)
            boolean_arrays_m = np.in1d(GeneLists[key_m], intersected_values)
            Conf_n = Confs[key_n][boolean_arrays_n]
            Conf_m = Confs[key_m][boolean_arrays_m]
            if norm_method:
                Conf_n = normalizeF(Conf_n, method=norm_method)
                Conf_m = normalizeF(Conf_m, method=norm_method)
            d1, _, _ = numpy_svd_rmsd_rot(Conf_n, Conf_m)
            d2, _, _ = numpy_svd_rmsd_rot(mirror(Conf_n), Conf_m)
            DM[n, m] = DM[m, n] = min(d1, d2)
    return DM

In [None]:
import os
from functools import reduce
import numpy as np
import scanpy as sc
import seaborn as sns
from cytocraft.craft import *
import matplotlib.pyplot as plt

adatas={}
GeneLists={}
GeneCounts={}

result_path="./results/Axolotls/"
samples = [ name for name in sorted(os.listdir(result_path)) if os.path.isdir(os.path.join(result_path, name)) ]
for sample in samples:
    if sample[0].isalpha():
        result_dir = os.path.join(result_path,sample)
        files = os.listdir(os.path.join(result_dir))
        info = sample.split(".")
        chip = info[0]
        ctype = info[1].split('_')[0]
        sample = chip+'_'+ctype
        if 'adata.h5ad' in files and ctype=='VLMC':
                adata = sc.read_h5ad(os.path.join(result_dir+"/adata.h5ad"))
                adatas[sample] = adata
                GeneLists[sample] = adata.uns['X'].index.values
                GeneCounts[sample] = adata.n_vars  

keys = ['Stage44_VLMC',
        'Stage54_VLMC',
        'Stage57_VLMC',
        'Control_Juv_VLMC',
        'Adult_VLMC',
        'Meta_VLMC']

#### Fig5c heatmap

In [None]:
D_Axo_dev_VLMC.max()

In [None]:
D_Axo_dev_VLMC = RMSD_distance_matrix(adatas, order=keys, ngene=100)

labels_dev = ['St.44',
        'St.54',
        'St.57',
        'Juv.',
        'Adult',
        'Meta.']  
   
sns.set(rc={"figure.dpi":300,'figure.figsize':(1.8, 1.5)},font_scale=1)
f,(ax,axcb) = plt.subplots(1,2, gridspec_kw={'width_ratios':[1,0.1]})
g = sns.heatmap(D_Axo_dev_VLMC, xticklabels=labels_dev, yticklabels=labels_dev, cmap="viridis", cbar_ax=axcb, ax=ax)
g.set_title('VLMC', fontdict={'fontsize': 10, 'fontweight': 'bold'})
axcb.set_ylabel('RMSD', rotation=90, va="bottom", loc="center", size=10, labelpad=10)
plt.savefig('figures/Fig5c_l.pdf', bbox_inches='tight', dpi=300)
plt.show()

#### Fig.5b lineplot

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

count=[]
file_1 = './data/Axolotls/development/Stage44/Stage44.VLMC.csv'
file_2 = './data/Axolotls/development/Stage54/Stage54.VLMC.csv'
file_3 = './data/Axolotls/development/Stage57/Stage57.VLMC.csv'
file_4 = './data/Axolotls/development/Control_Juv/Control_Juv.VLMC.csv'
file_5 = './data/Axolotls/development/Adult/Adult.VLMC.csv'
file_6 = './data/Axolotls/development/Meta/Meta.VLMC.csv'

# Read the file and calculate the row count
for file in [file_5, file_4, file_6, file_1, file_2, file_3]:
    df = pd.read_csv(file, sep='\t')
    count.append(df['MIDCounts'].count())

# Converting Dictionary into DataFrame
df = pd.DataFrame(list(GeneLists.items()), columns=['Stage', 'Genes'])
df['Count'] = df['Genes'].str.len()
df['Stage'] = ['Adult','Juv.','Meta.','St.44','St.54','St.57']
df['Genes'] = GeneCounts.values()
df['Transcripts'] = count

df.set_index('Stage', inplace=True)
df=df.reindex(['St.44','St.54','St.57','Juv.','Adult','Meta.'])

from scipy.stats import pearsonr

transcription_center_count = df['Count']
transcript_count = df['Transcripts']

correlation, p_value = pearsonr(transcription_center_count, transcript_count)

# Draw Barplot
sns.set(rc={"figure.dpi":300,'figure.figsize':(3, 2.2)},font_scale=1)
sns.set_theme(style='white', palette='Set2')
ax=sns.lineplot(data=df, x='Stage', y='Transcripts', color='#023fa5')
ax.set_ylabel('Total transcript count', color='#023fa5')
plt.xticks(rotation=90, fontsize=12)
plt.yticks(fontsize=12)

ax2 = ax.twinx()
sns.lineplot(data=df, x='Stage', y='Count', color='#8e063b', ax=ax2)
ax2.set_ylabel('Transcription Center Count', color='#8e063b')
ax.set_xlabel('')

ax.set_ylim(0,2500000)
ax2.set_ylim(0,10000)
ax.text(0.05, 0.98, f'Pearson r={correlation:.4f}\n$P={p_value:.4f}$', transform=ax.transAxes, fontsize=11, color='black', va='top')

plt.title('VLMC')
plt.savefig('figures/Fig5b_u.pdf', bbox_inches='tight')
#ax.set_ylim(0,10000)
plt.show()

In [None]:
from scipy.stats import spearmanr

transcription_center_count = df['Count']
transcript_count = df['Transcripts']

correlation, p_value = spearmanr(transcription_center_count, transcript_count)

correlation, p_value

In [None]:
from scipy.stats import pearsonr

transcription_center_count = df['Count']
transcript_count = df['Transcripts']

correlation, p_value = pearsonr(transcription_center_count, transcript_count)

correlation, p_value

#### netplot

In [None]:
from cytocraft.plot import *
color_map_sample = {
'St.44': '#b8d5ea',
'St.54': '#97c6df',
'St.57': '#4e9acb',
'Juv.': '#3282be',
'Adult': '#1a68ae',
'Meta.': '#084f99',
}
plot_network(D_Axo_dev_VLMC, labels=labels_dev, cmap=color_map_sample, solver='forceAtlas2Based', corder=0, html='network_Axo_dev_VLMC_neighbor2.html', edge_scale=10, edge_adjust=0, N_neighbor=2)

## 4.gene pairwise distance pattern

### CP

In [None]:
import os
from functools import reduce
import numpy as np
import scanpy as sc
import seaborn as sns
import pandas as pd

Xs={}
GeneLists={}

result_path="./results/Axolotls/"
samples = [ name for name in sorted(os.listdir(result_path)) if os.path.isdir(os.path.join(result_path, name)) ]
for sample in samples:
    if sample[0].isalpha():
        result_dir = os.path.join(result_path,sample)
        files = os.listdir(os.path.join(result_dir))
        info = sample.split(".")
        chip = info[0]
        ctype = info[1].split('_')[0]
        sample = chip+'_'+ctype
        if 'adata.h5ad' in files and ctype=='CP': 
            print(sample)
            adata = sc.read_h5ad(os.path.join(result_dir+"/adata.h5ad"))
            Xs[sample] = adata.uns['X']
            GeneLists[sample] = adata.uns['X'].index

intersected_genes = reduce(np.intersect1d, GeneLists.values())
boolean_arrays = [np.in1d(a, intersected_genes) for a in GeneLists.values()]

from scipy.spatial.distance import pdist, squareform

D_pattern = pd.DataFrame()
order=0
for s, df in Xs.items():
    print(s)
    distances = pdist(df[boolean_arrays[order]].values, metric='euclidean')
    df_dist = pd.DataFrame(squareform(distances), index=intersected_genes, columns=intersected_genes).stack()
    df_dist = df_dist.reset_index()
    df_dist.columns = ['gene1', 'gene2', 'distance']
    inf = s.split('_')
    df_dist['stage'] = inf[0]
    df_dist['celltype'] = inf[1]
    df_dist['genepair'] = df_dist['gene1'] + "," + df_dist['gene2']
    D_pattern = D_pattern.append(df_dist)
    order+=1

D_pattern.reset_index(drop=True, inplace=True)
matrix_pattern = D_pattern.pivot_table(index='genepair', columns='stage', values='distance')
matrix_pattern.to_csv('Axo_dev_CP_pattern.mx.csv')

#### enrichment

In [None]:
import pandas as pd
import numpy as np
Mfuzz_tumor=pd.read_csv('Mfuzz_AXO_dev_CP.xls', sep='\t')
Mfuzz_tumor_sub=Mfuzz_tumor[Mfuzz_tumor['max'] >0.95]

# unique gene list in each cluster
unique_gs_cluster={}
cluster_names = np.unique(Mfuzz_tumor_sub['Cluster'])
for i in cluster_names:
    Mfuzz_tumor_sub_cluster = Mfuzz_tumor_sub[Mfuzz_tumor_sub['Cluster'] == i]
    gene_list = Mfuzz_tumor_sub_cluster['gene'].apply(lambda x: x.split(','))
    genes = [gene for sublist in gene_list for gene in sublist]
    unique_gs_cluster[i] = set(genes)

In [None]:
# translate gene
new_unique_gs_cluster={}
anno = pd.read_csv("./Axo_Summary_Gene_Annotation_0325.txt", sep='\t')
annodict = {row['Axolotl_ID']: row['hs_gene'] for _, row in anno.iterrows() if row['hs_gene'] != '-'}
for cluster in cluster_names:
    new_unique_gs_cluster[cluster] = {annodict.get(i, i) for i in unique_gs_cluster[cluster]}

In [None]:
# enrichment
import gseapy as gp
enr = pd.DataFrame(columns=['Cluster','Gene_set', 'Term', 'Overlap', 'P-value', 'Adjusted P-value', 'Old P-value',	'Old Adjusted P-value',	'Odds Ratio', 'Combined Score', 'Genes'])

# Enricr API
for key, value in new_unique_gs_cluster.items():
    try:
        for gs in ['KEGG_2021_Human','GO_Molecular_Function_2023','GO_Biological_Process_2023']:
            filterlist = list(filter(lambda x: 'AMEX60DD' not in x, list(value)))
            enr_up = gp.enrichr(gene_list=filterlist,
                                organism='Human',
                                gene_sets=gs,
                                outdir=None, 
                                cutoff=0.05,
                                no_plot= True)

            # trim (go:...)
            enr_up.res2d.Term = enr_up.res2d.Term.str.split(" \(GO").str[0]
            enr_up.res2d['Cluster'] = key
            enr = pd.concat([enr, enr_up.res2d[enr_up.res2d['Adjusted P-value']<0.05]], ignore_index=True)
            # dotplot
            # gp.dotplot(enr_up.res2d, title="Cluster "+str(value), size=10, cmap = plt.cm.viridis_r)
    except ValueError or FileNotFoundError or TypeError as e:
        print(value)
        print(e)
        continue


In [None]:
# KEGG_2021_Human
import matplotlib as mpl
import seaborn as sns
import matplotlib.pyplot as plt

f,axs = plt.subplots(len(cluster_names),1)
sns.set(rc={"figure.dpi":300,'figure.figsize':(1, 3)},font_scale=1)
sns.set_theme(style='white', palette='pastel')
plt.subplots_adjust(left=0, bottom=-.1, right=1, top=1.1)
cmap = mpl.colormaps['viridis_r'] # create a color map object
enr_gs = enr[enr["Gene_set"]=='KEGG_2021_Human']

for i,cluster in enumerate(cluster_names):
    enr_top5_gs_cluster = enr_gs[enr_gs["Cluster"]==cluster][:5]
    log_adj_pvalue = -np.log10(enr_top5_gs_cluster["Adjusted P-value"])
    norm = mpl.colors.Normalize(vmin=0, vmax=30)
    colors = [cmap(norm(value)) for value in log_adj_pvalue]
    g = sns.barplot(data=enr_top5_gs_cluster, x="Combined Score", y=enr_top5_gs_cluster["Term"], hue=enr_top5_gs_cluster["Term"], fill="Adjusted P-value", ax=axs[i], palette=colors, legend=False)
    g.set_xlabel("")
    axs[i].yaxis.set_label_position("right")
    g.set_ylabel(cluster,rotation=270,labelpad=12)
    axs[i].tick_params(axis='x', which='major', pad=-4)

g.set_xlabel("Combined Score")
mappable = mpl.cm.ScalarMappable(cmap=cmap, norm=norm) # create a scalar mappable object
cbar = f.colorbar(mappable, ax=axs, orientation='horizontal',label='-log10(Adjusted P-value)') # add a colorbar to the figure
cbar.ax.set_position([-0.5, -0.4, 1.5, 0.5]) # xmin, ymin, dx, and dy
plt.show()
#f.savefig('Mfuzz_KEGG_Axo_dev_CP.png')

In [None]:
# GO_Molecular_Function_2023
import matplotlib as mpl
import seaborn as sns
import matplotlib.pyplot as plt

f,axs = plt.subplots(len(cluster_names),1)
sns.set(rc={"figure.dpi":300,'figure.figsize':(1, 3)},font_scale=1)
sns.set_theme(style='white', palette='pastel')
plt.subplots_adjust(left=0, bottom=-.1, right=1, top=1.1)
cmap = mpl.colormaps['viridis_r'] # create a color map object
enr_gs = enr[enr["Gene_set"]=='GO_Molecular_Function_2023']

for i,cluster in enumerate(cluster_names):
    enr_top5_gs_cluster = enr_gs[enr_gs["Cluster"]==cluster][:5]
    log_adj_pvalue = -np.log10(enr_top5_gs_cluster["Adjusted P-value"])
    norm = mpl.colors.Normalize(vmin=0, vmax=30)
    colors = [cmap(norm(value)) for value in log_adj_pvalue]
    g = sns.barplot(data=enr_top5_gs_cluster, x="Combined Score", y=enr_top5_gs_cluster["Term"], hue=enr_top5_gs_cluster["Term"], fill="Adjusted P-value", ax=axs[i], palette=colors, legend=False)
    g.set_xlabel("")
    axs[i].yaxis.set_label_position("right")
    g.set_ylabel(cluster,rotation=270,labelpad=12)
    axs[i].tick_params(axis='x', which='major', pad=-4)

g.set_xlabel("Combined Score")
mappable = mpl.cm.ScalarMappable(cmap=cmap, norm=norm) # create a scalar mappable object
cbar = f.colorbar(mappable, ax=axs, orientation='horizontal',label='-log10(Adjusted P-value)') # add a colorbar to the figure
cbar.ax.set_position([-0.5, -0.4, 1.5, 0.5]) # xmin, ymin, dx, and dy
plt.show()

In [None]:
# GO_Biological_Process_2023
import matplotlib as mpl
import seaborn as sns
import matplotlib.pyplot as plt

f,axs = plt.subplots(len(cluster_names),1)
sns.set(rc={"figure.dpi":300,'figure.figsize':(1, 3)},font_scale=1)
sns.set_theme(style='white', palette='pastel')
plt.subplots_adjust(left=0, bottom=-.1, right=1, top=1.1)
cmap = mpl.colormaps['viridis_r'] # create a color map object
enr_gs = enr[enr["Gene_set"]=='GO_Biological_Process_2023']

for i,cluster in enumerate(cluster_names):
    enr_top5_gs_cluster = enr_gs[enr_gs["Cluster"]==cluster][:5]
    log_adj_pvalue = -np.log10(enr_top5_gs_cluster["Adjusted P-value"])
    norm = mpl.colors.Normalize(vmin=0, vmax=30)
    colors = [cmap(norm(value)) for value in log_adj_pvalue]
    g = sns.barplot(data=enr_top5_gs_cluster, x="Combined Score", y=enr_top5_gs_cluster["Term"], hue=enr_top5_gs_cluster["Term"], fill="Adjusted P-value", ax=axs[i], palette=colors, legend=False)
    g.set_xlabel("")
    axs[i].yaxis.set_label_position("right")
    g.set_ylabel(cluster,rotation=270,labelpad=12)
    axs[i].tick_params(axis='x', which='major', pad=-4)

g.set_xlabel("Combined Score")
mappable = mpl.cm.ScalarMappable(cmap=cmap, norm=norm) # create a scalar mappable object
cbar = f.colorbar(mappable, ax=axs, orientation='horizontal',label='-log10(Adjusted P-value)') # add a colorbar to the figure
cbar.ax.set_position([-0.5, -0.4, 1.5, 0.5]) # xmin, ymin, dx, and dy
plt.show()

### VLMC

In [None]:

import os
from functools import reduce
import numpy as np
import scanpy as sc
import seaborn as sns
import pandas as pd
from scipy.spatial.distance import pdist, squareform

Xs={}
GeneLists={}

result_path="./results/Axolotls/"
samples = [ name for name in sorted(os.listdir(result_path)) if os.path.isdir(os.path.join(result_path, name)) ]
for sample in samples:
    if sample[0].isalpha():
        result_dir = os.path.join(result_path,sample)
        files = os.listdir(os.path.join(result_dir))
        info = sample.split(".")
        chip = info[0]
        ctype = info[1].split('_')[0]
        sample = chip+'_'+ctype
        if 'adata.h5ad' in files and ctype=='VLMC': 
            print(sample)
            adata = sc.read_h5ad(os.path.join(result_dir+"/adata.h5ad"))
            Xs[sample] = adata.uns['X']
            GeneLists[sample] = adata.uns['X'].index

intersected_genes = reduce(np.intersect1d, GeneLists.values())
boolean_arrays = [np.in1d(a, intersected_genes) for a in GeneLists.values()]

D_pattern = pd.DataFrame()
order=0
for s, df in Xs.items():
    print(s)
    distances = pdist(df[boolean_arrays[order]].values, metric='euclidean')
    df_dist = pd.DataFrame(squareform(distances), index=intersected_genes, columns=intersected_genes).stack()
    df_dist = df_dist.reset_index()
    df_dist.columns = ['gene1', 'gene2', 'distance']
    inf = s.split('_')
    df_dist['stage'] = inf[0]
    df_dist['celltype'] = inf[1]
    df_dist['genepair'] = df_dist['gene1'] + "," + df_dist['gene2']
    D_pattern = D_pattern.append(df_dist)
    order+=1

D_pattern.reset_index(drop=True, inplace=True)
matrix_pattern = D_pattern.pivot_table(index='genepair', columns='stage', values='distance')
matrix_pattern.to_csv('Axo_dev_VLMC_pattern.mx.csv')
#### enrichment
Mfuzz_VLMC=pd.read_csv('Mfuzz_membership_AXO_dev_VLMC.xls', sep='\t')


#### enrichment

In [None]:
import pandas as pd
import numpy as np
Mfuzz_tumor=pd.read_csv('Mfuzz_AXO_dev_VLMC.xls', sep='\t')
Mfuzz_tumor_sub=Mfuzz_tumor[Mfuzz_tumor['max'] >0.95]

# unique gene list in each cluster
unique_gs_cluster={}
cluster_names = np.unique(Mfuzz_tumor_sub['Cluster'])
for i in cluster_names:
    Mfuzz_tumor_sub_cluster = Mfuzz_tumor_sub[Mfuzz_tumor_sub['Cluster'] == i]
    gene_list = Mfuzz_tumor_sub_cluster['gene'].apply(lambda x: x.split(','))
    genes = [gene for sublist in gene_list for gene in sublist]
    unique_gs_cluster[i] = set(genes)

In [None]:
# translate gene
new_unique_gs_cluster={}
anno = pd.read_csv("./Axo_Summary_Gene_Annotation_0325.txt", sep='\t')
annodict = {row['Axolotl_ID']: row['hs_gene'] for _, row in anno.iterrows() if row['hs_gene'] != '-'}
for cluster in cluster_names:
    new_unique_gs_cluster[cluster] = {annodict.get(i, i) for i in unique_gs_cluster[cluster]}

In [None]:
# enrichment
import gseapy as gp
enr = pd.DataFrame(columns=['Cluster','Gene_set', 'Term', 'Overlap', 'P-value', 'Adjusted P-value', 'Old P-value',	'Old Adjusted P-value',	'Odds Ratio', 'Combined Score', 'Genes'])

# Enricr API
for key, value in new_unique_gs_cluster.items():
    try:
        for gs in ['KEGG_2021_Human','GO_Molecular_Function_2023','GO_Biological_Process_2023']:
            filterlist = list(filter(lambda x: 'AMEX60DD' not in x, list(value)))
            enr_up = gp.enrichr(gene_list=filterlist,
                                organism='Human',
                                gene_sets=gs,
                                outdir=None, 
                                cutoff=0.05,
                                no_plot= True)

            # trim (go:...)
            enr_up.res2d.Term = enr_up.res2d.Term.str.split(" \(GO").str[0]
            enr_up.res2d['Cluster'] = key
            enr = pd.concat([enr, enr_up.res2d[enr_up.res2d['Adjusted P-value']<0.05]], ignore_index=True)
            # dotplot
            # gp.dotplot(enr_up.res2d, title="Cluster "+str(value), size=10, cmap = plt.cm.viridis_r)
    except ValueError or FileNotFoundError or TypeError as e:
        print(value)
        print(e)
        continue


In [None]:
# KEGG_2021_Human
import matplotlib as mpl
import seaborn as sns
import matplotlib.pyplot as plt

f,axs = plt.subplots(len(cluster_names),1)
sns.set(rc={"figure.dpi":300,'figure.figsize':(1, 3)},font_scale=1)
sns.set_theme(style='white', palette='pastel')
plt.subplots_adjust(left=0, bottom=-.1, right=1, top=1.1)
cmap = mpl.colormaps['viridis_r'] # create a color map object
enr_gs = enr[enr["Gene_set"]=='KEGG_2021_Human']

for i,cluster in enumerate(cluster_names):
    enr_top5_gs_cluster = enr_gs[enr_gs["Cluster"]==cluster][:5]
    log_adj_pvalue = -np.log10(enr_top5_gs_cluster["Adjusted P-value"])
    norm = mpl.colors.Normalize(vmin=0, vmax=30)
    colors = [cmap(norm(value)) for value in log_adj_pvalue]
    g = sns.barplot(data=enr_top5_gs_cluster, x="Combined Score", y=enr_top5_gs_cluster["Term"], hue=enr_top5_gs_cluster["Term"], fill="Adjusted P-value", ax=axs[i], palette=colors, legend=False)
    g.set_xlabel("")
    axs[i].yaxis.set_label_position("right")
    g.set_ylabel(cluster,rotation=270,labelpad=12)
    axs[i].tick_params(axis='x', which='major', pad=-4)

g.set_xlabel("Combined Score")
mappable = mpl.cm.ScalarMappable(cmap=cmap, norm=norm) # create a scalar mappable object
cbar = f.colorbar(mappable, ax=axs, orientation='horizontal',label='-log10(Adjusted P-value)') # add a colorbar to the figure
cbar.ax.set_position([-0.5, -0.4, 1.5, 0.5]) # xmin, ymin, dx, and dy
plt.show()
#f.savefig('Mfuzz_KEGG_Axo_dev_CP.png')

In [None]:
# GO_Molecular_Function_2023
import matplotlib as mpl
import seaborn as sns
import matplotlib.pyplot as plt

f,axs = plt.subplots(len(cluster_names),1)
sns.set(rc={"figure.dpi":300,'figure.figsize':(1, 3)},font_scale=1)
sns.set_theme(style='white', palette='pastel')
plt.subplots_adjust(left=0, bottom=-.1, right=1, top=1.1)
cmap = mpl.colormaps['viridis_r'] # create a color map object
enr_gs = enr[enr["Gene_set"]=='GO_Molecular_Function_2023']

for i,cluster in enumerate(cluster_names):
    enr_top5_gs_cluster = enr_gs[enr_gs["Cluster"]==cluster][:5]
    log_adj_pvalue = -np.log10(enr_top5_gs_cluster["Adjusted P-value"])
    norm = mpl.colors.Normalize(vmin=0, vmax=30)
    colors = [cmap(norm(value)) for value in log_adj_pvalue]
    g = sns.barplot(data=enr_top5_gs_cluster, x="Combined Score", y=enr_top5_gs_cluster["Term"], hue=enr_top5_gs_cluster["Term"], fill="Adjusted P-value", ax=axs[i], palette=colors, legend=False)
    g.set_xlabel("")
    axs[i].yaxis.set_label_position("right")
    g.set_ylabel(cluster,rotation=270,labelpad=12)
    axs[i].tick_params(axis='x', which='major', pad=-4)

g.set_xlabel("Combined Score")
mappable = mpl.cm.ScalarMappable(cmap=cmap, norm=norm) # create a scalar mappable object
cbar = f.colorbar(mappable, ax=axs, orientation='horizontal',label='-log10(Adjusted P-value)') # add a colorbar to the figure
cbar.ax.set_position([-0.5, -0.4, 1.5, 0.5]) # xmin, ymin, dx, and dy
plt.show()

In [None]:
# GO_Biological_Process_2023
import matplotlib as mpl
import seaborn as sns
import matplotlib.pyplot as plt

f,axs = plt.subplots(len(cluster_names),1)
sns.set(rc={"figure.dpi":300,'figure.figsize':(1, 3)},font_scale=1)
sns.set_theme(style='white', palette='pastel')
plt.subplots_adjust(left=0, bottom=-.1, right=1, top=1.1)
cmap = mpl.colormaps['viridis_r'] # create a color map object
enr_gs = enr[enr["Gene_set"]=='GO_Biological_Process_2023']

for i,cluster in enumerate(cluster_names):
    enr_top5_gs_cluster = enr_gs[enr_gs["Cluster"]==cluster][:5]
    log_adj_pvalue = -np.log10(enr_top5_gs_cluster["Adjusted P-value"])
    norm = mpl.colors.Normalize(vmin=0, vmax=30)
    colors = [cmap(norm(value)) for value in log_adj_pvalue]
    g = sns.barplot(data=enr_top5_gs_cluster, x="Combined Score", y=enr_top5_gs_cluster["Term"], hue=enr_top5_gs_cluster["Term"], fill="Adjusted P-value", ax=axs[i], palette=colors, legend=False)
    g.set_xlabel("")
    axs[i].yaxis.set_label_position("right")
    g.set_ylabel(cluster,rotation=270,labelpad=12)
    axs[i].tick_params(axis='x', which='major', pad=-4)

g.set_xlabel("Combined Score")
mappable = mpl.cm.ScalarMappable(cmap=cmap, norm=norm) # create a scalar mappable object
cbar = f.colorbar(mappable, ax=axs, orientation='horizontal',label='-log10(Adjusted P-value)') # add a colorbar to the figure
cbar.ax.set_position([-0.5, -0.4, 1.5, 0.5]) # xmin, ymin, dx, and dy
plt.show()

### overlap

In [None]:
import pandas as pd
import numpy as np
Mfuzz_tumor=pd.read_csv('Mfuzz_AXO_dev_CP.xls', sep='\t')

# unique gene list in each cluster
Mfuzz_genepairs_CP={}
cluster_names = np.unique(Mfuzz_tumor['Pattern'])
for i in cluster_names:
    gp = set(Mfuzz_tumor[Mfuzz_tumor['Pattern'] == i].gene)
    Mfuzz_genepairs_CP[i] = gp
    print(len(gp))


In [None]:
import pandas as pd
import numpy as np
Mfuzz_tumor=pd.read_csv('Mfuzz_AXO_dev_VLMC.xls', sep='\t')


# unique gene list in each cluster
Mfuzz_genepairs_VLMC={}
cluster_names = np.unique(Mfuzz_tumor['Pattern'])
for i in cluster_names:
    gp = set(Mfuzz_tumor[Mfuzz_tumor['Pattern'] == i].gene)
    Mfuzz_genepairs_VLMC[i] = gp
    print(len(gp))

In [None]:
# chi-square test
from scipy.stats import chi2_contingency
from scipy.stats import chi2

# Define the patterns to compare
patterns_to_compare = [
    ('Pattern 3', 'Pattern 1'),
    ('Pattern 2', 'Pattern 2'),
    ('Pattern 1', 'Pattern 3')
]

# Perform chi-square test for each pair of patterns
for cp_pattern, vlmc_pattern in patterns_to_compare:
    # Create a contingency table
    overlap = len(Mfuzz_genepairs_CP[cp_pattern].intersection(Mfuzz_genepairs_VLMC[vlmc_pattern]))
    only_cp = len(Mfuzz_genepairs_CP[cp_pattern]) - overlap
    only_vlmc = len(Mfuzz_genepairs_VLMC[vlmc_pattern]) - overlap
    neither = len(set.union(Mfuzz_genepairs_CP[cp_pattern], Mfuzz_genepairs_VLMC[vlmc_pattern])) - overlap - only_cp - only_vlmc

    contingency_table = [[overlap, only_cp], [only_vlmc, neither]]

    print(contingency_table)
    # Perform chi-square test
    chi2_stat, p_val, dof, ex = chi2_contingency(contingency_table)

    print(f"Chi-square test for CP {cp_pattern} and VLMC {vlmc_pattern}:")
    print(f"Chi2 Stat: {chi2_stat}, P-value: {p_val}\n")

In [None]:
import openchord as ocd

# Find the overlaps
overlap_11 = len(Mfuzz_genepairs_CP['Pattern 1'].intersection(Mfuzz_genepairs_VLMC['Pattern 1']))
overlap_12 = len(Mfuzz_genepairs_CP['Pattern 1'].intersection(Mfuzz_genepairs_VLMC['Pattern 2']))
overlap_13 = len(Mfuzz_genepairs_CP['Pattern 1'].intersection(Mfuzz_genepairs_VLMC['Pattern 3']))
overlap_21 = len(Mfuzz_genepairs_CP['Pattern 2'].intersection(Mfuzz_genepairs_VLMC['Pattern 1']))
overlap_22 = len(Mfuzz_genepairs_CP['Pattern 2'].intersection(Mfuzz_genepairs_VLMC['Pattern 2']))
overlap_23 = len(Mfuzz_genepairs_CP['Pattern 2'].intersection(Mfuzz_genepairs_VLMC['Pattern 3']))
overlap_31 = len(Mfuzz_genepairs_CP['Pattern 3'].intersection(Mfuzz_genepairs_VLMC['Pattern 1']))
overlap_32 = len(Mfuzz_genepairs_CP['Pattern 3'].intersection(Mfuzz_genepairs_VLMC['Pattern 2']))
overlap_33 = len(Mfuzz_genepairs_CP['Pattern 3'].intersection(Mfuzz_genepairs_VLMC['Pattern 3']))

# The input should be a matrix showing the relation (overlap) between each gene list
matrix = [
  [0, 0, 0, overlap_11, overlap_12, overlap_13],
  [0, 0, 0, overlap_21, overlap_22, overlap_23],
  [0, 0, 0, overlap_31, overlap_32, overlap_33],
  [overlap_11, overlap_21, overlap_31, 0, 0, 0],
  [overlap_12, overlap_22, overlap_32, 0, 0, 0],
  [overlap_13, overlap_23, overlap_33, 0, 0, 0]
]
print(matrix)
names = ['CP Pattern 1', 'CP Pattern 2', 'CP Pattern 3', 'VLMC Pattern 1', 'VLMC Pattern 2', 'VLMC Pattern 3']
fig = ocd.Chord(matrix, names)
fig.colormap = ["#FF6B6B", "#F9844A", "#F9C74F", "#90BE6D", "#43AA8B", "#4D908E", "#577590", "#277DA1"]
fig.radius = 80
fig.padding = 80
bg_transparancy=1

fig.save_svg("figures/chord_patterns_vlmc&cp.svg")

## 5. Corr

In [None]:
import scanpy as sc
adatas={}
celltypes = {}

result_path="./results/Axolotls/"
samples = [ name for name in sorted(os.listdir(result_path)) if os.path.isdir(os.path.join(result_path, name)) ]
for sample in samples:
    if sample[0].isalpha():
        result_dir = os.path.join(result_path,sample)
        files = os.listdir(os.path.join(result_dir))
        info = sample.split(".")
        chip = info[0]
        ctype = info[1].split('_')[0]
        sample = chip+' '+ctype
        if chip not in celltypes:
            celltypes[chip] = []
        celltypes[chip].append(ctype)
        if 'adata.h5ad' in files and ('CP' in sample or 'VLMC' in sample):
            adatas[sample] = sc.read_h5ad(os.path.join(result_dir+"/adata.h5ad"))
            #GeneLists[sample] = adatas[sample].uns['X'].index
            #Xs[sample] = adatas[sample].uns['X']

In [None]:
adatas[sample].uns['X']

In [None]:
gtf = pd.read_csv("./notebook_yf/ReST3D/main/src/cytocraft/gtf/AmexT_v47-AmexG_v6.0-DD.gene.gtf", sep='\t', comment='#', header=None)
gtf = gtf.rename(columns=lambda x: 'col_' + str(x))
# get position of genes
gene_chr={}
gene_length={}
gene_length_cbrt={}
for _, row in gtf.query("col_2 == 'gene'").iterrows():
    key = row['col_8'].split(';')[1].split()[1].strip('"')
    gene_chr[key] = row['col_0']
    gene_length[key] = int(row['col_4'])-int(row['col_3'])
    gene_length_cbrt[key] = np.cbrt(gene_length[key])

In [None]:
def gene_gene_distance_matrix(F):
    GeneList = list(F.index)
    N = len(GeneList)
    DM = np.zeros((N, N))
    F_values = F.values  # Convert DataFrame to numpy array for faster access
    for n in range(N):
        for m in range(n + 1):
            d = np.linalg.norm(F_values[n] - F_values[m])
            DM[n, m] = DM[m, n] = d
    return DM

In [None]:
from tqdm import tqdm

thresholds = [0.2, 0.3, 0.5, 1, 2]
# Initialize an empty dataframe to store the results
combined_df = pd.DataFrame(columns=['sample', 'celltype', 'gene', 'thresh', 'num_neighbors', 'length'])

# Precompute the distance matrices for all cell types
distance_matrices = {cell_type: gene_gene_distance_matrix(adatas[cell_type].uns['X']) for cell_type in tqdm(adatas.keys())}

In [None]:
from tqdm import tqdm
import pandas as pd
import numpy as np


#thresholds = [0.2, 0.3, 0.5, 1, 2]
# Initialize an empty dataframe to store the results
#combined_df = pd.DataFrame(columns=['sample', 'celltype', 'gene', 'thresh', 'num_neighbors', 'length'])

# Precompute the distance matrices for all cell types
#distance_matrices = {cell_type: gene_gene_distance_matrix(adatas[cell_type].uns['X']) for cell_type in tqdm(adatas.keys())}

# Iterate over each cell type in GeneLists
for sample in tqdm(adatas.keys()):
    #cell_type = sample+' '+ct
    # Get the gene list and distance matrix for the current cell type
    gene_list = adatas[sample].uns['X'].index
    distance_matrix = distance_matrices[sample]
    
    # Filter the gene list based on the available gene lengths
    filtered_gene_list = [gene for gene in gene_list if gene in gene_length_cbrt.keys()]
    filtered_gene_indices = [gene_list.get_loc(gene) for gene in filtered_gene_list]
    
    # Calculate the number of neighbors for each gene using different thresholds
    for threshold in thresholds:
        num_neighbors = np.sum(distance_matrix[filtered_gene_indices][:, filtered_gene_indices] < threshold, axis=1) - 1
        
        # Create a dataframe with the current results
        
        df = pd.DataFrame({
            'sample': sample,
            #'celltype': ct,
            'gene': filtered_gene_list,
            'thresh': [threshold] * len(filtered_gene_list),
            'num_neighbors': num_neighbors,
            'length': [gene_length_cbrt[gene] for gene in filtered_gene_list]
        })
        
        # Append the current dataframe to the combined dataframe
        combined_df = pd.concat([combined_df, df], ignore_index=True)

# Print the combined dataframe
print(combined_df)

In [None]:
combined_df

In [None]:
thresholds = [0.2, 0.3, 0.5, 1, 2]
#combined_df['celltype'] = combined_df['celltype'].replace({
#    'Cardiomyocyte': 'Cardiomyocyte', 'Chondrocyte': 'Chondrocyte', 'Choroid_plexus': 'Choroid plexus', 
#    'Dorsal_midbrain_neuron': 'Dorsal Mb neuron', 'Endothelial_cell': 'Endothelial cell', 
#    'Epithelial_cell': 'Epithelial cell', 'Erythrocyte': 'Erythrocyte', 'Facial_fibroblast': 'Facial fibroblast', 
#    'Fibroblast': 'Fibroblast', 'Forebrain_neuron': 'Fb neuron', 'Forebrain_radial_glia_cell': 'Fb radial glia cell', 
#    'Ganglion': 'Ganglion', 'Hepatocyte': 'Hepatocyte', 'Immune_cell': 'Immune cell', 'Keratinocyte': 'Keratinocyte', 
#    'Limb_fibroblast': 'Limb fibroblast', 'Macrophage': 'Macrophage', 'Meninges_cell': 'Meninges cell', 
#    'Mid-hindbrain_and_spinal_cord_neuron': 'Mb/Hb/SpC neuron', 'Myoblast': 'Myoblast', 
#    'Olfactory_epithelial_cell': 'Olfactory epithelial cell', 'Radial_glia_cell': 'Radial glia cell', 
#    'Smooth_muscle_cell': 'Smooth muscle cell', 'Spinal_cord_neuron': 'SpC neuron', 'Thalamus_neuron': 'Thalamus neuron'
#})
#
#palette = {
#    'Limb fibroblast': '#809693', 'Keratinocyte': '#A30059', 'Endothelial cell': '#006FA6', 
#    'Facial fibroblast': '#63FFAC', 'Meninges cell': '#4FC601', 'Mb/Hb/SpC neuron': '#3B5DFF', 
#    'SpC neuron': '#FFAA92', 'Myoblast': '#FF2F80', 'Choroid plexus': '#FF34FF', 'Cardiomyocyte': '#8B0000', 
#    'Chondrocyte': '#1CE6FF', 'Dorsal Mb neuron': '#FF92CC', 'Fibroblast': '#FF4A46', 'Thalamus neuron': '#B79762', 
#    'Erythrocyte': '#FFE4E1', 'Smooth muscle cell': '#00C2A0', 'Macrophage': '#1B4400', 'Fb neuron': '#004D43', 
#    'Hepatocyte': '#997D87', 'Ganglion': '#800080', 'Olfactory epithelial cell': '#008941', 'Immune cell': '#5A0007', 
#    'Radial glia cell': '#6B7900', 'Epithelial cell': '#0000A6', 'Fb radial glia cell': '#8FB0FF'
#}

import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import spearmanr

# Calculate the correlation between gene length and number of proximal genes for each threshold
correlations = []
for threshold in thresholds:
    subset = combined_df[combined_df['thresh'] == threshold]
    correlation = subset.groupby('sample').apply(lambda x: spearmanr(x['length'], x['num_neighbors'].astype(float))[0])
    correlations.append(correlation)

# Create a DataFrame for plotting
correlation_df = pd.concat(correlations, axis=1).T

# Ensure columns have more than one level before dropping a level
if correlation_df.columns.nlevels > 1:
    correlation_df.columns = correlation_df.columns.droplevel(0)

correlation_df['Threshold'] = np.array(thresholds) * 0.5

# Melt the DataFrame for seaborn
correlation_df_melted = correlation_df.melt(id_vars='Threshold', var_name='Celltype', value_name='Correlation')

# Plot the relation of threshold and the correlation
plt.figure(figsize=(4, 4))
sns.set_theme(style='ticks')
sns.lineplot(data=correlation_df_melted, x='Threshold', y='Correlation', hue='Celltype', marker='o', #palette=palette,
 legend=True)
plt.axhline(0, color='black', linestyle='--')  # Highlight the y=0 axis
plt.xlabel('Radius (μm)')
plt.ylabel('Correlation')
#plt.ylabel('Correlation between Cube Root of Gene Length\nand Number of Proximal Genes')
plt.title('')
plt.grid(True)
plt.show()
