In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.cm import ScalarMappable

In [None]:
kegg_dataframe = pd.read_csv("../../results/gene_enrichment_analysis/kegg_enrich_symbiotic.csv")
#kegg_dataframe.head()

In [None]:
filtered_dataframe = kegg_dataframe[kegg_dataframe["p.adjust"] <= 0.05]
#filtered_dataframe.head()

In [None]:
kegg_id_labels = filtered_dataframe["ID"].to_numpy()
kegg_descriptions_labels = filtered_dataframe["Description"].to_numpy()
kegg_pvalue = filtered_dataframe["p.adjust"].to_numpy()
kegg_counts = filtered_dataframe["Count"].to_numpy()
kegg_clustersize = np.array(filtered_dataframe["BgRatio"].apply(lambda x: int(x.split("/")[0])))

In [None]:
go_dataframe = pd.read_table("../../results/gene_enrichment_analysis/symbiotic_genes/symbiotic_proteins_enrichment.txt")
filtered_go_dataframe = go_dataframe[go_dataframe["p_fdr_bh"] <= 0.005]

filtered_terms = []
for study_items, go_term, count in zip(filtered_go_dataframe.study_items,filtered_go_dataframe["# GO"], filtered_go_dataframe["study_count"]):
    
    if count > 0:
        study_items = study_items.split(",")
        study_items = [item.strip() for item in study_items]


        for study_items2, go_term2, count2 in zip(filtered_go_dataframe.study_items, filtered_go_dataframe["# GO"], filtered_go_dataframe["study_count"]):
            if count2 > 0:
                if go_term != go_term2:
                    study_items2 = study_items2.split(",")
                    study_items2 = [item.strip() for item in study_items2]

                    if sorted(study_items2) == sorted(study_items):
                        goterms = [go_term, go_term2]
                        sorted_goterms = sorted(goterms)
                        if sorted_goterms not in filtered_terms:
                            filtered_terms.append(sorted_goterms)

# these GO-terms 
gos_to_remove = []
for godouble in filtered_terms:
    for go in godouble:
        if go != "GO:1901616":
            gos_to_remove.append(go)
            
gos_to_keep = []
for go in filtered_go_dataframe["# GO"].to_list():
    if go not in gos_to_remove:
        gos_to_keep.append(go)

In [None]:
sorted_filtered_go_dataframe = filtered_go_dataframe[filtered_go_dataframe["# GO"].isin(gos_to_keep)].sort_values(by="p_fdr_bh", ascending=True)

In [None]:
sorted_filtered_go_dataframe["pop_count"] = sorted_filtered_go_dataframe["ratio_in_pop"].apply(lambda x: int(x.split("/")[0]))
sorted_filtered_go_dataframe["count"] = sorted_filtered_go_dataframe["ratio_in_study"].apply(lambda x: int(x.split("/")[0]))

In [None]:
for name,item in zip(go_dataframe["name"],go_dataframe["study_items"]):
    try:
        if "gene:AEP_03479" in item:
            print(name)
    except:
        continue

In [None]:
sorted_filtered_go_dataframe = sorted_filtered_go_dataframe[(sorted_filtered_go_dataframe["pop_count"] <= 350)&(sorted_filtered_go_dataframe["ratio_in_study"] != "0/274")]

In [None]:
sorted_filtered_go_dataframe = sorted_filtered_go_dataframe[sorted_filtered_go_dataframe["name"] != "nucleobase-containing compound metabolic process"]
sorted_filtered_go_dataframe = sorted_filtered_go_dataframe[sorted_filtered_go_dataframe["name"] != "cellular biosynthetic process"]

In [None]:
sorted_filtered_go_dataframe

In [None]:
sorted_filtered_go_dataframe = sorted_filtered_go_dataframe.loc[reversed([24,16,0,20,22,5,4])]

In [None]:
go_id_labels = sorted_filtered_go_dataframe["# GO"].to_numpy()
go_descriptions_labels = sorted_filtered_go_dataframe["name"].to_numpy()
go_pvalue = sorted_filtered_go_dataframe["p_fdr_bh"].to_numpy()
go_counts = sorted_filtered_go_dataframe["study_count"].to_numpy()
go_clustersize = sorted_filtered_go_dataframe["pop_count"].to_numpy()

In [None]:
ylabels_kegg = []
for KO, name in zip(kegg_id_labels,kegg_descriptions_labels):
    ylabels_kegg.append(KO.split("ko")[1] + " " + name)

ylabels_kegg.append(" ")
ylabels_go = []
for GO, name in zip(go_id_labels, go_descriptions_labels):
    ylabels_go.append(GO.split(":")[1] + " " + name)

In [None]:
kegg_counts_new = np.append(kegg_counts, 0)
kegg_clustersize_new = np.append(kegg_clustersize, 0)
counts = np.append(kegg_counts_new, go_counts)
marker_size = counts
clustersize = np.append(kegg_clustersize_new, go_clustersize)
ratio = counts/clustersize
labels = ylabels_kegg + ylabels_go

kegg_pvalue_new = np.append(kegg_pvalue, 0)
pvals = np.append(kegg_pvalue_new, go_pvalue)
pcolors = pvals
norm_p_values = np.array(pcolors) / max(pcolors)
colors=plt.cm.RdBu_r(norm_p_values)

sorted_arr = np.sort(marker_size[marker_size != 0])
indices = np.linspace(0, len(sorted_arr) - 1, 6).astype(int)
reduced_arr = sorted_arr[indices]

fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 10), sharey=False, gridspec_kw={'width_ratios': [2.5, 2.0, 1]})

ax1.sharey(ax2)

ax1.barh(labels,width=counts, color=colors, edgecolor="black")
ax2.scatter(y=list(range(len(counts))), x=ratio, s=np.array(marker_size)*10, c=colors, cmap='RdBu_r', edgecolors="black")


ax1.axhline(y=4, color='red', linestyle='--', linewidth=2)
#ax1.set_xlim(-50,50)
ax1.set_xticks([0,5,10,20,30,40,50,60])
ax1.set_xticklabels([0,5,10,20,30,40,50,60], fontsize=16)

#ax1.set_xticklabels(, fontsize=20)
ax1.set_yticks(range(len(labels)))
ax1.set_yticklabels(labels, fontsize=20, fontdict={'fontweight':"bold"})
ax1.set_xlabel("Count", fontsize=20, labelpad=20,  fontdict={'fontweight':"bold"})
ax1.grid()

ax2.axhline(y=4, color='red', linestyle='--', linewidth=2)

ax2.set_xlim(-0.1,1.1)
ax2.set_xticks([0.0,0.25,0.5,0.75,1.0])
ax2.set_xticklabels([0.0,0.25,0.5,0.75,1.0], fontsize=14)
ax2.set_xlabel("Count/ClusterSize", fontsize=20, labelpad=20,  fontdict={'fontweight':"bold"})
ax2.tick_params(axis='y', labelleft=False)
ax2.grid()

plt.subplots_adjust(left=0.2, wspace=0.05)

cbar = fig.colorbar(ScalarMappable(cmap='RdBu_r'), ax=[ax1, ax2], pad = 0.005)
cbar.set_label('p-values',fontsize=20, labelpad=0)
cbar.set_ticks([min(norm_p_values), max(norm_p_values)])
cbar.set_ticklabels([f'{min(pvals):.2f}', f'{max(pvals):.2f}'])
cbar.ax.tick_params(labelsize=20)

cbar.ax.set_position([0.71, 0.12, 0.03, 0.35])

ax3.set_position([0.71, 0.52, 0.05, 0.315])
ax3.scatter(x=[0.5 for i in reduced_arr], y=[0,1,2,3,4,5], s=np.array(reduced_arr)*10, color="grey", edgecolor="black")
ax3.set_yticks([0,1,2,3,4,5])
ax3.set_yticklabels(reduced_arr, fontsize=20)
ax3.set_xlim(0,1)
ax3.set_ylim(-1,6)
ax3.tick_params(axis='y', labelright=True, labelleft=False)
ax3.tick_params(axis='x', labelbottom=False)
ax3.set_xticks([])
ax3.yaxis.tick_right()   
ax3.set_title("Count", pad=10, fontsize=20)

plt.text(-21, 2, 'GO', fontsize=20, rotation=90, va='center', fontdict={'fontweight':"bold"})
plt.text(-21, -6, 'KEGG', fontsize=20, rotation=90, va='center', fontdict={'fontweight':"bold"})

#ax1.set_ylabel("KO Categories                GO Categories", fontsize=16)
#ax3.grid()
#ax3.axis("off")
#plt.tight_layout()
plt.savefig("../../results/gene_enrichment_analysis/symbiotic_genes/symbiotic_kegg_go_combined.jpg", dpi=400, bbox_inches='tight')
