In [4]:
import os
import json
import pymc as pm
import arviz as az
import numpy as np
import matplotlib.pyplot as plt

In [None]:
baseDir = r""
output_dir = r""

# get data from JSON files
def loadReportJsons():
    trainedModelData = []
    for root, dirs, files in os.walk(baseDir):
        for file in files:
            if file.endswith(".json"):
                fullPath = os.path.join(root, file)
                with open(fullPath, 'r') as f:
                    data = json.load(f)
                    trainedModelData.append(data)
    return trainedModelData

trainedModelData = loadReportJsons()

In [None]:
# get relevant info from trained model data
def getCriteriaData(trainedModelData, maxEpoch=20):
    finalValAcc = trainedModelData["validation_accuracy"]
    epochsData = trainedModelData.get("epochs_data", [])
    usedEpochs = len(epochsData)
    wasEarlyStop = (usedEpochs < maxEpoch)

    # calc training gradient in last x epochs to get rate of increase in val acc
    def getGradient(epochsData, window=5):
        if len(epochsData) < maxEpoch:
            return 0.0
        slicedEpochsData = epochsData[-window:]
        diffs = []
        for i in range(len(slicedEpochsData) - 1):
            accDiff = slicedEpochsData[i+1]["val_accuracy"] - slicedEpochsData[i]["val_accuracy"]
            diffs.append(accDiff)
        return sum(diffs)/len(diffs) if diffs else 0.0

    gradient = getGradient(epochsData, window=5)

    return {
        "finalValAcc": finalValAcc,
        "usedEpochs": usedEpochs,
        "wasEarlyStop": wasEarlyStop,
        "gradient": gradient,
    }

def getCriteriaDataForAllModels(trainedModelData, maxEpoch=20):
    criteriaDataForAllModels = []
    for currModelData in trainedModelData:
        criteriaData = getCriteriaData(currModelData, maxEpoch)
        criteriaDataForAllModels.append(criteriaData)
    return criteriaDataForAllModels

criteriaDataForAllModels = getCriteriaDataForAllModels(trainedModelData, maxEpoch=20)
for i, j in enumerate(criteriaDataForAllModels):
    print(f"Model {i}: {j}")


In [None]:
def adjustForCriteria(finalAccuracy, wasEarlyStop, gradient):
    adj = finalAccuracy
    if wasEarlyStop:
        adj -= gradient * 5 
    adj += gradient * 5
    return adj

def bayesianModelSelection(criteriaDataForAllModels):
    observedVals = []
    for criteriaData in criteriaDataForAllModels:
        obs = adjustForCriteria(
            criteriaData["finalValAcc"], 
            criteriaData["wasEarlyStop"], 
            criteriaData["gradient"]
        )
        observedVals.append(obs)
    observedVals = np.array(observedVals)  # shape (N,)

    N = len(observedVals)
    
    with pm.Model() as model:
        # Define a prior theta for the true final accuracy of each model (theta[i])
        theta = pm.Uniform("theta", lower=0, upper=100, shape=N)

        # Prior noise - standard deviation for the observation noise
        sigma = pm.HalfNormal("sigma", sigma=10)    

        # Likelihood - observed accuracy modelled as noisy(sigma) measurement of true accuracy(theta)
        pm.Normal("obs", mu=theta, sigma=sigma, observed=observedVals)

        # MCMC to sample from the posterior distributions of theta and sigma
        trace = pm.sample(2000, tune=1000, chains=1, cores=6, random_seed=42)
    
    return trace

# Get evaluation metrics (e.g., val accuracy, gradient, early stopping) for all models
criteriaDataForAllModels = getCriteriaDataForAllModels(trainedModelData, maxEpoch=20)

# estimate true final accuracy (posterior) for all models
trace = bayesianModelSelection(criteriaDataForAllModels)

# Summarise the posterior distributions (mean, sd, etc.) for theta and sigma
summary = az.summary(trace, var_names=["theta","sigma"]) # get posterior summary table with stats for 'theta', 'sigma'
print(summary)

thetaSamples = trace.posterior["theta"].values  # shape: (chain, draw, N)
thetaMeans = thetaSamples.mean(axis=(0,1))      # shape (N,) (averaging over chains and draws)
bestModelIdx = np.argmax(thetaMeans)+1          # Identify the best model based on highest posterior mean accuracy
print("Best model idx =", bestModelIdx, "with posterior mean accuracy =", thetaMeans[bestModelIdx])

In [None]:
# Sort models by mean descending and select top 5
sortedModels = np.argsort(thetaMeans)[::-1]
top5SortedModels = sortedModels[:5]

meansof5Models = []
lower95of5Models = []
upper95of5Models = []

print(thetaSamples.shape)
# Flatten chain/draw from 3d to 2d
thetaSamples2d = thetaSamples.reshape(-1, thetaSamples.shape[-1])  # shape (chain*draw, N)

for i in top5SortedModels:
    currSamples = thetaSamples2d[:, i]           # all posterior samples for model i
    currMean = np.mean(currSamples)
    lowerCredibalInterval, upperCredibalInterval = np.percentile(currSamples, [2.5, 97.5])
    meansof5Models.append(currMean)
    lower95of5Models.append(lowerCredibalInterval)
    upper95of5Models.append(upperCredibalInterval)

plt.figure(figsize=(4.5, 3))  # Smaller figure size for compact display

yIntervals = np.arange(len(top5SortedModels))

lowerError = np.array(meansof5Models) - np.array(lower95of5Models)
upperError = np.array(upper95of5Models) - np.array(meansof5Models)

plt.errorbar(
    x=meansof5Models,
    y=yIntervals,
    xerr=[lowerError, upperError],
    fmt='o',
    color='blue',
    ecolor='gray',
    capsize=3
)

# Bigger font size for labels
plt.yticks(yIntervals, [f"Model {idx+1}" for idx in top5SortedModels], fontsize=10)
plt.gca().invert_yaxis()
plt.xlabel("Posterior Mean Accuracy\n(95% credible interval)", fontsize=11)
plt.title("Top 5 Models by Accuracy\n3-channel on Original Labels", fontsize=12)
plt.grid(True, axis='x', linestyle='--', alpha=0.5)

# Bigger tick labels
plt.xticks(fontsize=10)
plt.yticks(fontsize=10)

plt.tight_layout()

# Annotate top model value with bigger text
top_model_y = 0
top_model_mean = meansof5Models[0]
plt.text(top_model_mean + 1.6, top_model_y + 0.6, f"{top_model_mean:.2f}%", 
         ha='center', va='bottom', fontsize=10, color='green')

# Save and show
output_path = f"{output_dir}/top5Models_3chnl_0_orig.png"
os.makedirs(output_dir, exist_ok=True)
plt.savefig(output_path, dpi=300)
print(f"Plot saved to {output_path}")
plt.show()