In [None]:
import pandas as pd
import numpy as np
from itertools import combinations
from collections import Counter
import plotly.graph_objects as go
import plotly.express as px
from functools import partial
import matplotlib.pyplot as plt
import seaborn as sns


def find_min_max_columns(row, find_min):
    """Finds the column name(s) of the min/max values in a row."""
    value = row.min() if find_min else row.max()
    return row.index[row == value].tolist()

def get_VBS_and_SBS_info(metric, learning_task, task_output, find_min=True):
    df = pd.read_csv(f'../processed_data/{metric}/{learning_task}/performance.csv', index_col=0)
    solvers = np.load(f'../processed_data/{metric}/algo_portfolio.npy', allow_pickle=True)
    find_min_max_columns_partial = partial(find_min_max_columns, find_min=find_min)
    df['VBS'] = df.apply(find_min_max_columns_partial, axis=1)
    df['VBS_value'] =  df.min(axis=1, numeric_only=True) if find_min else df.max(axis=1, numeric_only=True)
    for solver in solvers:
        df[f'SBS_{solver}'] = df.apply(lambda row: np.abs(row[f'{solver}']-row.VBS_value), axis=1)
    return df, solvers


def get_predicted_VBS_regression(seed, metric, learning_task, task_output, find_min, solvers, df):
    find_min_max_columns_partial = partial(find_min_max_columns, find_min=find_min)
    predictions = []
    for fold in range(0, 40):
        if(task_output=='multi'):
            preds = np.array(np.load(f'../results/seed_{seed}/{metric}/{learning_task}/{task_output}/predictions/test_predictions_fold_{fold}.npy', allow_pickle=True))
            predictions.extend(preds)
        elif(task_output == 'single'):   
            preds = np.array([np.load(f'../results/seed_{seed}/{metric}/{learning_task}/{task_output}/predictions/test_predictions_fold_{fold}_{solver}.npy', allow_pickle=True)[0] for solver in solvers])
            predictions.append(preds)
    df_pred = pd.DataFrame(predictions, columns=[solver+"_pred" for solver in solvers], index=df.index)
    df_pred['predicted_VBS'] = df_pred.apply(find_min_max_columns_partial, axis=1)
    for el in df_pred['predicted_VBS']:
        if len(el) > 1:
            print("more than one prediction!")
    df_pred['predicted_VBS'] = [el[0].replace("_pred", "") for el in df_pred['predicted_VBS']]
    return df_pred['predicted_VBS']
    
    
def get_predicted_VBS_classification(seed, metric, df):
    predictions = []
    for fold in range(0, 40):
        preds = np.array(np.load(f'../results/seed_{seed}/{metric}/classification/single/predictions/test_predictions_fold_{fold}.npy', allow_pickle=True))
        predictions.extend(preds)
    df_pred = pd.DataFrame(predictions, columns=['predicted_VBS'], index=df.index)
    return df_pred['predicted_VBS']

def get_predicted_VBS_pairwise_regression(seed, metric, task_output, df, find_min):
    algos = np.load(f'../processed_data/{metric}/algo_portfolio.npy', allow_pickle=True)
    pairwise_combinations = list(combinations(algos, 2))
    predictions = []
    for fold in range(0, 40):
        if task_output == 'multi':
            preds = np.array(np.load(f'../results/seed_{seed}/{metric}/pairwise_regression/{task_output}/predictions/test_predictions_fold_{fold}.npy', allow_pickle=True))[0]
        else:
            preds = np.array([np.load(f'../results/seed_{seed}/{metric}/pairwise_regression/{task_output}/predictions/test_predictions_fold_{fold}_{pair[0]}_vs_{pair[1]}.npy', allow_pickle=True)[0] for pair in pairwise_combinations])
        preds_algo = []
        for i, pred in enumerate(preds):
            if find_min:
                if pred > 0:
                    preds_algo.append(pairwise_combinations[i][1])
                else:
                    preds_algo.append(pairwise_combinations[i][0])
            else: 
                if pred > 0:
                    preds_algo.append(pairwise_combinations[i][0])
                else:
                    preds_algo.append(pairwise_combinations[i][1])
        string_counts = Counter(preds_algo)
        max_freq = max(string_counts.values())
        most_frequent_strings = [string for string, count in string_counts.items() if count == max_freq]
        predictions.append(', '.join(most_frequent_strings))

    df_pred = pd.DataFrame(predictions, columns=['predicted_VBS'], index=df.index)
    return df_pred['predicted_VBS']

def get_predicted_VBS_pairwise_classification(seed, metric, task_output, df, cost_sensitive=False):
    algos = np.load(f'../processed_data/{metric}/algo_portfolio.npy', allow_pickle=True)
    pairwise_combinations = list(combinations(algos, 2))
    predictions = []
    for fold in range(0, 40):
        if task_output == 'multi':
            preds = np.array(np.load(f'../results/seed_{seed}/{metric}/{"cost_sensitive_" if cost_sensitive else ""}pairwise_classification/{task_output}/predictions/test_predictions_fold_{fold}.npy', allow_pickle=True))[0]
        else:
            preds = np.array([np.load(f'../results/seed_{seed}/{metric}/{"cost_sensitive_" if cost_sensitive else ""}pairwise_classification/{task_output}/predictions/test_predictions_fold_{fold}_{pair[0]}_vs_{pair[1]}.npy', allow_pickle=True)[0] for pair in pairwise_combinations])
        preds_algo = []
        for i, pred in enumerate(preds):
            preds_algo.append(pairwise_combinations[i][pred])
        
        string_counts = Counter(preds_algo)
        max_freq = max(string_counts.values())
        most_frequent_strings = [string for string, count in string_counts.items() if count == max_freq]
        predictions.append(', '.join(most_frequent_strings))

    df_pred = pd.DataFrame(predictions, columns=['predicted_VBS'], index=df.index)
    return df_pred['predicted_VBS']

def build_AS(seed, metric, find_min):
    df, solvers = get_VBS_and_SBS_info(metric, "regression", 'multi', find_min)
    print("find min: ", find_min)
    # regression
    learning_task = 'regression'
    for task_output in ['multi', 'single' ]: 
        predicted_VBS = get_predicted_VBS_regression(seed, metric, learning_task, task_output, find_min, solvers, df)
        predicted_VBS = pd.DataFrame(predicted_VBS)
        df[f'AS-{learning_task[0].upper()}-{task_output[0].upper()}O'] = predicted_VBS


    # classification 
    predicted_VBS = get_predicted_VBS_classification(seed, metric, df)
    predicted_VBS = pd.DataFrame(predicted_VBS)
    df[f'AS-C-SO'] = predicted_VBS
    
    
    # pairwise_regression
    for task_output in ['multi', 'single']:
        predicted_VBS = get_predicted_VBS_pairwise_regression(seed, metric, task_output, df, find_min)
        predicted_VBS = pd.DataFrame(predicted_VBS)
        df[f'AS-PR-{task_output[0].upper()}O'] = predicted_VBS

    # pairwise_classification
    for task_output in ['multi', 'single']:
        predicted_VBS = get_predicted_VBS_pairwise_classification(seed, metric, task_output, df)
        predicted_VBS = pd.DataFrame(predicted_VBS)
        df[f'AS-PC-{task_output[0].upper()}O'] = predicted_VBS
    
    # cost sensitive pairwise classification
    predicted_VBS = get_predicted_VBS_pairwise_classification(seed, metric, 'single', df, cost_sensitive=True)
    predicted_VBS = pd.DataFrame(predicted_VBS)
    df[f'AS-CS-PC-SO'] = predicted_VBS
    
    return df


def calculate_algorithm_frequencies(df):
    # Determine all unique algorithms from all columns to initialize the counts
    unique_algorithms = set()
    for column in df.columns:
        for entry in df[column]:
            if isinstance(entry, str):
                # Handle both single and comma-separated entries
                algorithms = [algo.strip() for algo in entry.split(',')]
            elif isinstance(entry, list):
                # Directly handle list entries
                algorithms = entry
            else:
                # Handle single non-string (assuming direct use)
                algorithms = [entry]
            unique_algorithms.update(algorithms)
    
    # Initialize a dictionary to store frequencies
    frequencies = {column: {algo: 0 for algo in unique_algorithms} for column in df.columns}
    
    # Loop over each column in the DataFrame to count frequencies
    for column in df.columns:
        # Iterate over each row in the column
        for entry in df[column]:
            if isinstance(entry, str) and ',' in entry:
                algorithms = [algo.strip() for algo in entry.split(',')]
            elif isinstance(entry, list):
                algorithms = entry
            else:
                algorithms = [entry]
                
            # Increment each algorithm's frequency
            for algo in algorithms:
                if algo in frequencies[column]:
                    frequencies[column][algo] += 1
    
    return pd.DataFrame(frequencies)/40

for seed in range(42, 43):
    print("seed ", seed)

    for metric in ['HAMMING LOSS example based', 'MACRO F1', 'MICRO F1', 'AUCROC MICRO', 'F1 example based']:
        find_min = True if metric == 'HAMMING LOSS example based' else False
        algos = np.load(f"../processed_data/{metric}/algo_portfolio.npy", allow_pickle=True)
        


        df_AS = build_AS(seed, metric, find_min)
        
        df_AS = df_AS[['VBS', 'AS-R-MO', 'AS-R-SO', 'AS-PR-MO', 'AS-PR-SO', 'AS-C-SO', 'AS-PC-MO', 'AS-PC-SO', 'AS-CS-PC-SO']]
        

        heatmap_data = calculate_algorithm_frequencies(df_AS)
        heatmap_data.index = [col.replace('RSMLCC', 'RSLP') for col in heatmap_data.index]
        heatmap_data.index = [col.replace('Ada300', 'AdaBoost.MH') for col in heatmap_data.index]

        fig = px.imshow(heatmap_data.round(3), color_continuous_scale='reds', text_auto=True)
        fig.update_coloraxes(showscale=False)
        fig.update_xaxes(side="top")
        
        # Adjust layout
        fig.update_layout(
            autosize=True,
            width=720,
            margin=dict(
                l=0,
                r=0,
                b=0,
                t=0,
            ),
            font=dict(
                size=18
            )
        )
        print(metric, algos)
        fig.show()
        fig.write_image(f'../figures/heatmap_perc_{metric}.pdf')

        
        


seed  42
find min:  True
HAMMING LOSS example based ['DEEP4' 'RFPCT' 'CC' 'Ada300' 'TREMLC']


find min:  False
MACRO F1 ['CLR' 'Ada300' 'RFDTBR' 'CC' 'RSMLCC']


find min:  False
MICRO F1 ['RFDTBR' 'RFPCT' 'Ada300' 'CLR' 'PSt']


find min:  False
AUCROC MICRO ['RFPCT' 'PSt' 'RFDTBR' 'EBRJ48' 'TREMLC']


find min:  False
F1 example based ['RFPCT' 'RFDTBR' 'RSMLCC' 'PSt' 'Ada300']
