In [None]:
import os
import json
import re
from datetime import datetime
from glob import glob
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.stats import sem
import scipy.stats as stats

# Summarizing experiments

class Experiment:
    def __init__(self, experiment_data):
        (self.algorithm,
         self.date, 
         self.sim,
         self.params, 
         self.sim_info,
         self.n_clients, 
         self.n_rounds,
         self.n_checkpoints,
         self.n_cp_train_samps,
         self.n_cp_val_samps,
         self.client_mse,
         self.client_clean_mse,
         self.client_epistemic_mse,
         self.client_aleatoric_var,
         self.avg_mse, 
         self.var_mse,
         self.avg_clean_mse,
         self.var_clean_mse,
         self.avg_epistemic_mse,
         self.var_epistemic_mse) = self._dict_to_experiment(experiment_data)

    def _dict_to_experiment(self, experiment_data):
        algorithm = experiment_data['params']['algorithm']
        date = datetime.strptime(experiment_data['date'], "%Y-%m-%d_%H:%M:%S")
        sim = experiment_data['params']['sim']
        params = experiment_data['params']
        sim_info = experiment_data['simulation']
        
        n_clients = int(params['min_available_clients'])
        n_rounds = int(params['n_rounds'])
        
        n_checkpoints = int(sim_info['n_checkpoints'])
        n_cp_train_samps = int(sim_info['n_checkpoint_train_samples'])
        n_cp_val_samps = int(sim_info['n_checkpoint_val_samples'])
        
        client_mse = np.zeros((n_rounds, n_clients))
        for round_num, mse_dict in experiment_data['results']['metrics_distributed']['client_mse']:
            client_mse[round_num - 1] = np.array([mse_dict[str(i)] for i in range(n_clients)])
            
        client_clean_mse = np.zeros((n_rounds, n_clients))
        for round_num, clean_mse_dict in experiment_data['results']['metrics_distributed']['client_clean_mse']:
            client_clean_mse[round_num - 1] = np.array([clean_mse_dict[str(i)] for i in range(n_clients)])
            
        client_aleatoric_var = np.zeros((n_rounds, n_clients))
        for round_num, aleatoric_var_dict in experiment_data['results']['metrics_distributed']['client_noise_var']:
            client_aleatoric_var[round_num - 1] = np.array([aleatoric_var_dict[str(i)] for i in range(n_clients)])
            
        client_epistemic_mse = client_mse - client_aleatoric_var
            
        avg_mse = client_mse.mean(axis=1)
        var_mse = client_mse.var(axis=1)
        avg_clean_mse = client_clean_mse.mean(axis=1)
        var_clean_mse = client_clean_mse.var(axis=1)
        avg_epistemic_mse = client_epistemic_mse.mean(axis=1)
        var_epistemic_mse = client_epistemic_mse.var(axis=1)
        
        return (algorithm, date, sim, params, sim_info, n_clients, n_rounds, n_checkpoints,
                n_cp_train_samps, n_cp_val_samps, client_mse, client_clean_mse, client_epistemic_mse, client_aleatoric_var, 
                avg_mse, var_mse, avg_clean_mse, var_clean_mse, avg_epistemic_mse, var_epistemic_mse)

    def __str__(self):
        return \
f"""\
Experiment {self.date}

algorithm:\t\t{self.algorithm}
data:\t\t\t{self.sim}
n_clients:\t\t{self.n_clients}
n_rounds:\t\t{self.n_rounds}
n_checkpoints:\t\t{self.n_checkpoints}
n_cp_train_samps:\t{self.n_cp_train_samps}
n_cp_val_samps:\t\t{self.n_cp_val_samps}
avg_mse:\t\t\t{self.avg_mse[-1]:.5g}
var_mse:\t\t\t{self.var_mse[-1]:.5g}
    
params: {self.params}
"""


def load_experiment(select="all", directory="/proj/fair-ai/fair-fl/out"):
    # Find all JSON files in the directory
    json_files = glob(os.path.join(directory, "*.json"))

    # Extract the datetimes and file paths for each file
    datetime_files = []
    for json_file in json_files:
        try:
            datetime_str = re.search(r"experiment-(\d{4}_\d{2}_\d{2}-\d{2}_\d{2}_\d{2})", json_file).group(1)
            dt = datetime.strptime(datetime_str, "%Y_%m_%d-%H_%M_%S")
            datetime_files.append((dt, json_file))
        except (AttributeError, ValueError):
            continue

    # Sort the files by datetime
    datetime_files = sorted(datetime_files, key=lambda x: x[0], reverse=True)

    if isinstance(select, int):
        # Load the experiment at the specified index
        if -len(datetime_files) <= select < len(datetime_files):
            file_path = datetime_files[select][1]
            with open(file_path, 'r') as f:
                return Experiment(json.load(f))
        else:
            print(f"Index {select} out of range")
            return None
    elif select == "all":
        # Load all experiments
        experiments = []
        for _, file_path in datetime_files:
            with open(file_path, 'r') as f:
                experiments.append(Experiment(json.load(f)))
        return experiments
    elif select == "latest":
        # Load the latest experiment
        if datetime_files:
            latest_file = datetime_files[0][1]
            with open(latest_file, 'r') as f:
                return Experiment(json.load(f))
        else:
            print("No experiment files found")
            return None
    else:
        # Load the experiment with the specified datetime string
        dt = datetime.strptime(select, "%Y_%m_%d-%H_%M_%S")
        for file_dt, file_path in datetime_files:
            if file_dt == dt:
                with open(file_path, 'r') as f:
                    return Experiment(json.load(f))
        print(f"No experiment file found with datetime: {select}")
        return None
    

def list_files(dir_path):
    return sorted(os.listdir(dir_path), reverse=True)


# Plotting

colors = ['blue', 'green', 'red', 'cyan', 'magenta', 'yellow', 'black']

def plot_avg_mse(*experiments, epistemic=False, ax=None):
    if ax is None:
        _, ax = plt.subplots(figsize=(10, 6))

    for idx, exp in enumerate(experiments):
        if epistemic:
            ax.plot(np.arange(1, exp.n_rounds + 1), exp.avg_epistemic_mse, label=exp.algorithm, color=colors[idx])
        else:
            ax.plot(np.arange(1, exp.n_rounds + 1), exp.avg_mse, label=exp.algorithm, color=colors[idx])

    ax.set_xlabel('Round')
    ax.set_ylabel(f'Average {"epistemic " if epistemic else ""}mse')
    ax.set_title('Training history')
    ax.legend()

    if not ax:
        plt.show()


def plot_var_mse(*experiments, epistemic=False, ax=None):
    if ax is None:
        _, ax = plt.subplots(figsize=(10, 6))

    for idx, exp in enumerate(experiments):
        if epistemic:
            ax.plot(np.arange(1, exp.n_rounds + 1), exp.var_epistemic_mse, label=exp.algorithm, color=colors[idx])
        else:
            ax.plot(np.arange(1, exp.n_rounds + 1), exp.var_mse, label=exp.algorithm, color=colors[idx])

    ax.set_ylim(bottom=0)
    ax.set_xlabel('Round')
    ax.set_ylabel(f'Variance of {"epistemic " if epistemic else ""}mse')
    ax.set_title('Training history')
    ax.legend()

    if not ax:
        plt.show()



def plot_client_mse_dist(*experiments, ax=None):
    if ax is None:
        _, ax = plt.subplots(figsize=(10, 6))

    n_clients = experiments[0].n_clients
    bar_width = 0.9 / len(experiments)
    client_indices = np.arange(n_clients)

    for idx, exp in enumerate(experiments):
        min_mse_round = np.argmin(exp.avg_epistemic_mse)  # round with minimum avg_mse
        client_aleatoric_var_at_min_mse_round = exp.client_aleatoric_var[min_mse_round]
        client_epistemic_mse_at_min_mse_round = exp.client_epistemic_mse[min_mse_round]
        # sorted_indices = np.argsort(client_epistemic_mse_at_min_mse_round + client_aleatoric_var_at_min_mse_round)
        sorted_indices = client_indices

        # Create stacked bar by first plotting epistemic mse and then aleatoric variance on top of it
        ax.bar(client_indices + idx * bar_width, client_epistemic_mse_at_min_mse_round[sorted_indices], 
               bar_width, alpha=0.8, label=f'{exp.algorithm} - Epistemic', color=colors[idx])

        ax.bar(client_indices + idx * bar_width, client_aleatoric_var_at_min_mse_round[sorted_indices], 
               bar_width, bottom=client_epistemic_mse_at_min_mse_round[sorted_indices], alpha=0.4, 
               label=f'{exp.algorithm} - Aleatoric', color=colors[idx])

    ax.set_xlabel('Clients')
    ax.set_ylabel('Best mse')
    ax.set_title('Client mse distribution at minimum avg mse round')
    ax.set_xticks(client_indices + bar_width / 2 * len(experiments))
    ax.set_xticklabels(sorted_indices.astype(str))
    ax.legend()

    if not ax:
        plt.tight_layout()
        plt.show()

        
def generate_experiment_summary(*experiments):
    info_names = ['algo', 'sim', 'recollect', 
                  'av_mse', 'v_mse', 
                  'av_epi_mse', 'v_epi_mse',
                  'av_clean_mse', 'v_clean_mse',
                  'b_sz', 'lr', 'date']
    
    info_data = []
    for exp in experiments:
        ix = np.argmin(exp.avg_epistemic_mse)
        info_data.append(
            [
                exp.algorithm, exp.sim, exp.params['recollection_strategy'],
                np.min(exp.avg_mse), exp.var_mse[ix],
                exp.avg_epistemic_mse[ix], exp.var_epistemic_mse[ix],
                exp.avg_clean_mse[ix], exp.var_clean_mse[ix],
                exp.params['batch_size'], exp.params['learning_rate'], exp.date
            ]
        )
    
    experiments_summary = pd.DataFrame(info_data, columns=info_names)
    
    experiments_summary.sim = experiments_summary.sim.str.split('/').str[-3:-1]
    
    experiments_summary.set_index('date', inplace=True)
    experiments_summary.sort_index(ascending=False, inplace=True)
    
    return experiments_summary


def summarize_experiments(directory="/proj/fair-ai/fair-fl/out", rows=-1):
    # Load all experiments
    experiments = load_experiment("all", directory=directory)
    display(generate_experiment_summary(*experiments).iloc[:rows])

def describe_experiment(*experiments):
    experiments_info = generate_experiment_summary(*experiments)
    display(experiments_info)
    
    # Plots
    fig = plt.figure(figsize=(14, 14))

    # Set up the grid
    grid = plt.GridSpec(3, 2, wspace=0.4, hspace=0.3)
    
    # Add subplots
    ax1 = fig.add_subplot(grid[0, 0])  # First row, left
    ax2 = fig.add_subplot(grid[0, 1])  # First row, right
    ax3 = fig.add_subplot(grid[1, 0])  # Middle row, left
    ax4 = fig.add_subplot(grid[1, 1])  # Middle row, right
    ax5 = fig.add_subplot(grid[2, :])  # Bottom row, full width
    
    plot_avg_mse(*experiments, ax=ax1, epistemic=False)
    plot_var_mse(*experiments, ax=ax2, epistemic=False)
    plot_avg_mse(*experiments, ax=ax3, epistemic=True)
    plot_var_mse(*experiments, ax=ax4, epistemic=True)
    plot_client_mse_dist(*experiments, ax=ax5)

    plt.show()

In [None]:
# summarize_experiments(rows=10)
summarize_experiments(rows=10, directory="/proj/fair-ai/fair-fl/out/bike_hom/nsr_0.2/fedavg/")

In [None]:
exp = load_experiment(0, directory="/proj/fair-ai/fair-fl/out/")
describe_experiment(exp)