In [1]:
import numpy as np 
from matplotlib import pyplot as plt 

plt.rcParams.update({
    "text.usetex": True,            # Use LaTeX for all text
    "font.family": "serif",         # Use serif font
    "font.serif": ["Computer Modern Roman"],  # LaTeX default
    "axes.labelsize": 12,
    "font.size": 12,
    "legend.fontsize": 10,
    "xtick.labelsize": 10,
    "ytick.labelsize": 10,
})

import os 
from pathlib import Path 

import re 


In [2]:
FOLDER = "FASHION_MNIST_STRATIFIED_CLASSIFIERS_MADGAN"
# FOLDER = "MNIST_STRATIFIED_CLASSIFIERS_MADGAN"
BASE_EXPERIMENT = "2025-02-28_Stratified_classifierExperiment_FASHIONMNIST__BASE__images_real_5000_gen_0"

TYPE = "MADGAN"
DATASET = "FASHION"

metrics = ["val_accuracy", "val_f1_score", "val_loss"]

base_path = Path("C:/Users/NiXoN/Desktop/_thesis/mad_gan_thesis/notebooks")
strat_exp_path = base_path / 'experiments' / FOLDER 

experiments = os.listdir(strat_exp_path)


In [5]:
cmap = plt.get_cmap('tab10')  # good for distinct colors (up to 10)
colors = [cmap(i) for i in range(10)] + [(0.5, 0.5, 0.5, 1.0)]  # Add gray for the 11th color

color_dict_by_generator = {str(i): colors[i] for i in range(11)}


def extract_info_from_experiment_name(exp: str) -> dict: 
    splt = exp.split('_')
    ret = {'dataset': splt[3], 'n_gen': splt[5], 'used_gen': splt[8], 'n_real': splt[-3], 'n_fake': splt[-1]}
    
    if any([v == '' for v in ret.values()]): 
        print("ALARM")
        print(exp)

    return ret

def sort_dict_based_on_n_real_images(d: dict, reverse=True) -> dict: 
    return dict(sorted(
        d.items(),
        key=lambda x: int(re.search(r'images_real_(\d+)', x[0]).group(1)), 
        reverse=reverse
    ))


def plot_history_strat_classifiers(histories: dict, meta_info: dict, save_path: Path, target_gen, METRIC, show: bool = False) -> None: 
    mi = {'': [100]}
    ma = {'': [0]}
    
    all_vals = []
    max_n_gen = 0 
    min_n_gen = 100
    baseline = None
    for exp, hist in histories.items():

        if exp == BASE_EXPERIMENT:
            continue 
        
        dataset, n_gen, used_gen, n_real, n_fake = meta_info[exp].values()
        
        n_gen_int = int(used_gen)
        
        plt.plot(
            hist[METRIC], 
            color=(np.clip((int(n_gen_int) - 1) / (int(n_gen) - 1), 0, 1), 0, np.clip((int(n_gen_int) - 1) / (int(n_gen) - 1), 0, 1), .25)  # Light grey for individual runs
        )
        
        all_vals.append(hist[METRIC])
        
        if hist[METRIC][-1] > list(ma.values())[-1][-1]: 
            ma = {exp: hist[METRIC]}
            
        if hist[METRIC][-1] < list(mi.values())[-1][-1]: 
            mi = {exp: hist[METRIC]}
    
    # Pad shorter histories for averaging
    max_len = max(map(len, all_vals))
    all_vals_padded = [np.pad(a, (0, max_len - len(a)), constant_values=np.nan) for a in all_vals]
    
    avg = np.nanmean(all_vals_padded, axis=0)
    med = np.nanmedian(all_vals_padded, axis=0)
    
    # CUD color palette
    color_min = '#D55E00'  # Rust (worst)
    color_max = '#009E73'  # Teal (best)
    color_avg = '#0072B2'  # Blue (average)
    color_med = '#E69F00'  # Golden (median)
    
    # Plot min run
    dataset, n_gen, used_gen, n_real, n_fake = meta_info[list(mi.keys())[0]].values()
    plt.plot(list(mi.values())[-1], color=color_min, linewidth=2, label=f"Minimum, gen: {used_gen}, N-real: {n_real}, N-fake: {n_fake}")
    
    # Plot max run
    dataset, n_gen, used_gen, n_real, n_fake = meta_info[list(ma.keys())[0]].values()
    plt.plot(list(ma.values())[-1], color=color_max, linewidth=2, label=f"Maximum, gen: {used_gen}, N-real: {n_real}, N-fake: {n_fake}")
    
    # Plot average
    plt.plot(avg, color=color_avg, linewidth=2, label="Average")
    
    # Plot median
    plt.plot(med, color=color_med, linewidth=2, linestyle='--', label="Median")
    
    baseline_history = histories[BASE_EXPERIMENT]
    plt.plot(baseline_history[METRIC], color=(0,0,0,.5), linewidth=2.5, label='Baseline' )
    
    plt.title(f"Validation Accuracy - Dataset: {dataset}, N-Gen: {target_gen}")
    plt.legend()
    plt.tight_layout()

    
    if show:
        print(show)
        plt.show()
    else: 
        plt.savefig(save_path)

    plt.close()

In [11]:
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path

# Define BASE_EXPERIMENT somewhere accessible if not already defined globally
# Example: BASE_EXPERIMENT = 'your_baseline_experiment_key'
# Make sure BASE_EXPERIMENT exists as a key in the histories dictionary

def plot_history_strat_classifiers(histories: dict, meta_info: dict, save_path: Path, target_gen, METRIC, show: bool = False) -> None:
    """
    Plots the history of a given metric across multiple experiments,
    highlighting the best and worst runs based on the final value,
    and showing the average and median. Handles metrics where lower is better (e.g., val_loss).

    Args:
        histories (dict): Dictionary where keys are experiment identifiers and values are history objects (e.g., dicts containing lists of metrics per epoch).
        meta_info (dict): Dictionary mapping experiment identifiers to metadata (e.g., {'dataset': 'MNIST', 'n_gen': '10', 'used_gen': '5', 'n_real': 1000, 'n_fake': 1000}).
        save_path (Path): Path object where the plot image should be saved.
        target_gen (str or int): The total number of generators for this set of experiments (used in title).
        METRIC (str): The key for the metric to plot (e.g., "val_accuracy", "val_loss").
        show (bool, optional): If True, display the plot instead of saving. Defaults to False.
    """
    plt.figure(figsize=(10, 6)) # Add figure size for better readability

    # --- MODIFIED: Initialization for Best/Worst Tracking ---
    best_run_key = None
    best_run_hist = None
    worst_run_key = None
    worst_run_hist = None

    # Initialize comparison values based on METRIC type
    lower_is_better = METRIC == 'val_loss' # Add other metrics here if needed

    if lower_is_better:
        best_val = float('inf')  # Initialize best value to infinity (seeking minimum)
        worst_val = float('-inf') # Initialize worst value to negative infinity (seeking maximum)
    else: # For accuracy, f1_score, etc. (higher is better)
        best_val = float('-inf') # Initialize best value to negative infinity (seeking maximum)
        worst_val = float('inf')  # Initialize worst value to infinity (seeking minimum)
    # --- END MODIFICATION ---

    all_vals = []
    found_baseline = False # Flag to check if baseline exists

    for exp, hist in histories.items():

        if exp == BASE_EXPERIMENT:
            found_baseline = True # Mark baseline as found
            continue # Skip baseline for min/max/avg calculation for now

        # Ensure the metric exists in the history for this experiment
        if METRIC not in hist:
            print(f"Warning: Metric '{METRIC}' not found in history for experiment '{exp}'. Skipping.")
            continue
        if not hist[METRIC]: # Skip if metric list is empty
             print(f"Warning: Metric list for '{METRIC}' is empty for experiment '{exp}'. Skipping.")
             continue

        # Safely access meta info
        if exp not in meta_info:
             print(f"Warning: Meta info not found for experiment '{exp}'. Skipping for plotting details.")
             continue # Skip if no meta info for this experiment
        dataset, n_gen, used_gen, n_real, n_fake = meta_info[exp].values()

        n_gen_int_str = str(used_gen) # Use used_gen which seems to be the specific generator index string

        # Plot individual run (use a try-except for robustness if needed)
        try:
            # Define color - adjust alpha or calculation if needed
            color_val = np.clip((int(n_gen_int_str)+1) / (int(n_gen)), 0, 1) if int(n_gen) > 1 else 0.5
            # Example color: using a blue gradient, lighter for lower index gen
            run_color = (0.1, 0.2, color_val * 0.8 + 0.2, 0.25) # Light blueish gradient with alpha
            plt.plot(
                hist[METRIC],
                color=run_color # Example: use a colormap if preferred
            )
        except ValueError:
             print(f"Warning: Could not parse used_gen '{n_gen_int_str}' or n_gen '{n_gen}' as integer for experiment '{exp}'. Skipping color calculation.")
             plt.plot(hist[METRIC], color=(0.5, 0.5, 0.5, 0.25)) # Default light grey
        except IndexError:
             print(f"Warning: Issue accessing meta_info values for experiment '{exp}'.")
             plt.plot(hist[METRIC], color=(0.5, 0.5, 0.5, 0.25)) # Default light grey


        all_vals.append(hist[METRIC])
        current_final_val = hist[METRIC][-1]

        # --- MODIFIED: Update Best/Worst Logic ---
        if lower_is_better: # Handling val_loss (lower is better)
            # Update best run (lowest final value)
            if current_final_val < best_val:
                best_val = current_final_val
                best_run_key = exp
                best_run_hist = hist[METRIC]
            # Update worst run (highest final value)
            if current_final_val > worst_val:
                worst_val = current_final_val
                worst_run_key = exp
                worst_run_hist = hist[METRIC]
        else: # Handling accuracy, f1, etc. (higher is better)
            # Update best run (highest final value)
            if current_final_val > best_val:
                best_val = current_final_val
                best_run_key = exp
                best_run_hist = hist[METRIC]
            # Update worst run (lowest final value)
            if current_final_val < worst_val:
                worst_val = current_final_val
                worst_run_key = exp
                worst_run_hist = hist[METRIC]
        # --- END MODIFICATION ---

    if not all_vals:
        print(f"Error: No valid data found for metric '{METRIC}' in the provided histories (excluding baseline). Cannot generate plot.")
        plt.close() # Close the empty figure
        return

    # Pad shorter histories for averaging/median calculation
    try:
        max_len = max(map(len, all_vals))
        # Ensure padding value is appropriate (NaN for numerical data)
        all_vals_padded = [np.pad(np.array(a, dtype=float), (0, max_len - len(a)), mode='constant', constant_values=np.nan) for a in all_vals]

        # Calculate average and median, ignoring NaNs
        avg = np.nanmean(all_vals_padded, axis=0)
        med = np.nanmedian(all_vals_padded, axis=0)
    except ValueError as e:
        print(f"Error during padding or aggregation: {e}. Check data types in metric lists.")
        plt.close()
        return


    # CUD color palette
    color_worst = '#D55E00' # Rust (worst performance)
    color_best = '#009E73'  # Teal (best performance)
    color_avg = '#0072B2'  # Blue (average)
    color_med = '#E69F00'  # Golden (median)
    color_base = '#000000' # Black for baseline

    # Plot worst run if found
    if worst_run_key and worst_run_hist:
        dataset_w, n_gen_w, used_gen_w, n_real_w, n_fake_w = meta_info.get(worst_run_key, ["N/A"]*5) # Default if key missing
        # --- MODIFIED: Labeling ---
        worst_label = f"Worst ({'Max' if lower_is_better else 'Min'} {METRIC}), gen: {used_gen_w}, N-real: {n_real_w}, N-fake: {n_fake_w}"
        plt.plot(worst_run_hist, color=color_worst, linewidth=2, label=worst_label)
    else:
         print("Warning: Could not determine the worst run.")


    # Plot best run if found
    if best_run_key and best_run_hist:
        dataset_b, n_gen_b, used_gen_b, n_real_b, n_fake_b = meta_info.get(best_run_key, ["N/A"]*5) # Default if key missing
         # --- MODIFIED: Labeling ---
        best_label = f"Best ({'Min' if lower_is_better else 'Max'} {METRIC}), gen: {used_gen_b}, N-real: {n_real_b}, N-fake: {n_fake_b}"
        plt.plot(best_run_hist, color=color_best, linewidth=2, label=best_label)
    else:
        print("Warning: Could not determine the best run.")

    # Plot average
    plt.plot(avg, color=color_avg, linewidth=2, label="Average")

    # Plot median
    plt.plot(med, color=color_med, linewidth=2, linestyle='--', label="Median")

    # Plot baseline if it was found and has the metric
    if found_baseline and BASE_EXPERIMENT in histories and METRIC in histories[BASE_EXPERIMENT] and histories[BASE_EXPERIMENT][METRIC]:
        plt.plot(histories[BASE_EXPERIMENT][METRIC], color=color_base, linewidth=2.5, linestyle=':', label='Baseline')
    elif not found_baseline:
         print(f"Warning: Baseline experiment '{BASE_EXPERIMENT}' not found in histories.")
    else:
         print(f"Warning: Metric '{METRIC}' not found or empty in baseline history '{BASE_EXPERIMENT}'.")


    # --- MODIFIED: Dynamic Title ---
    # Extract dataset name from the first valid meta_info entry if possible
    try:
        first_valid_key = next((k for k in histories if k != BASE_EXPERIMENT and k in meta_info), None)
        plot_dataset_name = meta_info[first_valid_key]['dataset'] if first_valid_key else "Unknown Dataset"
    except:
        plot_dataset_name = "Unknown Dataset" # Fallback

    plt.title(f"{METRIC.replace('_', ' ').title()} - Dataset: {plot_dataset_name}, N-Gen: {target_gen}")
    # --- END MODIFICATION ---

    plt.xlabel("Epoch") # Add xlabel
    plt.ylabel(METRIC.replace('_', ' ').title()) # Add ylabel
    plt.legend(fontsize='small') # Adjust legend size if needed
    plt.grid(True, linestyle='--', alpha=0.6) # Add grid for readability
    plt.tight_layout()


    if show:
        print("Displaying plot...")
        plt.show()
    else:
        try:
            save_path.parent.mkdir(parents=True, exist_ok=True) # Ensure directory exists
            plt.savefig(save_path)
            print(f"Plot saved to {save_path}")
        except Exception as e:
            print(f"Error saving plot to {save_path}: {e}")

    plt.close() # Close the figure to free memory

### NOTE: 
there are only experiments with 3, 5, 7, 10 generators 

In [12]:
# load all histories: 

experiments_by_used_gen = {}

histories = {}
meta_info = {}

for METRIC in metrics:
    print(f"CURRENT METRIC: {METRIC}")
    for target_gen in [3, 5, 7, 10,]:
    
        print(f"CURRENT GENERATOR: {target_gen}")
    
        for exp in experiments: 
        
            meta = extract_info_from_experiment_name(exp)
    
        
            if meta['n_gen'] == str(target_gen):
                history = np.load(Path(strat_exp_path) / exp / 'training_history.npy', allow_pickle=True).item()
                histories[exp] = history
                meta_info[exp] = meta
                
            else: 
                continue
                
        histories = sort_dict_based_on_n_real_images(histories)
        meta_info = sort_dict_based_on_n_real_images(meta_info)
    
        exp = BASE_EXPERIMENT        
        meta = extract_info_from_experiment_name(exp)
    
        history = np.load(Path(strat_exp_path) / exp / 'training_history.npy', allow_pickle=True).item()
        histories[exp] = history
        meta_info[exp] = meta
    
        if histories:
            plot_history_strat_classifiers(
                histories, 
                meta_info, 
                Path("C:\\Users\\NiXoN\\Desktop\\_thesis\\mad_gan_thesis\\latex\\master_thesis\\abb\\strat_classifier_performance") / f"{METRIC}_{TYPE}_{DATASET}_n_gen_{target_gen}_all.png",    
                target_gen,
                METRIC
            )
        histories = {}
        meta_info = {}
        

    

CURRENT METRIC: val_accuracy
CURRENT GENERATOR: 3
Plot saved to C:\Users\NiXoN\Desktop\_thesis\mad_gan_thesis\latex\master_thesis\abb\strat_classifier_performance\val_accuracy_MADGAN_FASHION_n_gen_3_all.png
CURRENT GENERATOR: 5
Plot saved to C:\Users\NiXoN\Desktop\_thesis\mad_gan_thesis\latex\master_thesis\abb\strat_classifier_performance\val_accuracy_MADGAN_FASHION_n_gen_5_all.png
CURRENT GENERATOR: 7
Plot saved to C:\Users\NiXoN\Desktop\_thesis\mad_gan_thesis\latex\master_thesis\abb\strat_classifier_performance\val_accuracy_MADGAN_FASHION_n_gen_7_all.png
CURRENT GENERATOR: 10
Plot saved to C:\Users\NiXoN\Desktop\_thesis\mad_gan_thesis\latex\master_thesis\abb\strat_classifier_performance\val_accuracy_MADGAN_FASHION_n_gen_10_all.png
CURRENT METRIC: val_f1_score
CURRENT GENERATOR: 3
Plot saved to C:\Users\NiXoN\Desktop\_thesis\mad_gan_thesis\latex\master_thesis\abb\strat_classifier_performance\val_f1_score_MADGAN_FASHION_n_gen_3_all.png
CURRENT GENERATOR: 5
Plot saved to C:\Users\NiXoN\

In [None]:
history.keys()
