In [28]:
import os
import json
import numpy as np
import glob
from collections import defaultdict

import plotly.graph_objects as go
import pandas as pd
from plotly.subplots import make_subplots

In [18]:
# Path to the directory containing all JSON files
base_path = "/home/users/MTrappett/mtrl/BranchGatingProject/branchNetwork/data/hyper_search/train_and_evaluate_model_2024-04-24_21-05-07"

# Dictionary to collect results by grouping key
results_dict = defaultdict(list)

# Walk through the directories and files
for file_path in glob.glob(f"{base_path}/train_and_evaluate_model_*/result.json"):
    # print(f"Processing {file_path}")
    with open(file_path, 'r') as file:
        # Read all lines in the file (each line is a JSON string)
        data = [json.loads(line) for line in file.readlines()]
        
        # Extract mean accuracies
        accuracies = [entry['mean_accuracy'] for entry in data]
        
        # Assuming the 'config' is the same in all entries, so we take the first one
        config = data[0]['config']
        
        # Create a key to group by model_name, lr, and batch_size
        grouping_key = (config['model_name'], config['lr'], config['batch_size'])
        
        # Append accuracies to the corresponding key
        results_dict[grouping_key].extend(accuracies)


In [19]:
# List to hold final results after grouping
final_results = []

# Process grouped results to calculate means and standard deviations
for key, grouped_accuracies in results_dict.items():
    average_acc = np.mean(grouped_accuracies)
    std_acc = np.std(grouped_accuracies)
    
    result = {
        'model_name': key[0],
        'lr': key[1],
        'batch_size': key[2],
        'ave_mean_accuracy': average_acc,
        'std_mean_accuracy': std_acc
    }
    
    final_results.append(result)


In [20]:

# Optionally print results or handle them otherwise
for res in final_results:
    print(res)


{'model_name': 'SimpleModel', 'lr': 0.1, 'batch_size': 32, 'ave_mean_accuracy': 10.362, 'std_mean_accuracy': 0.6522519962918215}
{'model_name': 'SimpleModel', 'lr': 0.001, 'batch_size': 64, 'ave_mean_accuracy': 96.56316666666666, 'std_mean_accuracy': 0.769957773012405}
{'model_name': 'SimpleModel', 'lr': 0.1, 'batch_size': 256, 'ave_mean_accuracy': 10.713833333333332, 'std_mean_accuracy': 0.6345420702277264}
{'model_name': 'SimpleModel', 'lr': 0.0001, 'batch_size': 128, 'ave_mean_accuracy': 97.32783333333333, 'std_mean_accuracy': 1.2673280181371964}
{'model_name': 'ExpertModel', 'lr': 0.01, 'batch_size': 32, 'ave_mean_accuracy': 16.656166666666667, 'std_mean_accuracy': 4.179941822428739}
{'model_name': 'ExpertModel', 'lr': 0.1, 'batch_size': 32, 'ave_mean_accuracy': 10.357500000000003, 'std_mean_accuracy': 0.655063928178006}
{'model_name': 'ExpertModel', 'lr': 0.0001, 'batch_size': 32, 'ave_mean_accuracy': 97.60333333333332, 'std_mean_accuracy': 0.9263038858219744}
{'model_name': 'Bran

In [34]:
def create_heatmap(data):
    # Convert list of dicts to DataFrame
    df = pd.DataFrame(data)
    
    # Unique models, learning rates, and batch sizes
    models = df['model_name'].unique()
    lrs = sorted(df['lr'].unique(), reverse=True)
    batch_sizes = sorted(df['batch_size'].unique())

    # Create a subplot figure with 2 rows and 2 columns
    fig = make_subplots(rows=2, cols=2, subplot_titles=models[:4])

    # Iterate over the first four models to create a heatmap for each
    for i, model in enumerate(models[:4], 1):
        # Filter data for the current model
        filtered_df = df[df['model_name'] == model]
        
        # Create heatmap data
        z = filtered_df.pivot_table(index='lr', columns='batch_size', values='ave_mean_accuracy').reindex(index=lrs, columns=batch_sizes)
        
        # Determine the position in the subplot
        row = (i-1) // 2 + 1
        col = (i-1) % 2 + 1
        
        # Add heatmap to the appropriate subplot
        fig.add_trace(
            go.Heatmap(
                x=z.columns.astype(str),
                y=z.index.astype(str),
                z=z.values,
                colorbar=dict(title='Mean Accuracy'),
                showscale=(i == len(models[:4]))  # Only show the scale for the last plot
            ),
            row=row, col=col
        )

    # Update layout to better fit subplots
    fig.update_layout(
        height=600,  # Adjust the height
        width=700,  # Adjust the width
        title_text="Model Performance Comparison across Different Hyperparameters"
    )
    
    # fig.update_yaxes(type='log')  # Apply log scale to y-axes

    fig.show()


create_heatmap(final_results)

In [42]:
def create_comparison_plots(data):
    df = pd.DataFrame(data)
    
    # Create two subplots: one for lr vs. mean accuracy, another for batch_size vs. mean accuracy
    fig = make_subplots(rows=1, cols=2, subplot_titles=('Mean Accuracy vs. Learning Rate', 'Mean Accuracy vs. Batch Size'))
    
    # Unique models
    models = df['model_name'].unique()
    
    # Define specific colors for visibility and consistency
    colors = {
        models[0]: 'blue',   # Blue for the first model
        models[1]: 'orange', # Orange for the second model
        models[2]: 'green',  # Green for the third model
        models[3]: 'red'     # Red for the fourth model
    }
    
    # First plot: lr vs. mean accuracy
    for i, model in enumerate(models):
        filtered_df = df[df['model_name'] == model]
        grouped = filtered_df.groupby('lr').agg({'ave_mean_accuracy': 'mean', 'std_mean_accuracy': 'mean'})
        
        fig.add_trace(
            go.Scatter(
                x=grouped.index,
                y=grouped['ave_mean_accuracy'],
                error_y=dict(type='data', array=grouped['std_mean_accuracy'], visible=True),
                mode='lines+markers',
                name=model,
                line=dict(color=colors[model]),
                showlegend=True  # Ensure legend is shown for all models in the first subplot
            ), row=1, col=1
        )
    
    # Second plot: batch_size vs. mean accuracy
    for i, model in enumerate(models):
        filtered_df = df[df['model_name'] == model]
        grouped = filtered_df.groupby('batch_size').agg({'ave_mean_accuracy': 'mean', 'std_mean_accuracy': 'mean'})
        
        fig.add_trace(
            go.Scatter(
                x=grouped.index,
                y=grouped['ave_mean_accuracy'],
                error_y=dict(type='data', array=grouped['std_mean_accuracy'], visible=True),
                mode='lines+markers',
                name=model,
                line=dict(color=colors[model]),
                showlegend=False  # Hide legend for this subplot to avoid repetition
            ), row=1, col=2
        )
    
    # Update layout
    fig.update_layout(height=600, width=1200, title_text="Comparison of Mean Accuracy Across Models")
    
    # Update axes
    fig.update_xaxes(title_text="Learning Rate", type='category', row=1, col=1)
    fig.update_yaxes(title_text="Mean Accuracy", row=1, col=1)
    fig.update_xaxes(title_text="Batch Size", type='category', row=1, col=2)
    fig.update_yaxes(title_text="Mean Accuracy", row=1, col=2)
    
    fig.show()
    
create_comparison_plots(final_results)

In [43]:
def find_best_hyperparameters(data):
    # Convert the list of dictionaries into a DataFrame
    df = pd.DataFrame(data)
    
    # Group by 'lr' and 'batch_size', calculate the average mean accuracy across all models
    grouped = df.groupby(['lr', 'batch_size'])['ave_mean_accuracy'].mean().reset_index()
    
    # Find the row with the highest average mean accuracy
    best_combination = grouped[grouped['ave_mean_accuracy'] == grouped['ave_mean_accuracy'].max()]
    
    # If there are multiple rows with the same max value, you can decide how to handle (e.g., take the first, average them, etc.)
    best_combination = best_combination.iloc[0] if not best_combination.empty else None
    
    # Return the best combination as a dictionary if found
    return {
        'lr': best_combination['lr'],
        'batch_size': best_combination['batch_size'],
        'ave_mean_accuracy': best_combination['ave_mean_accuracy']
    } if best_combination is not None else "No best combination found"

In [44]:
best_hyperparameters = find_best_hyperparameters(final_results)
print("Best hyperparameters combination:", best_hyperparameters)


Best hyperparameters combination: {'lr': 0.0001, 'batch_size': 32.0, 'ave_mean_accuracy': 97.26050000000001}
