In [None]:
import pandas as pd
import numpy as np
from os.path import isfile
import matplotlib.pyplot as plt
from matplotlib.cm import ScalarMappable
import textwrap

In [None]:
result_path = "../../results/gene_enrichment_analysis/goatools/"
data_path = "../../results/rsem/"
samples_path = "../../data/rsem/samples.csv"
samples_dataframe = pd.read_csv(samples_path)

# Preparation of input files for goatools

In [None]:
dataframes = []
for unique_sample in samples_dataframe.treatment.unique():
    for unique_sample_2 in samples_dataframe.treatment.unique():
        if unique_sample != unique_sample_2:
            file = data_path + unique_sample + "_vs_" + unique_sample_2 + ".CSV"
            if isfile(file):
                dataframes.append(file)
dataframes

In [None]:
curvibacter_genes_df = pd.read_csv("../../results/curvibacter_genome_annotation.csv")
curvibacter_genes_df["gene_id"] = curvibacter_genes_df["gene_id"].apply(lambda x: "gene:"+x)
curvibacter_genes_df.head()

In [None]:
# writing population and associations file 
go_gene_ids = []
with open(result_path + "associations.txt", "w") as associations_file:
    with open(result_path + "populations.txt", "w") as populations_file:
        for gene_id in curvibacter_genes_df["gene_id"]:
            if curvibacter_genes_df[curvibacter_genes_df["gene_id"] == gene_id].GO.values[0] != "unknown":
                
                go_gene_ids.append(gene_id)
                populations_file.write(gene_id+"\n")
                associations_file.write(gene_id + "\t")
                counter = 0
                gos = curvibacter_genes_df[curvibacter_genes_df["gene_id"] == gene_id].GO.values[0].split(",")
                for go_id in gos:
                    
                    if counter != len(gos)-1:
                        associations_file.write(go_id+";")
                    else:
                        associations_file.write(go_id+"\n")
                    counter += 1

In [None]:
# writing sample files for up/down regulated genes
go_files = []
for df in dataframes:
    print("[+] Working with: {}".format(df))
    log2folddf = pd.read_csv(df)
    log2folddf.columns = ["gene_id","baseMean","log2FoldChange","lfcSE","stat","pvalue","padj"]
    log2folddf = log2folddf[log2folddf["padj"] <= 0.05]
    downregulated_genes = log2folddf[log2folddf["log2FoldChange"] <= -1.0]
    upregulated_genes = log2folddf[log2folddf["log2FoldChange"] >= 1.0]

    downregulated_genes = downregulated_genes[downregulated_genes["gene_id"].isin(go_gene_ids)]
    upregulated_genes = upregulated_genes[upregulated_genes["gene_id"].isin(go_gene_ids)]
    
    
    print("\t[*] Length of downregulated genes: {}".format(len(downregulated_genes)))
    print("\t[*] Length of upregulated genes: {}".format(len(upregulated_genes)))
    
    sample = df.split("/")[-1].split(".CSV")[0] 
    sample_up = result_path + sample + "_upregulated_genes_goatools.txt"
    sample_down = result_path + sample + "_downregulated_genes_goatools.txt"
    
    if len(downregulated_genes) > 5:
        go_files.append(sample_down)
        with open(sample_down, "w") as goadown:
            for gene_id in downregulated_genes["gene_id"]:
                goadown.write(gene_id+"\n")
    if len(upregulated_genes) > 5:
        go_files.append(sample_up)
        with open(sample_up, "w") as goaup:
            for gene_id in upregulated_genes["gene_id"]:
                goaup.write(gene_id+"\n")

In [None]:
goafiles = []
for samplefile in go_files:
    print("[+] Working with {}".format(samplefile))
    outfile = result_path + samplefile.split("/")[-1].split(".txt")[0] + "_output.table"
    
    !find_enrichment.py $samplefile ../../results/gene_enrichment_analysis/goatools/populations.txt ../../results/gene_enrichment_analysis/goatools/associations.txt --annofmt id2gos --alpha 0.05 --pval 0.05 --obo ../../results/gene_enrichment_analysis/goatools/go-basic.obo --method fdr_bh --outfile $outfile --obsolete replace > /dev/null
    
    goafiles.append(outfile)
    print("[*] DONE")

In [None]:
def plot_goa(goafile_enriched:pd.DataFrame,savep:str, filename:str):
    print("[*] Producing plot for {}".format(filename))
    goafile_enriched["ratio_stud"] = goafile_enriched.ratio_in_study.apply(lambda x: int(x.split("/")[0])/int(x.split("/")[1]))
    goafile_enriched["ratio_pop"] = goafile_enriched.ratio_in_pop.apply(lambda x: int(x.split("/")[0])/int(x.split("/")[1]))
    goafile_enriched["amount_in_pop"] = goafile_enriched.ratio_in_pop.apply(lambda x: int(x.split("/")[0]))
    
    categorical_indices = []
    categories = []
    for index, cat in enumerate(list(goafile_enriched.name)):
        if len(cat) >= 30:
            cat = textwrap.fill(cat, width=30)
            categorical_indices.append(index)
        categories.append(cat)
        
    values = list(goafile_enriched.study_count)
    scatter_values = np.array(goafile_enriched.study_count) / np.array(goafile_enriched.amount_in_pop)
    
    pcolors = goafile_enriched.p_fdr_bh
    norm_p_values = np.array(pcolors) / max(pcolors)
    colors=plt.cm.RdBu_r(norm_p_values)
    
    
    # Create figure and axes
    if len(goafile_enriched) == 30:
        fsize = (20,18)
    elif len(goafile_enriched) >= 15:
        fsize = (16,12)
    else:
        fsize = (12,8)
        
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 16), sharey=True)

    # Plot horizontal bar plot on ax1
    ax1.barh(categories, values, color=colors, edgecolor="black")

    ax1.set_xlabel('Count', fontsize=15, labelpad=20)
    ax1.tick_params(axis='y', labelsize=15) 
    ax1.tick_params(axis='x', labelsize=15) 
    
    
    # Accessing tick labels
    tick_labels = ax1.get_yticklabels()

    # Specify indices of labels to make bold (e.g., 1 and 3 in this example)

    # Update tick labels with LaTeX formatting for bold
    for i, label in enumerate(tick_labels):
        if i in categorical_indices:
            #label.set_fontweight('bold')
            label.set_fontsize(10)  # Optional: Adjust font size if needed
            #label.set_color('blue')  # Optional: Adjust font color if needed
            # Use LaTeX for bold formatting
            #label.set_text(r'\textbf{' + label.get_text() + r'}')
    
    
    ax2.scatter(scatter_values, categories, c=colors, cmap='RdBu_r', 
                label='Gene Ratio (compared to Study)', s=list(goafile_enriched.ratio_stud*1000),edgecolor="black")

    ax2.set_xlabel('Count in Study / Count in Pop', fontsize=15, labelpad=20)
    
    ax2.tick_params(axis='x', labelsize=15) 
    #ax1.set_ylabel('GO Categories')
    ax1.invert_yaxis()
    plt.subplots_adjust(left=0.2, wspace=0.1)
    cbar = fig.colorbar(ScalarMappable(cmap='RdBu_r'), ax=[ax1, ax2], pad = 0.005)
    cbar.set_label('p-values',fontsize=15, labelpad=20)
    cbar.set_ticks([min(norm_p_values), max(norm_p_values)])
    cbar.set_ticklabels([f'{min(goafile_enriched.p_fdr_bh):.4f}', f'{max(goafile_enriched.p_fdr_bh):.4f}'])
    cbar.ax.tick_params(labelsize=12)

    cbar.ax.set_position([0.85, 0.15, 0.03, 0.7])
    
    #plt.show()
    plt.savefig(savep + filename + ".jpg", dpi=400)
    plt.close()
    print("[*] DONE")

In [None]:
for goafile in goafiles:
    if isfile(goafile):
        goafigure = goafile.split("/")[-1].split("_goatools_output.table")[0]
        dataframe = pd.read_table(goafile)
        plot_goa(dataframe, result_path, goafigure)

# Goatools on symbiotic protein ids

In [None]:
translation_table = pd.read_table("../../data/curvibacter_annotation_files/translation_table_corrected.csv")
translation_table.head()

In [None]:
log2Fold_df = pd.read_csv('../../results/rsem/liquid_mono_culture_orgint_vs_metatranscriptome.CSV')
#log2Fold_df = pd.read_csv('../../results/rsem/liquid_mono_culture_kiel_vs_hydra_mono_culture_kiel.csv')
log2Fold_df.columns = ["old_locus_tag","baseMean","log2FoldChange","lfcSE","stat","pvalue","padj"]
log2Fold_df.head()

In [None]:
log2Fold_df = log2Fold_df[log2Fold_df["padj"] <= 0.05]
print(len(log2Fold_df))

In [None]:
symbiotic_protein_ids = pd.read_csv("../../data/symbiotic_wps/caep_symbiotic_wps.txt", header=None)
symbiotic_protein_ids.columns = ["protein_id"]
merged_table = translation_table.merge(symbiotic_protein_ids, on="protein_id")
merged_table["old_locus_tag"] = merged_table["old_locus_tag"].apply(lambda x: "gene:"+x)

In [None]:
curvibacter_specific_protein_ids = pd.read_csv("../../data/symbiotic_wps/caep_specific_wps.txt", header=None)
curvibacter_specific_protein_ids.columns = ["protein_id"]
specific_merged_table = translation_table.merge(curvibacter_specific_protein_ids, on="protein_id")
specific_merged_table["old_locus_tag"] = specific_merged_table["old_locus_tag"].apply(lambda x: "gene:"+x)

In [None]:
pop_ids = []
with open("../../results/gene_enrichment_analysis/goatools/populations.txt","r") as populationfile:
    for line in populationfile:
        pop_ids.append(line.strip())

In [None]:
with open("../../data/symbiotic_wps/go_gene_id_aep_specific_goatools.txt", "w") as goatools:
    for gene in specific_merged_table[specific_merged_table["old_locus_tag"].isin(pop_ids) == True]["old_locus_tag"]:
        goatools.write(gene.strip()+"\n")

In [None]:
!find_enrichment.py ../../data/symbiotic_wps/go_gene_id_goatools.txt ../../results/gene_enrichment_analysis/goatools/populations.txt ../../results/gene_enrichment_analysis/goatools/associations.txt --annofmt id2gos --alpha 0.05 --pval 0.05 --obo ../../results/gene_enrichment_analysis/goatools/go-basic.obo --method fdr_bh --outfile ../../results/gene_enrichment_analysis/symbiotic_genes/symbiotic_proteins_enrichment.txt --obsolete replace > /dev/null

In [None]:
!find_enrichment.py ../../data/symbiotic_wps/go_gene_id_aep_specific_goatools.txt ../../results/gene_enrichment_analysis/goatools/populations.txt ../../results/gene_enrichment_analysis/goatools/associations.txt --annofmt id2gos --alpha 0.05 --pval 0.05 --obo ../../results/gene_enrichment_analysis/goatools/go-basic.obo --method fdr_bh --outfile ../../results/gene_enrichment_analysis/symbiotic_genes/caep_specific_proteins_enrichment.txt --obsolete replace > /dev/null

In [None]:
dataframe = pd.read_table("../../results/gene_enrichment_analysis/symbiotic_genes/symbiotic_proteins_enrichment.txt")
plot_goa(dataframe, "../../results/gene_enrichment_analysis/symbiotic_genes/", "symbiotic_proteins")

In [None]:
dataframe = pd.read_table("../../results/gene_enrichment_analysis/symbiotic_genes/caep_specific_proteins_enrichment.txt")
plot_goa(dataframe, "../../results/gene_enrichment_analysis/symbiotic_genes/", "caep_specific_proteins")

# Up/Down regulated symbiotic proteins

In [None]:
log2FoldChange_table = merged_table.merge(log2Fold_df, on="old_locus_tag")
upregulated = log2FoldChange_table[log2FoldChange_table.log2FoldChange <= -1.0]
downregulated = log2FoldChange_table[log2FoldChange_table.log2FoldChange >= 1.0]
unregulated = log2FoldChange_table[abs(log2FoldChange_table.log2FoldChange)<1.0]

In [None]:
with open("../../data/symbiotic_wps/go_gene_id_upregulated_on_host_goatools.txt", "w") as goatools:
    for gene in upregulated[upregulated["old_locus_tag"].isin(pop_ids) == True]["old_locus_tag"]:
        goatools.write(gene.strip()+"\n")

In [None]:
!find_enrichment.py ../../data/symbiotic_wps/go_gene_id_upregulated_on_host_goatools.txt ../../results/gene_enrichment_analysis/goatools/populations.txt ../../results/gene_enrichment_analysis/goatools/associations.txt --annofmt id2gos --alpha 0.05 --pval 0.05 --obo ../../results/gene_enrichment_analysis/goatools/go-basic.obo --method fdr_bh --outfile ../../results/gene_enrichment_analysis/symbiotic_genes/symbiotic_proteins_upregulated_on_host_enrichment.txt --obsolete replace > /dev/null

In [None]:
dataframe = pd.read_table("../../results/gene_enrichment_analysis/symbiotic_genes/symbiotic_proteins_upregulated_on_host_enrichment.txt")
plot_goa(dataframe, "../../results/gene_enrichment_analysis/symbiotic_genes/", "symbiotic_upregulated_proteins")

In [None]:
significant_proteins = []
for item in dataframe.study_items:
    for protein in item.split(","):
        
        protein = protein.strip()
        
        if protein not in significant_proteins:
            significant_proteins.append(protein)
enriched_upregulated = log2FoldChange_table[log2FoldChange_table["old_locus_tag"].isin(significant_proteins)]

In [None]:
annotations = pd.read_csv("../../results/curvibacter_genome_annotation.csv", index_col=0)
annotations.rename(columns={"wp_number":"protein_id"}, inplace=True)
annotations.head()

In [None]:
annotations.merge(enriched_upregulated, on="protein_id")[["protein_id","description", "log2FoldChange", "GO","GO_process"]].sort_values(by="log2FoldChange")

In [None]:
counts = []
pvals = []
counts_pop = []
labels = []
ratio = []

for count, pval, count_pop, label, go in zip(dataframe["study_count"],
                                             dataframe["p_fdr_bh"],
                                             dataframe["amount_in_pop"],
                                             dataframe["name"],
                                             dataframe["# GO"]):
    
    labels.append(go + " " + label)
    counts.append(count)
    pvals.append(pval)
    counts_pop.append(count_pop)
    ratio.append(count/count_pop)
    #print(count, pval, count_pop,label, go)

marker_size = ratio
pcolors = pvals
norm_p_values = np.array(pcolors) / max(pcolors)
colors=plt.cm.RdBu_r(norm_p_values)
marker_legend = marker_size
marker_legend = sorted(marker_legend, reverse=True)
sorted_markersize = sorted(np.array(marker_legend)*10000, reverse=True)
sorted_markers = sorted(marker_legend, reverse=True)

In [None]:
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 10), sharey=False, gridspec_kw={'width_ratios': [2.5, 2.0, 1]})

ax1.sharey(ax2)

ax1.barh(labels,width=counts, color=colors, edgecolor="black")
ax2.scatter(y=[0,1,2,3,4], x=ratio, s=np.array(marker_size)*10000, c=colors, cmap='RdBu_r', edgecolors="black")

ax1.set_xlim(0,25)
#ax1.set_xticklabels(fontsize=16)
ax1.set_xticks([0,5,10,15,20])
ax1.set_xticklabels([0,5,10,15,20], fontsize=20)
ax1.set_yticks([0,1,2,3,4])
ax1.set_yticklabels(labels, fontsize=20, fontdict={'fontweight':"bold"})
ax1.set_xlabel("Count", fontsize=20, labelpad=20,  fontdict={'fontweight':"bold"})
ax1.grid()

ax2.set_xlim(-0.2,1.2)
ax2.set_xticks([0.0,0.25,0.5,0.75,1.0])
ax2.set_xticklabels([0.0,0.25,0.5,0.75,1.0], fontsize=20)
ax2.set_xlabel("Count/ClusterSize", fontsize=20, labelpad=20,  fontdict={'fontweight':"bold"})
ax2.tick_params(axis='y', labelleft=False)
ax2.grid()

plt.subplots_adjust(left=0.2, wspace=0.05)

cbar = fig.colorbar(ScalarMappable(cmap='RdBu_r'), ax=[ax1, ax2], pad = 0.005)
cbar.set_label('p-values',fontsize=20, labelpad=0)
cbar.set_ticks([0, 1])
cbar.set_ticklabels([f'{0:.2f}', f'{max(pvals):.2f}'])
cbar.ax.tick_params(labelsize=20)

cbar.ax.set_position([0.71, 0.12, 0.03, 0.35])

ax3.set_position([0.71, 0.52, 0.05, 0.315])
ax3.scatter(x=[0.5 for i in labels], y=[0.5,2,3,4,5], s=sorted_markersize, color="grey", edgecolor="black")
ax3.set_yticks([0.5,2,3,4,5])
ax3.set_yticklabels(counts_pop, fontsize=20)
ax3.set_xlim(0,1)
ax3.set_ylim(-1,6)
ax3.tick_params(axis='y', labelright=True, labelleft=False)
ax3.tick_params(axis='x', labelbottom=False)
ax3.set_xticks([])
ax3.yaxis.tick_right()   
ax3.set_title("Count", pad=10, fontsize=20)
#ax3.grid()
#ax3.axis("off")
plt.savefig("../../results/gene_enrichment_analysis/up_symbiotic_combined.jpg", dpi=400, bbox_inches='tight')
#plt.tight_layout()

In [None]:
new_labels = ['GO:0008643 carbohydrate transport','GO:0051179 localization and transport','GO:0008152 metabolic process']
new_counts = [5,15,5]
new_ratio = [0.25, (0.05813953488372093 +0.05747126436781609+0.056818181818181816)/3,0.006112469437652812]
new_marker_size = new_ratio
new_counts_pop = [20,int((258+261+264)/3),818]
new_pvals = [0.0441145001987955,0.0441765749751197,0.0441145001987955]


new_pcolors = new_pvals
new_norm_p_values = np.array(new_pcolors) / max(new_pcolors)
new_colors=plt.cm.RdBu_r(new_norm_p_values)
new_marker_legend = new_marker_size
new_marker_legend = sorted(new_marker_legend, reverse=True)
new_sorted_markersize = sorted(np.array(new_marker_legend)*10000, reverse=True)
new_sorted_markers = sorted(new_marker_legend, reverse=True)

In [None]:
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 10), sharey=False, gridspec_kw={'width_ratios': [2.5, 2.0, 1]})

ax1.sharey(ax2)

ax1.barh(new_labels,width=new_counts, color=new_colors, edgecolor="black")
ax2.scatter(y=[0,1,2], x=new_ratio, s=np.array(new_marker_size)*10000, c=new_colors, cmap='RdBu_r', edgecolors="black")

ax1.set_xlim(0,25)
#ax1.set_xticklabels(fontsize=16)
ax1.set_xticks([0,5,10,15,20,25])
ax1.set_xticklabels([0,5,10,15,20,25], fontsize=20)
ax1.set_yticks([0,1,2])
ax1.set_yticklabels(new_labels, fontsize=20, fontdict={'fontweight':"bold"})
ax1.set_xlabel("Count", fontsize=20, labelpad=20,  fontdict={'fontweight':"bold"})
ax1.grid()

ax2.set_xlim(-0.2,1.2)
ax2.set_xticks([0.0,0.25,0.5,0.75,1.0])
ax2.set_xticklabels([0.0,0.25,0.5,0.75,1.0], fontsize=20)
ax2.set_xlabel("Count/ClusterSize", fontsize=20, labelpad=20,  fontdict={'fontweight':"bold"})
ax2.tick_params(axis='y', labelleft=False)
ax2.grid()

plt.subplots_adjust(left=0.2, wspace=0.05)

cbar = fig.colorbar(ScalarMappable(cmap='RdBu_r'), ax=[ax1, ax2], pad = 0.005)
cbar.set_label('p-values',fontsize=20, labelpad=0)
cbar.set_ticks([0, 1])
cbar.set_ticklabels([f'{0:.2f}', f'{max(new_pvals):.2f}'])
cbar.ax.tick_params(labelsize=20)

cbar.ax.set_position([0.71, 0.12, 0.03, 0.35])

ax3.set_position([0.71, 0.52, 0.05, 0.315])
ax3.scatter(x=[0.5 for i in new_labels], y=[0.5,2,3], s=new_sorted_markersize, color="grey", edgecolor="black")
ax3.set_yticks([0.5,2,3])
ax3.set_yticklabels(new_counts_pop, fontsize=20)
ax3.set_xlim(0,1)
ax3.set_ylim(-1,4)
ax3.tick_params(axis='y', labelright=True, labelleft=False)
ax3.tick_params(axis='x', labelbottom=False)
ax3.set_xticks([])
ax3.yaxis.tick_right()   
ax3.set_title("Count", pad=10, fontsize=20)
#ax3.grid()
#ax3.axis("off")
plt.savefig("../../results/gene_enrichment_analysis/up_symbiotic_comprehensive_combined.jpg", dpi=400, bbox_inches='tight')
#plt.tight_layout()

In [None]:
with open("../../data/symbiotic_wps/go_gene_id_downregulated_on_host_goatools.txt", "w") as goatools:
    for gene in downregulated[downregulated["old_locus_tag"].isin(pop_ids) == True]["old_locus_tag"]:
        goatools.write(gene.strip()+"\n")

In [None]:
!find_enrichment.py ../../data/symbiotic_wps/go_gene_id_downregulated_on_host_goatools.txt ../../results/gene_enrichment_analysis/goatools/populations.txt ../../results/gene_enrichment_analysis/goatools/associations.txt --annofmt id2gos --alpha 0.05 --pval 0.05 --obo ../../results/gene_enrichment_analysis/goatools/go-basic.obo --method fdr_bh --outfile ../../results/gene_enrichment_analysis/symbiotic_genes/symbiotic_proteins_downregulated_on_host_enrichment.txt --obsolete replace

In [None]:
#dataframe = pd.read_table("../../results/gene_enrichment_analysis/symbiotic_genes/symbiotic_proteins_downregulated_on_host_enrichment.txt")
#plot_goa(dataframe, "../../results/gene_enrichment_analysis/symbiotic_genes/", "symbiotic_downregulated_proteins")

In [None]:
with open("../../data/symbiotic_wps/go_gene_id_unregulated_on_host_goatools.txt", "w") as goatools:
    for gene in unregulated[unregulated["old_locus_tag"].isin(pop_ids) == True]["old_locus_tag"]:
        goatools.write(gene.strip()+"\n")

In [None]:
!find_enrichment.py ../../data/symbiotic_wps/go_gene_id_unregulated_on_host_goatools.txt ../../results/gene_enrichment_analysis/goatools/populations.txt ../../results/gene_enrichment_analysis/goatools/associations.txt --annofmt id2gos --alpha 0.05 --pval 0.05 --obo ../../results/gene_enrichment_analysis/goatools/go-basic.obo --method fdr_bh --outfile ../../results/gene_enrichment_analysis/symbiotic_genes/symbiotic_proteins_unregulated_on_host_enrichment.txt --obsolete replace

In [None]:
#dataframe = pd.read_table("../../results/gene_enrichment_analysis/symbiotic_genes/symbiotic_proteins_unregulated_on_host_enrichment.txt")
#plot_goa(dataframe, "../../results/gene_enrichment_analysis/symbiotic_genes/", "symbiotic_downregulated_proteins")

In [None]:
merged_table.to_csv("../../results/gene_enrichment_analysis/symbiotic_genes/merged_translation_table.csv")