In [None]:
import pandas as pd
import numpy as np
import sklearn.metrics as metrics
import pickle
from collections import defaultdict
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.pyplot import figure
import matplotlib.patches as mpatches
import helper as helper

In [None]:
matched_pairs = []
with open('data/matched_pairs/matched_pair_kegg_go.txt', 'r') as f:  # data/matched_pairs/matched_list_bto_doid.txt 
    for line in f:
        tokens = line.strip().split('\t')
        matched_pairs.append((tokens[0], tokens[1]))

In [None]:
selected_kegg_terms = sorted(list(set([x for x,y in matched_pairs])))
selected_go_terms = sorted(list(set([y for x,y in matched_pairs])))
print(len(selected_kegg_terms), len(selected_go_terms))

In [None]:
labels = np.zeros((len(selected_kegg_terms), len(selected_go_terms)))
for i,x in enumerate(selected_kegg_terms):
    for j,y in enumerate(selected_go_terms):
        if (x,y) in matched_pairs:
            labels[i][j] = 1

In [None]:
andes_data, embed_names = helper.load_kegg_go_data('/results/andes_out', selected_kegg_terms, selected_go_terms)
andes_data_results = pd.DataFrame(50-(np.array(helper.generate_kegg_go_result(andes_data, labels))).T)

In [None]:
means = andes_data_results.mean()
sorted_indices = np.argsort(means)[::-1]
andes_data_sorted = andes_data_results.iloc[:, sorted_indices]
cleaned_embed_names_sorted = [embed_names[i] for i in sorted_indices]

meta_df = pd.read_csv('z_benchmark_embed_meta.csv', index_col=0, encoding='utf-8')
meta_df.index = meta_df.index.str.replace(r'\s+', '', regex=True)

color_map = {
    'gene expression (bulk)': "#920015",
    'gene expression (single cell)': "#ef476f",
    'amino acid sequence': "#ffd166",
    'PPI': "#06d6a0",
    'biomedical literature': "#118ab2",
    'mutation profile, biomedical literature, PPI': "#073b4c"
}

cleaned_embed_names = [i for i in embed_names]

categories_sorted = [
    meta_df.loc[meta_df.index == name, 'Category'].values[0]
    for name in cleaned_embed_names
]

palette = [color_map[cat] for cat in categories_sorted]

In [None]:
plt.figure(figsize=(8, 15))

sns.boxplot(
    data=andes_data_sorted,
    orient="h",
    order=sorted_indices[::-1],
    palette=palette,
    width=.5 
)

plt.yticks(
    range(len(cleaned_embed_names_sorted)),
    cleaned_embed_names_sorted[::-1],
    fontsize=15
)
plt.title('KEGG-GO', fontsize=20)
plt.xlabel('Rank', fontsize=17)
plt.ylabel('', fontsize=17)


plt.tight_layout()

plt.savefig("/results/plots/kegg_go_andes.pdf", format="pdf", bbox_inches="tight")

plt.show()