In [1]:
import warnings
warnings.filterwarnings('ignore')
import pandas as pd
import sys
sys.path.insert(0, "../../src/")
import seaborn as sns
import matplotlib.pyplot as plt
from dimensionality_reductions.linear_reductions import LinearReduction
from dimensionality_reductions.non_linear_reductions import NonLinearReductions
import math

In [2]:
palette_values = ['#00B0BE', '#F45F74', '#98C127']
colors = sns.color_palette(palette_values)

In [3]:
name_repr_list = {
    "basic": [
        "ordinal_antiviral_homology_90", 
        "one_hot_antiviral_homology_90", 
        "frequency_antiviral_homology_90"
    ],
     "demo": [
        "demo_antiviral_homology_90_esm1b_t33_650M_UR50S",
        "demo_antiviral_homology_90_esm2_t6_8M_UR50D",
        "demo_antiviral_homology_90_esm2_t12_35M_UR50D",
        "demo_antiviral_homology_90_esm2_t36_3B_UR50D",
        "demo_antiviral_homology_90_prot_t5_xl_bfd",
        #"demo_antiviral_homology_90_prot_t5_xl_uniref50",
    ],
    "fft": [
        "fft/fft_antiviral_homology_90_FASG760101",
        "fft/fft_antiviral_homology_90_FAUJ880111",
        "fft/fft_antiviral_homology_90_FAUJ880112",
        "fft/fft_antiviral_homology_90_GEIM800101",
        "fft/fft_antiviral_homology_90_GEIM800105",
        "fft/fft_antiviral_homology_90_JOND750101",
        "fft/fft_antiviral_homology_90_KLEP840101",
        "fft/fft_antiviral_homology_90_ROBB760113",
        "fft/fft_antiviral_homology_90_ZIMJ680104",
    ],
    "pc": [
        "pc/pc_antiviral_homology_90_FASG760101",
        "pc/pc_antiviral_homology_90_FAUJ880111",
        "pc/pc_antiviral_homology_90_FAUJ880112",
        "pc/pc_antiviral_homology_90_GEIM800101",
        "pc/pc_antiviral_homology_90_GEIM800105",
        "pc/pc_antiviral_homology_90_JOND750101",
        "pc/pc_antiviral_homology_90_KLEP840101",
        "pc/pc_antiviral_homology_90_ROBB760113",
        "pc/pc_antiviral_homology_90_ZIMJ680104",
    ]
}

In [4]:
def plot_group(group_name, repr_list, visualization_type):
    n = len(repr_list)
    cols = 3
    rows = math.ceil(n / cols)

    fig, axs = plt.subplots(rows, cols, figsize=(cols * 5, rows * 4))
    axs = axs.flatten()

    for i, name_repr in enumerate(repr_list):
        df_data = pd.read_csv(f"../../results_demos/{name_repr}.csv")
        df_values = df_data.drop(columns=["target", "experimental_characteristics"])

        if visualization_type=="PCA":
            linear_instance = LinearReduction(dataset=df_values)
            pca_instance, transform_values_pca = linear_instance.applyPCA()
            transform_values_pca["target"] = df_data["target"].values

            sns.scatterplot(data=transform_values_pca, x="p_1", y="p_2", hue="target", palette=colors, ax=axs[i])
        elif visualization_type=="UMAP" or visualization_type=="TSNE":
            non_linear_instance = NonLinearReductions(dataset=df_values)
            if visualization_type == "UMAP":
                transform_values_umap = non_linear_instance.applyUMAP()
                transform_values_umap["target"] = df_data["target"].values
                sns.scatterplot(data=transform_values_umap, x="p_1", y="p_2", hue="target", palette=colors, ax=axs[i])
            else:
                transform_values_tsne = non_linear_instance.applyTSNE()
                transform_values_tsne["target"] = df_data["target"].values
                sns.scatterplot(data=transform_values_tsne, x="p_1", y="p_2", hue="target", palette=colors, ax=axs[i])

        axs[i].set_title(name_repr.split("/")[-1], fontsize=10)
        axs[i].legend().remove()
    # Oculta los subplots sobrantes
    for j in range(i + 1, len(axs)):
        axs[j].axis('off')

    plt.tight_layout()
    plt.savefig(f"../../img/{visualization_type}_group_{group_name}.png", dpi=300, bbox_inches='tight')
    plt.close()

In [None]:
# Ejecuta para cada grupo
for group_name, repr_list in name_repr_list.items():
    #plot_group(group_name, repr_list, 'PCA')
    plot_group(group_name, repr_list, 'UMAP')
    #plot_group(group_name, repr_list, 'TSNE')