# Analytical Calculation

In [None]:
import math
import importlib
import os
import random
import pickle
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import pearsonr, spearmanr
from statsmodels.stats import multitest
import statsmodels.api as sm
import pandas as pd
from matplotlib.lines import Line2D
from matplotlib.colors import ListedColormap
from matplotlib.ticker import MultipleLocator, FuncFormatter
from collections import defaultdict
from scipy.stats import linregress
import torch
from AllFnc import eegvislib

import warnings
warnings.filterwarnings("ignore", category = FutureWarning)

# SET PATH
sep         = os.path.sep
dataPath    = '/data/delpup/datasets/eegpickle/'
osPath      = os.path.abspath(os.getcwd())
imgPath     = osPath + sep + 'imgs' + sep
modelsPath  = osPath + sep + 'AlzClassification' + sep + 'Models' + sep
resultsPath = osPath + sep + 'AlzClassification' + sep + 'Results' + sep
numericPath = osPath + sep + 'AnalyticalCalculation' + sep

# IMGS OPTIONS
imgs_format = '.pdf'
save_img    = True

outerFolds  = 10
innerFolds  = 5
classlabels = ['CTL', 'FTD', 'AD']

## Analysis First Layer Initialization

In [None]:
def weight_change_load(path, split, mod, model, metric):
    """
    Load weight change data and a specific performance metric from pickle files.

    Parameters:
    -----------
    - path (str): Directory path for the model files.
    - split (tuple): Tuple specifying outer and inner folds as (outerFold, innerFold).
    - mod (str): Modifier string to identify the file.
    - model (str): Model identifier.
    - metric (str): Performance metric to load.

    Returns:
    -----------
    - dw (numpy.ndarray): Flattened weight change array scaled by 100.
    - bacc (float): Performance metric value scaled by 100.
    """
    outerFold, innerFold = split

    # Load and process weight change data
    with open(f'{path}{model}_dw_{outerFold}_{innerFold}_{mod}.pkl', 'rb') as file:
        dw_preinit = pickle.load(file)
    dw = np.ravel(dw_preinit.numpy()) * 100  # Flatten and scale

    # Load and process performance metric data
    with open(f'{path}{model}_scores_{outerFold}_{innerFold}_{mod}.pkl', 'rb') as file:
        bacc = pickle.load(file)[metric]
    bacc *= 100  # Scale the performance metric

    return dw, bacc

def weight_change_plot(path, mod, model, splits, spec_dict):
    """
    Plot the histogram of weight changes across specified splits.

    Parameters:
    -----------
    - path (str): Folder to path used by the function weight_change_load.
    - mod (str): Modifier string to identify the file.
    - model (str): Model identifier.
    - splits (list): List of split tuples [(outerFold1, innerFold1), (outerFold2, innerFold2), ...].
    - spec_dict (dict): A dictionary containing various specification parameters 
    and other plotting configurations.

    Returns:
    --------
    - fig (matplotlib.figure.Figure): The generated plot figure.
    """
    fig, ax = plt.subplots(figsize=(spec_dict['figdim'][0], spec_dict['figdim'][1]))
    ax.set_title(spec_dict['title'], fontsize=spec_dict['font'])

    for split in splits:
        # Load metrics for each split
        dw_random, bacc_random = weight_change_load(path, split, mod, model, spec_dict['metric'])
        
        # Clip weight change values to specified limits
        dw_random = np.clip(dw_random, -spec_dict['clip'], spec_dict['clip'])

        # Compute histogram counts and relative frequencies
        counts, bin_edges = np.histogram(dw_random, bins=spec_dict['bins'])
        relative_counts = counts * 100 / len(dw_random)  # Convert to percentage

        # Plot histogram as a bar chart with relative counts
        ax.bar(bin_edges[:-1], relative_counts, width=np.diff(bin_edges), alpha=spec_dict['alpha'],
               label=f'Split {split[0]}-{split[1]}', align='edge')

    # Set axis labels and customize ticks
    ax.set_ylabel('Relative Count %', fontsize=spec_dict['font'] - 2)
    ax.set_xlabel(r'$\Delta_{w\%} = (w - w_{init})/|w|$', fontsize=spec_dict['font'] - 2)
    ax.set_xlim(-spec_dict['clip'], spec_dict['clip'])
    ax.set_xticks(np.arange(-spec_dict['clip'], spec_dict['clip'] + spec_dict['ticks'][0], spec_dict['ticks'][0]))
    ax.set_yticks(np.arange(0, np.max(relative_counts) + spec_dict['ticks'][1], spec_dict['ticks'][1]))
    ax.tick_params(axis='both', labelsize=spec_dict['font'] - 4)
    ax.legend(loc=spec_dict['loc'], fontsize=spec_dict['font'] - 6)

    return fig

In [None]:
# Example parameters
mod = 'random'
model = 'shn0'
splits = [[9, 2], [6, 5]]
spec_dict = {
    'font': 16,
    'clip': 50,
    'bins': 40,
    'alpha': 0.5,
    'ticks': [10, 2],
    'title': 'Weights Change in the First Layer ShallowNet',
    'figdim': [7, 5],
    'loc': 'upper right',
    'metric': 'accuracy_weighted',
    'filename': 'weightChangeFirstLayer'
}

# Generate and optionally save the plot
fig = weight_change_plot(numericPath + 'Initialization' + sep, mod, model, splits, spec_dict)

if save_img:
    fig.savefig(imgPath + spec_dict['filename'] + imgs_format,
                transparent=False, bbox_inches='tight')

## Analysis Topographies Correlation

In [None]:
def pvals_calculation(scalps, indexs, spec_dict):
    """
    Calculate the pairwise p-values and correlation statistics between filter activations across specified scalp maps.
    Adjusts p-values for multiple comparisons and returns the corrected p-value matrix and a significance matrix.

    Parameters:
    ----------
    - scalps (list of np.ndarray): List of 4D arrays representing the scalp maps. 
    Each array should have a shape (filters, kernels=1, hight=channels, width=1)    
    - indexs (list of int): Indices for two specific scalp maps in `scalps` to compare.
    - spec_dict (dict): A dictionary containing various specification parameters 
    and other plotting configurations.
    
    Returns:
    -------
    - corrected_pvals (np.ndarray): Symmetric matrix of corrected p-values for the pairwise filter comparisons.
    - significance_matrix (np.ndarray): Boolean matrix indicating where p-values are significant after correction.
    - stats (np.ndarray): Symmetric matrix of pearson coefficients for the pairwise filter comparisons.
    """
    
    filters = scalps[0].shape[0]  # Number of filters in each scalp map
    pvals = np.zeros((filters, filters))  # Matrix to store p-values
    stats = np.zeros((filters, filters))  # Matrix to store correlation statistics

    # Calculate pairwise Pearson correlations between filters in specified scalp maps
    for i in range(filters):
        for j in range(filters):
            res = pearsonr(scalps[indexs[0]][i, 0, :, 0], scalps[indexs[1]][j, 0, :, 0])
            pvals[i, j] = res.pvalue
            stats[i, j] = res.statistic

    # Extract the upper triangle of p-values for multiple testing correction
    upper_tri_indices = np.triu_indices(filters)
    flattened_pvals = pvals[upper_tri_indices].flatten()

    # Apply multiple testing correction (e.g., Bonferroni, FDR)
    corrected_res = multitest.multipletests(flattened_pvals, alpha=spec_dict['pval'], method=spec_dict['method'])
    corrected_pvals_flat = corrected_res[1]  # Corrected p-values
    significance_matrix_flat = corrected_res[0]  # Boolean array for significance

    # Reconstruct matrices with corrected p-values and significance flags in original dimensions
    corrected_pvals = np.zeros((filters, filters))
    significance_matrix = np.zeros((filters, filters), dtype=bool)

    # Map corrected values to the upper triangle
    corrected_pvals[upper_tri_indices] = corrected_pvals_flat
    significance_matrix[upper_tri_indices] = significance_matrix_flat

    # Reflect upper triangle to lower triangle for symmetry
    corrected_pvals = corrected_pvals + corrected_pvals.T - np.diag(np.diag(corrected_pvals))
    significance_matrix = significance_matrix | significance_matrix.T

    return corrected_pvals, significance_matrix, stats


def pvals_plot(corrected_pvals, significance_matrix, stats, splits, indexs, spec_dict):
    """
    Plot the significance matrix with annotated corrected p-values for significant comparisons.

    Parameters:
    ----------
    - corrected_pvals (np.ndarray): Symmetric matrix of corrected p-values for filter comparisons.
    - significance_matrix (np.ndarray): Boolean matrix indicating where p-values are significant.
    - stats (np.ndarray): Symmetric matrix of pearson coefficients for the pairwise filter comparisons.
    - splits (list of lists): List of split identifiers. 
    Each identifier is a list of integers representing the split structure.
    - indexs (list of int): Indices for two specific scalp maps from `splits` to label in the plot.
    - spec_dict (dict): A dictionary containing various specification parameters 
    and other plotting configurations.

    Returns:
    -------
    fig (matplotlib.figure.Figure): The resulting figure with the significance matrix plot.
    """
    
    filters = np.shape(corrected_pvals)[0]  # Number of filters

    # Initialize plot
    fig, ax = plt.subplots(figsize=(spec_dict['figdim'][0], spec_dict['figdim'][1]))
    ax.imshow(significance_matrix, cmap='Greys')  # Display significance matrix
    ax.set_ylabel(f'Split {splits[indexs[0]][0]}-{splits[indexs[0]][1]}', fontsize=spec_dict['font']-2)
    ax.set_xlabel(f'Split {splits[indexs[1]][0]}-{splits[indexs[1]][1]}', fontsize=spec_dict['font']-2)
    ax.set_xticks(np.arange(0, filters, 1))
    ax.set_xticklabels(spec_dict['xticks'], fontsize=spec_dict['font']-4)
    ax.set_yticks(np.arange(0, filters, 1))
    ax.set_yticklabels(spec_dict['xticks'], fontsize=spec_dict['font']-4)
    ax.set_title(spec_dict['title'], fontsize=spec_dict['font'])

    # Annotate significant p-values on the plot
    for i in range(filters):
        for j in range(filters):
            if significance_matrix[i, j]:  # Only annotate cells with significant p-values
                adj_R2 = eegvislib.adjusted_R2(stats[i,j], spec_dict['C'], 1)
                ax.text(j, i, f'p$={corrected_pvals[i, j]:.1e}$ \n $\\rho({spec_dict['C']})={stats[i,j]:.2f}$ \n adj.$R^2={adj_R2:.2f}$', 
                        ha='center', va='center', color='white', fontsize=spec_dict['font']-8)

    return fig

In [None]:
splits = [[9, 2], [6, 5]]
modelsToimport = ['alz_flt_125_shn7db_009_002_000050_019_004','alz_flt_125_shn7db_006_005_000050_019_004']

spec_dict = {'pval': 0.05,
             'method': 'holm',
             'xticks': ['$\\delta$', '$\\theta$', '$\\alpha$', '$\\beta_1$',
                       '$\\beta_2$', '$\\beta_3$', '$\\gamma$'],
             'C': 19,
             'title': f'Significance Correlation (Holm corrected)',
             'font': 16,
             'figdim': [7, 7]}

scalps = []
for modelToimport in modelsToimport:
    shnm = torch.load(modelsPath + modelToimport + '.pt')    
    scalps.append(shnm['encoder.conv2.weight'].numpy())

corrected_pvals, significance_matrix, stats = pvals_calculation(scalps, [0,1], spec_dict)
fig = pvals_plot(corrected_pvals, significance_matrix, stats, splits, [0,1], spec_dict)

In [None]:
corrected_pvals, significance_matrix, stats = pvals_calculation(scalps, [0,0], spec_dict)
fig = pvals_plot(corrected_pvals, significance_matrix, stats, splits, [0,0], spec_dict)

In [None]:
corrected_pvals, significance_matrix, stats = pvals_calculation(scalps, [1,1], spec_dict)
fig = pvals_plot(corrected_pvals, significance_matrix, stats, splits, [1,1], spec_dict)

## Analysis Overlap Embedding

In [None]:
model = 'shn7db'
overlap_trainval  = []
overlap_traintest = [] 
overlap_valtest   = []
area_train = []
area_val = []
area_test = []
acc = []
for i in range(1,outerFolds+1):
    for j in range(1,innerFolds+1):
        file_suffix = f'00{i}_00{j}' if i != 10 else f'0{i}_00{j}'
        filename = f'alz_flt_125_{model}_{file_suffix}_000050_019_004'
        with open(numericPath + 'Overlap' +sep + 'Results' +sep + filename + '.pickle', 'rb') as f:
            shn = pickle.load(f)
        overlap_traintest.append(shn[0])
        overlap_trainval.append(shn[1])
        overlap_valtest.append(shn[2])
        area_train.append(shn[3])
        area_val.append(shn[4])
        area_test.append(shn[5])
        with open(resultsPath + filename + '.pickle', 'rb') as f:
            shnR = pickle.load(f)
        acc.append(shnR['accuracy_weighted'])

In [None]:
# Variables to analyze
variables = {
    "overlap_traintest": overlap_traintest,
    "overlap_trainval": overlap_trainval,
    "overlap_valtest": overlap_valtest,
    "area_train": area_train,
    "area_val": area_val,
    "area_test": area_test,
}

# Step 1: Calculate pairwise correlations and p-values
pairs = []
corrs = []
p_values = []

keys = list(variables.keys())
for i, key1 in enumerate(keys):
    for j, key2 in enumerate(keys):
        if i < j:  # Avoid duplicate pairs and self-correlation
            corr, p_val = spearmanr(variables[key1], variables[key2])
            pairs.append((key1, key2))
            corrs.append(corr)
            p_values.append(p_val)

# Step 2: Apply Holm-Bonferroni correction
corrected_res = multitest.multipletests(p_values, method='holm')

# Step 3: Identify pairs with corrected p-values > 0.05
uncorrelated_pairs = [
    pairs[i] for i, p_val in enumerate(corrected_res[0]) if p_val > 0.05
]

# Identify uncorrelated variables
uncorrelated_vars = set()
for pair in uncorrelated_pairs:
    uncorrelated_vars.update(pair)
uncorrelated_vars = list(uncorrelated_vars)

print("Uncorrelated variable pairs (Holm-Bonferroni corrected p > 0.05):")
for pair in uncorrelated_pairs:
    print(pair)

print("\nIndependent variables after correction:")
print(uncorrelated_vars)

# Step 4: Regression Analysis
# Prepare feature matrix with uncorrelated variables
X_features = np.column_stack([variables[var] for var in uncorrelated_vars])
X_features = sm.add_constant(X_features)
y = acc  # Target variable

# Fit the regression model
model = sm.OLS(y, X_features)
results = model.fit()

# Print the regression summary
print("\nRegression Results:")
print(results.summary())

In [None]:
import statsmodels.formula.api as smf
formula = "acc ~ area_test + overlap_traintest + area_test*overlap_traintest"

# Step 4: Fit the model using the formula
model = smf.ols(formula=formula, data=pd.DataFrame(variables)).fit()
print("\nRegression Results:")
print(model.summary())

## Analysis Model Variation

In [None]:
models = ["shn0", "shn1", "shn2", "shn3", "shn4", "shn5", "shn67",
          "shn628","shn663","shn6119","shn6126","shn6127","shn7","shn7db"]

# Set of models for which a specific condition applies
model_set = {'shn0', 'shn1', 'shn2', 'shn3'} #These models has not the first layer frozen

spec_dict = {'figname': 'modelVarAccvsWeight',
             'font': 18,
             'linew': 2,
             's':4,
             'rotation': 45,
             'loc': 'upper center',
             'title': 'Architecture Variation - Performance',
             'accrandom': 1/len(classlabels),
             'linestyle': '--',
             'markers': ['X','o'],
             'ylima': [20,105],
             'ylimw': [10**1,10**6],
             'jitter': 0.02,
             'figdim': [15,5],
             'color': ['tab:orange','tab:blue']}

In [None]:
# Loop over each model
total_weights_v = []
for model in models:
    filename = f'alz_flt_125_{model}_001_001_000050_019_004'
    shn = torch.load(modelsPath + filename + '.pt')

    total_weights = 0
    # Iterate through the state_dict (model weights)
    for key, tensor in shn.items():
        # Skip specific keys related to batch normalization
        if key in ['encoder.batch1.running_mean', 'encoder.batch1.running_var', 'encoder.batch1.num_batches_tracked']:
            continue
        
        # Add the tensor's total number of elements, applying conditions for specific models
        if model in model_set:
            total_weights += tensor.numel()
        elif key not in ['encoder.conv1.weight', 'encoder.conv1.bias']:
            total_weights += tensor.numel()
    
    total_weights_v.append(total_weights)

sorted_pairs = sorted(zip(total_weights_v, models))
stored = sorted_pairs[0]
sorted_pairs[0] = sorted_pairs[1]
sorted_pairs[1] = stored

# Unzip the sorted pairs to get sorted vector1
models = [v for _, v in sorted_pairs]
total_weights_v = [v for v, _ in sorted_pairs]

models1 = [sorted_pairs[i][1] for i in range(len(models))]
models1[9] = 'ShallowNet'
models1[0] = 'Med-ShallowNet'

In [None]:
acc_dict = defaultdict(list)
acc_train = {}
acc_val = {}
# Precompute accuracy values for each model (only once)
for model in models:
    acct = []  # Reset accuracy list for each model
    accv = []
    for i in range(1, outerFolds+1):  # Loop over first parameter
        for j in range(1, innerFolds+1):  # Loop over second parameter
            file_suffix = f'00{i}_00{j}' if i != 10 else f'0{i}_00{j}'
            filename = f'alz_flt_125_{model}_{file_suffix}_000050_019_004'

            # Load accuracy data from pickle file (avoid reloading for each iteration)
            with open(resultsPath + filename + '.pickle', 'rb') as f:
                shn = pickle.load(f)

            # Append the accuracy to the list
            acc_dict[model].append(shn['accuracy_weighted'] * 100)
            acct.append(shn['training_loss_curve'])
            accv.append(shn['validation_loss_curve'])

    # Append the accuracy to the list
    acc_train[model] = acct
    acc_val[model] = accv
    print(f'{model} QCV: {eegvislib.quartile_coefficient_of_variation(acc_dict[model]):.4f}')

### Model Weights vs Accuracy

In [None]:
fig, ax1 = plt.subplots(1, 1, 
                        figsize=(spec_dict['figdim'][0], spec_dict['figdim'][1]))
ax2 = ax1.twinx()

# Plot the weights vs models (on the second y-axis, ax2)
ax2.plot(range(1, len(models) + 1), total_weights_v, color=spec_dict['color'][0], 
         marker=spec_dict['markers'][0], linestyle=spec_dict['linestyle'], 
         linewidth=spec_dict['linew'], label='Weights')

# Plot boxplots for all models (on the first y-axis, ax1)
box = ax1.boxplot(list(acc_dict.values()), showfliers=False)

# Remove median line from boxplot (using more efficient approach)
for median in box['medians']:
    median.set_visible(False)

# Add scatter points (individual accuracy values) for each model
for i, model in enumerate(models):
    acc = acc_dict[model]
    ax1.scatter(np.ones(len(acc)) * (i + 1) + spec_dict['jitter']*np.random.randn(len(acc)), acc,
                color='k', alpha=1, s=spec_dict['s'])

# Label the first y-axis (for accuracies)
ax1.set_ylabel("Median Weighted Accuracy %", color=spec_dict['color'][1], fontsize=spec_dict['font']-2)
ax1.set_ylim(spec_dict['ylima'][0], spec_dict['ylima'][1])
ax1.set_xlim(0.5, len(models) + 0.75)

# Plot the median value as text on the boxplot for each model
median_v = [np.median(acc_dict[model]) for model in models]
for i, median in enumerate(median_v):
    ax1.text(i + 1, spec_dict['ylima'][0]+2, f'{median:.1f}%', ha='center', va='bottom', 
             color=spec_dict['color'][1], fontsize=spec_dict['font']-6, fontweight='bold')

ax1.plot(np.arange(1, 1 + len(models)), median_v, marker=spec_dict['markers'][1], linestyle=spec_dict['linestyle'],
         color=spec_dict['color'][1], linewidth=spec_dict['linew'])

# Label the second y-axis (for weights)
ax2.set_yscale('log')
ax2.set_ylabel("$N^{\\circ}$ Trainable Parameters", color=spec_dict['color'][0], fontsize=spec_dict['font']-2)
ax2.set_ylim(spec_dict['ylimw'][0], spec_dict['ylimw'][1])

# Add a horizontal line at 33.33 for reference
ax1.axhline(spec_dict['accrandom']*100, linestyle=spec_dict['linestyle'], linewidth=spec_dict['linew'], color='red')

# Add a title to the plot
ax1.set_title(spec_dict['title'], fontsize=spec_dict['font'])

# Create legend
legend_elements = [Line2D([0], [0], marker='.', color='k', label='Single Split', linestyle='None'),
                   Line2D([0], [0], label=f'{spec_dict['accrandom']:.0%} Random Guess', 
                          linestyle=spec_dict['linestyle'], linewidth=spec_dict['linew'], color='red')]
ax1.legend(handles=legend_elements, loc=spec_dict['loc'], fontsize=spec_dict['font']-6)

# Set x-tick labels for models
ax1.set_xticklabels(models1, rotation=spec_dict['rotation'])
ax1.tick_params(axis='both', which='major', labelsize=spec_dict['font']-4)
ax2.tick_params(axis='both', which='major', labelsize=spec_dict['font']-4)

# Label the figure
ax1.text(len(models), spec_dict['ylima'][1]-10, '$(A)$', fontsize=spec_dict['font'])

# Save the image if needed
if save_img:
    fig.savefig(imgPath + spec_dict['figname'] + imgs_format,
                transparent=False, bbox_inches='tight')

In [None]:
# Perform linear regression
slope, intercept, r_value, p_value, std_err = linregress(np.log(total_weights_v), median_v)

adj_R2 = eegvislib.adjusted_R2(r_value, len(median_v), 1)

# Print the results
print(f' r({len(median_v)-2:.0f})={r_value:.2f}, p={p_value:.3f}, Adj. R^2={adj_R2:.2f}')

### Model Weights vs Number of Epochs

In [None]:
spec_dict['figname'] = 'modelVarEpochsvsWeight'
spec_dict['title']   = "Architecture Variation - Training Length"
spec_dict['color']   = ['tab:orange','tab:green']
spec_dict['ylime']   = [-15, 200]
patience = 15

fig, ax1 = plt.subplots(1, 1, 
                        figsize=(spec_dict['figdim'][0], spec_dict['figdim'][1]))

# Create a secondary y-axis for weights
ax2 = ax1.twinx()

# Plot weights on the secondary y-axis (ax2)
ax2.plot(range(1, len(models) + 1), total_weights_v, color=spec_dict['color'][0], marker=spec_dict['markers'][0], 
         linestyle=spec_dict['linestyle'], linewidth=spec_dict['linew'], label='Weights')

# Precompute n_epoch for each model
n_epoch = {
    model: [len(acc_train[model][i]) - patience + 1 for i in range(int(outerFolds*innerFolds))]
    for model in models}

# Boxplot for n_epoch
box = ax1.boxplot([n_epoch[model] for model in models], labels=models, showfliers=False)

# Remove median lines from the boxplot
for median in box['medians']:
    median.set_visible(False)

# Scatter plot individual epoch values with slight random offsets on the x-axis
for i, model in enumerate(models):
    nepoch = n_epoch[model]
    ax1.scatter(np.ones(len(nepoch)) * (i + 1) + spec_dict['jitter'] * np.random.randn(len(nepoch)), nepoch,
                color='k', alpha=1, s=spec_dict['s'])

# Label the first y-axis (for epochs)
ax1.set_ylabel("Median $~N^{\\circ}$ Epochs", color=spec_dict['color'][1], fontsize=spec_dict['font']-2)
ax1.set_ylim(spec_dict['ylime'][0],spec_dict['ylime'][1])
ax1.set_xlim(0.5, len(models) + 0.75)

# Compute and plot the median values
median_v = [np.median(n_epoch[model]) for model in models]
for i, median in enumerate(median_v):
    ax1.text(i + 1, spec_dict['ylime'][0]+2, f'{median:.0f}', ha='center', va='bottom', 
             color=spec_dict['color'][1], fontsize=spec_dict['font']-6, fontweight='bold')

ax1.plot(np.arange(1, 1 + len(models)), median_v, marker=spec_dict['markers'][1], 
         linestyle=spec_dict['linestyle'], color=spec_dict['color'][1], linewidth=spec_dict['linew'])

# Label the second y-axis (for weights)
ax2.set_yscale('log')
ax2.set_ylabel("$N^{\\circ}$ Trainable Parameters", color=spec_dict['color'][0], fontsize=spec_dict['font']-2)
ax2.set_ylim(spec_dict['ylimw'][0],spec_dict['ylimw'][1])

# Add x-tick labels for models
ax1.set_xticklabels(models1, rotation=spec_dict['rotation'])

# Add title and other plot decorations
ax1.set_title(spec_dict['title'], fontsize=spec_dict['font'])
ax1.tick_params(axis='both', which='major', labelsize=spec_dict['font']-4)
ax2.tick_params(axis='both', which='major', labelsize=spec_dict['font']-4)

# Create the legend
legend_elements = [Line2D([0], [0], marker='.', color='k', label='Single Split', linestyle='None')]
ax1.legend(handles=legend_elements, loc='upper center', fontsize=spec_dict['font']-6)

# Annotate the plot with figure label
ax1.text(len(models), spec_dict['ylime'][1]-20, '$(B)$', fontsize=spec_dict['font'])

if save_img:
    fig.savefig(imgPath + spec_dict['figname'] + imgs_format,
                transparent=False, bbox_inches='tight')

In [None]:
# Perform linear regression
slope, intercept, r_value, p_value, std_err = linregress(np.log(total_weights_v), median_v)

adj_R2 = eegvislib.adjusted_R2(r_value, len(median_v), 1)

# Print the results
print(f' r({len(median_v)-2:.0f})={r_value:.2f}, p={p_value:.3f}, Adj. R^2={adj_R2:.2f}')

## Analysis Loss Curves

In [None]:
models      = ['shn0', 'shn5', 'shn6127', 'shn628', 'shn7db']

spec_dict = {}
spec_dict['models']     = ['ShallowNet', 'shn5', 'shn6$_{127}$', 'shn6$_{28}$', 'Med-ShallowNet']
spec_dict['listcolors'] = ['black','tab:brown', 'tab:red', 'tab:orange', 'tab:green']
spec_dict['font']     = 16
spec_dict['figname']  = 'lossCurves'
spec_dict['linew']    = 0.5
spec_dict['figdim']   = [10,6]
spec_dict['titlet']   = 'Train Loss Curves'
spec_dict['titlev']   = 'Validation Loss Curves'
spec_dict['titlec']   = 'Correlation Train-Val Loss'
spec_dict['loc']      = 'upper right'
spec_dict['xlimt']    = [0.9, 300]
spec_dict['ylimt']    = [0.0, 3.0]
spec_dict['xlimv']    = [0.9, 300]
spec_dict['ylimv']    = [0.5, None]
spec_dict['xlimc']    = [-1,1]
spec_dict['ylimc']    = [0,50]
spec_dict['ytick']    = 5
spec_dict['xtick']    = 0.2
spec_dict['bin']      = 0.1
spec_dict['rotation'] = 45

acc_train = {}
acc_val = {}
# Precompute accuracy values for each model (only once)
for model in models:
    acct = []  # Reset accuracy list for each model
    accv = []
    for i in range(1, outerFolds+1):  # Loop over first parameter
        for j in range(1, innerFolds+1):  # Loop over second parameter
            file_suffix = f'00{i}_00{j}' if i != 10 else f'0{i}_00{j}'
            filename = f'alz_flt_125_{model}_{file_suffix}_000050_019_004'

            # Load accuracy data from pickle file (avoid reloading for each iteration)
            with open(resultsPath + filename + '.pickle', 'rb') as f:
                shn = pickle.load(f)

            # Append the accuracy to the list
            acct.append(shn['training_loss_curve'])
            accv.append(shn['validation_loss_curve'])

    # Append the accuracy to the list
    acc_train[model] = acct
    acc_val[model] = accv

fig = eegvislib.overfitting_inspection(models, acc_train, acc_val, spec_dict)

if save_img:
    fig.savefig(imgPath + spec_dict['figname'] + imgs_format,
                transparent=False, bbox_inches='tight')

## Analysis Dense Layer

In [None]:
model = 'shn7db'
f_lim = [1,45]
spec_dict = {'font':12,
             'linew':1,
             'cmap':'RdBu_r',
             'bands': {'$\\delta$':  [0,4],
                       '$\\theta$':  [4,8],
                       '$\\alpha$':  [8,12],
                       '$\\beta_1$': [12,16],
                       '$\\beta_2$': [16,20],
                       '$\\beta_3$': [20,28],
                       '$\\gamma$': [28,f_lim[1]]},
             'classlabels': classlabels,
             'figdim': [4,1],
             'plot_type': 'second',
             'cmap': 'RdBu_r',
             'figname': 'uniqueConfigDense'}

spec_dict['xticks'] = [i for i in spec_dict['bands'].keys()]

In [None]:
matrix_weights = []
for i in range(1,outerFolds+1):
    for j in range(1,innerFolds+1):
        file_suffix = f'00{i}_00{j}' if i != 10 else f'0{i}_00{j}'
        filename = f'alz_flt_125_{model}_{file_suffix}_000050_019_004'
        shn = torch.load(modelsPath + filename + '.pt')
        fig, mask = eegvislib.denseweights_plot(shn,'Dense',spec_dict)
        matrix_weights.append(mask)
        plt.close()
        
matrix_count = {}
# Count occurrences of each matrix
for matrix in matrix_weights:
    matrix_tuple = tuple(map(tuple, matrix))  # Convert to a hashable format
    if matrix_tuple in matrix_count:
        matrix_count[matrix_tuple] += 1
    else:
        matrix_count[matrix_tuple] = 1

# Convert counts back to a list of matrices
unique_matrices = [np.array(matrix) for matrix in matrix_count.keys()]
counts = list(matrix_count.values())

cmap = ListedColormap(['white', 'red', 'blue', 'green'])

# Now you have a list of unique matrices
fig, ax = plt.subplots(len(unique_matrices),1,
                       figsize=(spec_dict['figdim'][0], spec_dict['figdim'][1]*len(unique_matrices)),
                       constrained_layout=True)
for i in range(len(unique_matrices)):
    ax[i].set_title(f'Number of Splits with this config.: {counts[i]}', fontsize=spec_dict['font'])
    ax[i].imshow(unique_matrices[i], aspect='auto', 
                 cmap=cmap, interpolation='nearest')
    ax[i].set_yticks(np.arange(0, 3))
    ax[i].set_yticklabels(labels=spec_dict['classlabels'], fontsize=spec_dict['font']-4)
    ax[i].set_xticks(np.arange(0, len(spec_dict['xticks']), 1))
    ax[i].set_xticklabels(spec_dict['xticks'], fontsize=spec_dict['font']-4)

if save_img:
    fig.savefig(imgPath + spec_dict['figname'] + imgs_format,
                transparent=False, bbox_inches='tight')

## Analysis Statical Models Comparison

In [None]:
# Function to load accuracy data for each model
def load_accuracy_data(model, path, folds, task='alz', pipe='flt', srate='125', lr='000050', ch='019', w='004', metric='accuracy_weighted'):
    acc = []
    for i in range(folds[0][0], folds[0][1]+1):  # Loop over first parameter
        for j in range(folds[1][0], folds[1][1]+1):  # Loop over second parameter
            file_suffix = f'00{i}_00{j}' if i != 10 else f'0{i}_00{j}'
            filename = f'{task}_{pipe}_{srate}_{model}_{file_suffix}_{lr}_{ch}_{w}'
            with open(path + filename + '.pickle', 'rb') as f:
                shn = pickle.load(f)
            if isinstance(shn, dict):
                acc.append(shn['accuracy_weighted'])
            else:
                acc.append(shn)
    return acc

In [None]:
metrics = {}
folds = [[1,outerFolds],[1,innerFolds]]
task='alz'
pipe='flt'
srate='125'
model = 'shn7db'
lr='000050'
ch='019'
w='004'
accMLRM = []
accshnMLRM = []
path = numericPath + 'StatisticalModels' + sep + 'Results' + sep
for i in range(folds[0][0], folds[0][1]+1):  # Loop over first parameter
    for j in range(folds[1][0], folds[1][1]+1):  # Loop over second parameter
        file_suffix = f'00{i}_00{j}' if i != 10 else f'0{i}_00{j}'
        filename = f'{task}_{pipe}_{srate}_{model}_{file_suffix}_{lr}_{ch}_{w}'
        with open(path + filename + '.pickle', 'rb') as f:
            shn = pickle.load(f)
        accMLRM.append(shn['MLRM_accuracy_weighted'])
        accshnMLRM.append(shn['shnMLRM_accuracy_weighted'])

metrics['MLRM'] = accMLRM
metrics['shnMLRM'] = accshnMLRM

model = 'shn7db'
path = resultsPath
metrics['shn7db'] = load_accuracy_data(model, resultsPath, [[1,outerFolds],[1,innerFolds]])

model = 'shn0'
path = resultsPath
metrics['shn0'] = load_accuracy_data(model, resultsPath, [[1,outerFolds],[1,innerFolds]])

In [None]:
spec_dict = {'figdim': [7,5],
             'jitter': 0.02,
             's': 5,
             'font': 16,
             'rotation': 45,
             'xticks': ['MLRM', 'shnMLRM', 'Med-ShallowNet', 'ShallowNet'],
             'ylim': [20,105],
             'linew':2,
             'accrandom': 1/len(classlabels),
             'marker': 'o',
             'linestyle': '--',
             'color': 'tab:blue',
             'loc': 'upper center',
             'figname': 'statModelComparison'}

def stat_comparison(metrics, spec_dict):
    df = pd.DataFrame(metrics)*100
    # Plot setup
    fig, ax = plt.subplots(figsize=(spec_dict['figdim'][0], spec_dict['figdim'][1]))
    
    box = ax.boxplot(df, showfliers=False, widths=0.6, medianprops=dict(color="none"))
    # Add jittered scatter points
    for i, col in enumerate(df.columns, start=1):
        jittered_x = np.random.normal(i, spec_dict['jitter'], size=len(df[col]))
        ax.scatter(jittered_x, df[col], color='black', alpha=1, s=spec_dict['s'])
    
    # Set labels and ticks
    ax.set_ylabel('Median Weighted Accuracy %', fontsize=spec_dict['font']-2, color='tab:blue')
    ax.set_xticks(range(1, len(df.columns) + 1))
    ax.set_xticklabels(spec_dict['xticks'], rotation=spec_dict['rotation'], fontsize=spec_dict['font']-4)
    ax.set_ylim(spec_dict['ylim'][0], spec_dict['ylim'][1])
    ax.axhline(spec_dict['accrandom']*100, linestyle='--', linewidth=spec_dict['linew'], color='red')
    ax.set_title("Comparison with Multinomial Logistic Regression", fontsize=spec_dict['font'])
    
    # Display median values and connect them with a line
    medians = [np.median(df[col]) for col in df.columns]
    for i, median in enumerate(medians, start=1):
        ax.text(i, spec_dict['ylim'][0]+2, f'{median:.1f}%', ha='center', va='bottom',
                color=spec_dict['color'], fontsize=spec_dict['font']-6, fontweight='bold')
    ax.plot(range(1, len(medians) + 1), medians, marker=spec_dict['marker'], 
            linestyle=spec_dict['linestyle'], color=spec_dict['color'], linewidth=spec_dict['linew'])
    
    # Custom legend
    legend_elements = [Line2D([0], [0], marker='.', color='k', label='Single Split', linestyle='None'),
                       Line2D([0], [0], label=f'{spec_dict['accrandom']:.0%} Random Guess', linestyle=spec_dict['linestyle'],
                              linewidth=spec_dict['linew'], color='red')
                      ]
    
    ax.legend(handles=legend_elements, loc=spec_dict['loc'], fontsize=spec_dict['font']-6)
    return fig

fig = stat_comparison(metrics, spec_dict)
# Save and display
if save_img:
    fig.savefig(imgPath + spec_dict['figname'] + imgs_format,
                transparent=False, bbox_inches='tight')

In [None]:
[ np.mean(i) for i in metrics.values()]