In [None]:
import pandas as pd
import numpy as np
import loader as load
import config
import matplotlib
import matplotlib.pyplot as plt

files = ['aak_ge', 'tcma_gen', 'tcma_gen_aak_ge']
c_data_features = [pd.read_csv(fr"Data\Descriptor\Prediction_Tables\random_sampling\stage\{x}_linreg_predfeature.csv", index_col=None) for x in files]

In [None]:
overlap = c_data_features[-1]
overlap

In [None]:
# p: 200, target: tumor/stage, selection: linreg, layer: overlap
# what are the most selected bacterial features per cancer and across cancers
# what are the most selected Microbial features per cancer and across cancers

def plot_freq_features_selected(layer_features, files, layer_names, target, selection_type, root_folder="Visual/selection"):
    all_features, _ = load.getFeatures()
    tcma_gen_features, aak_ge_features = all_features

    p = 200

    for i, features_data in enumerate(layer_features):
        p_features_data = features_data[features_data.p == p]

        fig = plt.figure(linewidth=2,
        # tight_layout={'pad':0.1},
        figsize=(20,7),
        )
        # Make space for main title. Use instead of tight layout
        fig.subplots_adjust(top=0.8)
        main_title = f"{target} pred {layer_names[i]} Feature selection frequency | {selection_type}"
        for k, c in enumerate(["COAD", "ESCA", "HNSC", "STAD"][:4]):
            cancer_features_data = p_features_data[p_features_data.cancer == c]
            feature_frequencies = cancer_features_data["features"].value_counts()
            unique_selected_features = len(feature_frequencies)
            
            print(unique_selected_features, features_data)
            cell_text = feature_frequencies.values.reshape(unique_selected_features, 1)[:10]
            feature_names = feature_frequencies.index[:10]
            column_headers = ["Freq."]
            title = f"{c}"
            #487cfd, 0343df, r:E50000, ff4848

            rcolors = ["#5b8afd" if x in aak_ge_features else "#ff4848" for x in feature_names]
            ccolors = plt.cm.BuPu(np.full(len(column_headers), 0.1))

            ax = fig.add_subplot(1,len(c),(k+1))
            
            # Add a table at the bottom of the axes
            the_table = ax.table(cellText=cell_text,
                                rowLabels=feature_names,
                                rowColours=rcolors,
                                rowLoc='right',
                                colColours=ccolors,
                                colLabels=column_headers,
                                loc='center')
            
            # Make the rows taller (i.e., make cell y scale larger).
            the_table.scale(0.5, 2.5)
            # Hide axes
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)
            # Hide axes border
            plt.box(on=None)
            # Add title
            ax.set_title(title, fontsize=15)
            
        # Mid point of left and right x-positions
        # mid = (fig.subplotpars.right + fig.subplotpars.left)/2
        plt.suptitle(main_title, fontsize=20, y=0.92, x=0.5)
        # Force the figure to update, so backends center objects correctly within the figure.
        plt.draw()

        # plt.show()
        filename = f"{root_folder}/{target}/{selection_type}/{files[i]}_freq.png"
        load.createDirectory(filename)
        plt.savefig(filename, transparent=False, facecolor="white")

files = config.visualization_packages["base"]
for target in config.prediction_targets[:]:
    for selection_type in config.selection_types[:2]:
        prediction_tables_dir = fr"Data/Descriptor/Prediction_Tables/random_sampling/{target}"
        layer_features = [pd.read_csv(fr"{prediction_tables_dir}\{x}_{selection_type}_predfeature.csv", index_col=None) for x in files]
        layer_names = [config.modality_file_name_to_name[x] for x in files]
        # print(layer_features)
        plot_freq_features_selected(layer_features, files, layer_names, target, selection_type)