In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.cm import ScalarMappable


In [None]:
upregulated = pd.read_csv("../../results/gene_enrichment_analysis/kegg_enrich_upregulated.csv")
downregulated = pd.read_csv("../../results/gene_enrichment_analysis/kegg_enrich_downregulated.csv")

In [None]:
upregulated.head()

In [None]:
sig_upregulated = upregulated[upregulated["p.adjust"] <= 0.05]
sig_downregulated = downregulated[downregulated["p.adjust"] <= 0.05]

In [None]:
counts_downregulated = sig_downregulated["GeneRatio"].apply(lambda x: int(x.split("/")[0])).to_list()
counts_pop_downregulated = sig_downregulated["BgRatio"].apply(lambda x: int(x.split("/")[0])).to_list()
ratio_downregulated = [x/y for x,y in zip(counts_downregulated, counts_pop_downregulated)]
labels_downregulated = sig_downregulated["Description"].to_list()
padjust_downregulated = sig_downregulated["p.adjust"].to_list()
id_downregulated = sig_downregulated["ID"].to_list()

counts_downregulated = [-x for x in counts_downregulated]


counts_upregulated = sig_upregulated["GeneRatio"].apply(lambda x: int(x.split("/")[0])).to_list()
counts_pop_upregulated  = sig_upregulated["BgRatio"].apply(lambda x: int(x.split("/")[0])).to_list()
ratio_upregulated = [x/y for x,y in zip(counts_upregulated, counts_pop_upregulated)]
labels_upregulated  = sig_upregulated["Description"].to_list()
padjust_upregulated = sig_upregulated["p.adjust"].to_list()
id_upregulated = sig_upregulated["ID"].to_list()

ids = id_downregulated + id_upregulated


counts = counts_downregulated + counts_upregulated
ratio = ratio_downregulated + ratio_upregulated
labels = labels_downregulated + labels_upregulated
labels = [x.replace("ko","") + " " + y for x,y in zip(ids, labels)]

pvals = padjust_downregulated + padjust_upregulated
marker_size = counts_pop_downregulated + counts_pop_upregulated


In [None]:
pcolors = pvals
norm_p_values = np.array(pcolors) / max(pcolors)
colors=plt.cm.RdBu_r(norm_p_values)

In [None]:
#marker_legend = np.linspace(min(marker_size), max(marker_size), num=6, endpoint=True, retstep=False, dtype=None)
#marker_legend = [int(x) for x in marker_legend]
marker_legend = marker_size
marker_legend = sorted(marker_legend, reverse=True)

In [None]:
sorted_markersize = sorted(np.array(marker_legend)*10, reverse=True)
sorted_markers = sorted(marker_legend, reverse=True)

In [None]:
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 10), sharey=False, gridspec_kw={'width_ratios': [2.5, 2.0, 1]})

ax1.sharey(ax2)

ax1.barh(labels,width=counts, color=colors, edgecolor="black")
ax2.scatter(y=[0,1,2,3,4,5], x=ratio, s=np.array(marker_size)*10, c=colors, cmap='RdBu_r', edgecolors="black")

ax1.set_xlim(-50,50)
#ax1.set_xticklabels(fontsize=16)
ax1.set_xticks([-40,-30,-20,-10,0,10,20,30,40])
ax1.set_xticklabels([40,30,20,10,0,10,20,30,40], fontsize=20)
ax1.set_yticks([0,1,2,3,4,5])
ax1.set_yticklabels(labels, fontsize=20, fontdict={'fontweight':"bold"})
ax1.set_xlabel("Count", fontsize=20, labelpad=20,  fontdict={'fontweight':"bold"})
ax1.grid()

ax2.set_xlim(-0.2,1.2)
ax2.set_xticks([0.0,0.25,0.5,0.75,1.0])
ax2.set_xticklabels([0.0,0.25,0.5,0.75,1.0], fontsize=20)
ax2.set_xlabel("Count/ClusterSize", fontsize=20, labelpad=20,  fontdict={'fontweight':"bold"})
ax2.tick_params(axis='y', labelleft=False)
ax2.grid()

plt.subplots_adjust(left=0.2, wspace=0.05)

cbar = fig.colorbar(ScalarMappable(cmap='RdBu_r'), ax=[ax1, ax2], pad = 0.005)
cbar.set_label('p-values',fontsize=20, labelpad=0)
cbar.set_ticks([min(norm_p_values), max(norm_p_values)])
cbar.set_ticklabels([f'{min(pvals):.2f}', f'{max(pvals):.2f}'])
cbar.ax.tick_params(labelsize=20)

cbar.ax.set_position([0.71, 0.12, 0.03, 0.35])

ax3.set_position([0.71, 0.52, 0.05, 0.315])
ax3.scatter(x=[0.5 for i in labels], y=[0,1,2,3,4,5], s=sorted_markersize, color="grey", edgecolor="black")
ax3.set_yticks([0,1,2,3,4,5])
ax3.set_yticklabels(sorted_markers, fontsize=20)
ax3.set_xlim(0,1)
ax3.set_ylim(-1,6)
ax3.tick_params(axis='y', labelright=True, labelleft=False)
ax3.tick_params(axis='x', labelbottom=False)
ax3.set_xticks([])
ax3.yaxis.tick_right()   
ax3.set_title("ClusterSize", pad=10, fontsize=20)
#ax3.grid()
#ax3.axis("off")
plt.savefig("../../results/gene_enrichment_analysis/up_down_combined.jpg", dpi=400, bbox_inches='tight')
#plt.tight_layout()

In [None]:
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 10), sharey=False, gridspec_kw={'width_ratios': [2.5, 2.0, 1]})
ax1.sharey(ax2)

ax1.barh(labels,width=counts, color=colors, edgecolor="black")
ax2.scatter(y=[0,1,2,3,4,5], x=ratio, s=np.array(abs(np.array(counts)))*40, c=colors, cmap='RdBu_r', edgecolors="black")

ax1.set_xlim(-50,50)
#ax1.set_xticklabels(fontsize=16)
ax1.set_xticks([-40,-30,-20,-10,0,10,20,30,40])
ax1.set_xticklabels([40,30,20,10,0,10,20,30,40], fontsize=20)
ax1.set_yticks([0,1,2,3,4,5])
ax1.set_yticklabels(labels, fontsize=20, fontdict={'fontweight':"bold"})
ax1.set_xlabel("Count", fontsize=20, labelpad=20,  fontdict={'fontweight':"bold"})
ax1.grid()

#for spine in ax1.spines.values():
#    spine.set_linewidth(2)

ax2.set_xlim(-0.2,1.2)
ax2.set_xticks([0.0,0.25,0.5,0.75,1.0])
ax2.set_xticklabels([0.0,0.25,0.5,0.75,1.0], fontsize=20)
ax2.set_xlabel("Count/ClusterSize", fontsize=20, labelpad=20,  fontdict={'fontweight':"bold"})
ax2.tick_params(axis='y', labelleft=False)
ax2.grid()


#for spine in ax2.spines.values():
#    spine.set_linewidth(2)

plt.subplots_adjust(left=0.2, wspace=0.05)

cbar = fig.colorbar(ScalarMappable(cmap='RdBu_r'), ax=[ax1, ax2], pad = 0.005)
cbar.set_label('p-values',fontsize=20, labelpad=0)
cbar.set_ticks([min(norm_p_values), max(norm_p_values)])
cbar.set_ticklabels([f'{min(pvals):.2f}', f'{max(pvals):.2f}'])
cbar.ax.tick_params(labelsize=20)

cbar.ax.set_position([0.71, 0.12, 0.03, 0.3])

ax3.set_position([0.71, 0.45, 0.05, 0.4])
ax3.scatter(x=[0.5 for i in labels], y=[0,1,2,3,4,5], s=sorted(abs(np.array(counts))*40, reverse=True), color="grey", edgecolor="black")
ax3.set_yticks([0,1,2,3,4,5])
ax3.set_yticklabels(sorted(abs(np.array(counts)), reverse=True), fontsize=20)
ax3.set_xlim(0,1)
ax3.set_ylim(-1,6)
ax3.tick_params(axis='y', labelright=True, labelleft=False)
ax3.tick_params(axis='x', labelbottom=False)
ax3.set_xticks([])
ax3.yaxis.tick_right()   
ax3.set_title("Count", pad=10, fontsize=20)
#ax3.grid()
#ax3.axis("off")

#for spine in ax3.spines.values():
#    spine.set_linewidth(2)

plt.savefig("../../results/gene_enrichment_analysis/up_down_combined_new.jpg", dpi=400, bbox_inches='tight')
#plt.tight_layout()