In [5]:
import numpy as np 
from matplotlib import pyplot as plt 

plt.rcParams.update({
    "text.usetex": True,            # Use LaTeX for all text
    "font.family": "serif",         # Use serif font
    "font.serif": ["Computer Modern Roman"],  # LaTeX default
    "axes.labelsize": 12,
    "font.size": 12,
    "legend.fontsize": 10,
    "xtick.labelsize": 10,
    "ytick.labelsize": 10,
})

import os 
from pathlib import Path 

import re 


In [11]:
FOLDER = "FASHION_MNIST_STRATIFIED_CLASSIFIERS_MADGAN"
# FOLDER = "MNIST_STRATIFIED_CLASSIFIERS_MADGAN"

TYPE = "MADGAN"
DATASET = "FASHION"


base_path = Path("C:/Users/NiXoN/Desktop/_thesis/mad_gan_thesis/notebooks")
strat_exp_path = base_path / 'experiments' / FOLDER 

experiments = os.listdir(strat_exp_path)

experiments[0]

'2025-02-28_Stratified_classifierExperiment_FASHIONMNIST__10_used_generator_0__images_real_0_gen_5000'

In [12]:
cmap = plt.get_cmap('tab10')  # good for distinct colors (up to 10)
colors = [cmap(i) for i in range(10)] + [(0.5, 0.5, 0.5, 1.0)]  # Add gray for the 11th color

color_dict_by_generator = {str(i): colors[i] for i in range(11)}


def extract_info_from_experiment_name(exp: str) -> dict: 
    splt = exp.split('_')
    ret = {'dataset': splt[3], 'n_gen': splt[5], 'used_gen': splt[8], 'n_real': splt[-3], 'n_fake': splt[-1]}

    if any([v == '' for v in ret.values()]): 
        print("ALARM")
        print(exp)

    return ret

def sort_dict_based_on_n_real_images(d: dict, reverse=True) -> dict: 
    return dict(sorted(
        d.items(),
        key=lambda x: int(re.search(r'images_real_(\d+)', x[0]).group(1)), 
        reverse=reverse
    ))


def plot_history_strat_classifiers(histories: dict, meta_info: dict, save_path: Path, show: bool = False) -> None: 
    mi = {'': [100]}
    ma = {'': [0]}
    
    all_vals = []
    
    for exp, hist in histories.items(): 
        dataset, n_gen, used_gen, n_real, n_fake = meta_info[exp].values()
        
        plt.plot(
            hist['val_accuracy'], 
            color=(0, 0, 0, 0.1)  # Light grey for individual runs
        )
        
        all_vals.append(hist['val_accuracy'])
        
        if hist['val_accuracy'][-1] > list(ma.values())[-1][-1]: 
            ma = {exp: hist['val_accuracy']}
            
        if hist['val_accuracy'][-1] < list(mi.values())[-1][-1]: 
            mi = {exp: hist['val_accuracy']}
    
    # Pad shorter histories for averaging
    max_len = max(map(len, all_vals))
    all_vals_padded = [np.pad(a, (0, max_len - len(a)), constant_values=np.nan) for a in all_vals]
    
    avg = np.nanmean(all_vals_padded, axis=0)
    med = np.nanmedian(all_vals_padded, axis=0)
    
    # CUD color palette
    color_min = '#D55E00'  # Rust (worst)
    color_max = '#009E73'  # Teal (best)
    color_avg = '#0072B2'  # Blue (average)
    color_med = '#E69F00'  # Golden (median)
    
    # Plot min run
    dataset, n_gen, used_gen, n_real, n_fake = meta_info[list(mi.keys())[0]].values()
    plt.plot(list(mi.values())[-1], color=color_min, linewidth=2, label=f"Minimum, gen: {used_gen}, N-real: {n_real}, N-fake: {n_fake}")
    
    # Plot max run
    dataset, n_gen, used_gen, n_real, n_fake = meta_info[list(ma.keys())[0]].values()
    plt.plot(list(ma.values())[-1], color=color_max, linewidth=2, label=f"Maximum, gen: {used_gen}, N-real: {n_real}, N-fake: {n_fake}")
    
    # Plot average
    plt.plot(avg, color=color_avg, linewidth=2, label="Average")
    
    # Plot median
    plt.plot(med, color=color_med, linewidth=2, linestyle='--', label="Median")
    
    plt.title(f"Validation Accuracy - Dataset: {dataset}, N-Gen: {n_gen}")
    plt.legend()
    plt.tight_layout()

    
    if show:
        print(show)
        plt.show()
    else: 
        plt.savefig(save_path)

    plt.close()

### NOTE: 
there are only experiments with 3, 5, 7, 10 generators 

In [14]:
# load all histories: 

experiments_by_used_gen = {}

histories = {}
meta_info = {}

for target_gen in [3, 5, 7, 10]:

    print(f"CURRENT GENERATOR: {target_gen}")

    for exp in experiments: 
    
        meta = extract_info_from_experiment_name(exp)
    
        if meta['n_gen'] == str(target_gen):
            history = np.load(Path(strat_exp_path) / exp / 'training_history.npy', allow_pickle=True).item()
            histories[exp] = history
            meta_info[exp] = meta

        else: 
            continue
            

    
    histories = sort_dict_based_on_n_real_images(histories)
    meta_info = sort_dict_based_on_n_real_images(meta_info)

    if histories: 
        plot_history_strat_classifiers(
            histories, 
            meta_info, 
            Path("C:\\Users\\NiXoN\\Desktop\\_thesis\\mad_gan_thesis\\latex\\master_thesis\\abb\\strat_classifier_performance") / f"{TYPE}_{DATASET}_n_gen_{target_gen}_all.png",     
        )

    

CURRENT GENERATOR: 3
CURRENT GENERATOR: 5
CURRENT GENERATOR: 7
CURRENT GENERATOR: 10
