In [None]:
import os
import json
import pandas as pd
import matplotlib.pyplot as plt
import re

# Function to parse the time.txt file
def parse_time_file(filepath):
    with open(filepath, 'r') as file:
        line = file.readline().strip()
        total_time = float(line.split(': ')[1].split(' seconds')[0])
    return total_time

# Function to load training metrics and time from a directory
def load_data_from_dir(directory):
    metrics_filepath = os.path.join(directory, 'training_metrics.json')
    time_filepath = os.path.join(directory, 'time.txt')
    
    metrics = None
    total_time = None

    if os.path.exists(metrics_filepath):
        with open(metrics_filepath, 'r') as file:
            metrics = json.load(file)
        
    if os.path.exists(time_filepath):
        total_time = parse_time_file(time_filepath)
    
    # Determine the number of entries (either from metrics or default to 1)
    num_entries = len(metrics['generation_indices']) if metrics else 1

    if total_time is None:
        data = {
            'generation_indices': [None] * num_entries,
            'percent_correct': [None] * num_entries,
            'percent_incorrect': [None] * num_entries,
            'percent_blank': [None] * num_entries,
            'total_time': [None] * num_entries
        }
    else:
        data = {
            'generation_indices': metrics['generation_indices'] if metrics else [None] * num_entries,
            'percent_correct': [int(p * 100) for p in metrics['percent_correct']] if metrics else [None] * num_entries,
            'percent_incorrect': [int(p * 100) for p in metrics['percent_incorrect']] if metrics else [None] * num_entries,
            'percent_blank': [int(p * 100) for p in metrics['percent_blank']] if metrics else [None] * num_entries,
            'total_time': [total_time] * num_entries if total_time is not None else [None] * num_entries
        }
    
    # Extract parameters from directory name
    dir_name = os.path.basename(directory)
    params = re.split('__|_', dir_name)[2:]  # Skip the first 'v3_1'
    
    if len(params) % 2 == 0:
        param_dict = {}
        for i in range(0, len(params), 2):
            key = params[i]
            value = params[i+1]
            try:
                param_dict[key] = int(value) if '.' not in value else float(value)
            except ValueError:
                param_dict[key] = value
        for key, value in param_dict.items():
            data[key] = [int(value)] * num_entries if isinstance(value, float) and value.is_integer() else [value] * num_entries
    
    else:
        print(f"Warning: Skipping directory {dir_name} due to unmatched parameters.")
    
    data['dir_name'] = [dir_name] * num_entries  # Add directory name for identification in plots
    
    return pd.DataFrame(data)

# Main function to load data from all directories and aggregate
def load_and_aggregate_data(base_dir):
    all_data = []
    for root, dirs, files in os.walk(base_dir):
        for dir_name in dirs:
            if dir_name.startswith("v3_1__"):
                dir_path = os.path.join(root, dir_name)
                data = load_data_from_dir(dir_path)
                print(data)
                if data is not None:
                    all_data.append(data)
    
    if all_data:
        return pd.concat(all_data, ignore_index=True)
    else:
        return pd.DataFrame()

# Function to convert total time from seconds to hours, minutes, and seconds
def format_time(seconds):
    hours = int(seconds // 3600)
    minutes = int((seconds % 3600) // 60)
    seconds = int(seconds % 60)
    return f"{hours}h {minutes}m {seconds}s"

# Function to create the model generation vs validation metrics plots
def create_generation_plots(data, max_generation, filename):
    plt.figure(figsize=(15, 10))

    for name, group in data.groupby('dir_name'):
        label = (f"EVAL: {group['eval'].iloc[0]}, TOKENS: {group['tokens'].iloc[0]}, ROLLOUTS: {group['rollouts'].iloc[0]}, "
                 f"BRANCH: {group['branch'].iloc[0]}, GAMES: {group['games'].iloc[0]}, BATCH: {group['batch'].iloc[0]}, GEN: {group['gen'].iloc[0]}")
        plt.plot(group['generation_indices'], [p for p in group['percent_correct']], label=f"{label} - % Correct")
        plt.plot(group['generation_indices'], [p for p in group['percent_incorrect']], label=f"{label} - % Incorrect", linestyle='--')
        plt.plot(group['generation_indices'], [p for p in group['percent_blank']], label=f"{label} - % Blank", linestyle=':')

    plt.xlabel('Model Generation')
    plt.ylabel('Percentage')
    plt.ylim(0, 100)
    plt.xlim(1, max_generation)
    plt.legend(loc='upper left', bbox_to_anchor=(1.05, 1), borderaxespad=0.)
    plt.title(f'Model Generation vs Validation Metrics (up to {max_generation} generations)')
    plt.grid(True)
    plt.tight_layout(rect=[0, 0, 0.75, 1])
    plt.savefig(filename, bbox_inches='tight')
    plt.show()

# Function to analyze and visualize the data
def analyze_and_visualize(data):
    # Create two views for the model generation vs validation metrics plot
    create_generation_plots(data, 2, 'model_generation_vs_metrics_up_to_2.png')
    create_generation_plots(data, 20, 'model_generation_vs_metrics_up_to_20.png')

    # Add initial values for correct, blank, and incorrect
    data['initial_correct'] = data.groupby('dir_name')['percent_correct'].transform('first')
    data['initial_blank'] = data.groupby('dir_name')['percent_blank'].transform('first')
    data['initial_incorrect'] = data.groupby('dir_name')['percent_incorrect'].transform('first')

    # Bar chart: Rank each folder by the final % correct and total time
    final_results = data.groupby('dir_name').agg({
        'generation_indices': 'max',
        'percent_correct': 'last',
        'percent_incorrect': 'last',
        'percent_blank': 'last',
        'total_time': 'last',
        'eval': 'first',
        'tokens': 'first',
        'rollouts': 'first',
        'branch': 'first',
        'games': 'first',
        'gen': 'first',
        'batch': 'first',
        'initial_correct': 'first',
        'initial_blank': 'first',
        'initial_incorrect': 'first'
    }).reset_index()

    # Calculate deltas
    final_results['Δ Correct (%)'] = final_results['percent_correct'] - final_results['initial_correct']
    final_results['Δ Blank (%)'] = final_results['percent_blank'] - final_results['initial_blank']
    final_results['Δ Incorrect (%)'] = final_results['percent_incorrect'] - final_results['initial_incorrect']

    # Replace NaN with empty strings for display purposes
    final_results = final_results.replace({pd.NA: '', None: ''})

    # Filter out rows with empty values in 'percent_correct' and 'total_time' for the bar chart
    filtered_results = final_results[(final_results['percent_correct'] != '') & (final_results['total_time'] != '')]
    filtered_results = filtered_results.sort_values(by='percent_correct', ascending=False)

    fig, ax1 = plt.subplots(figsize=(15, 10))
    
    color = 'tab:blue'
    ax1.set_xlabel('Directory')
    ax1.set_ylabel('Percent Correct', color=color)
    ax1.bar(filtered_results['dir_name'], filtered_results['percent_correct'], color=color)
    ax1.tick_params(axis='y', labelcolor=color)
    plt.xticks(rotation=90, ha='right')
    ax1.set_ylim(0, 100)

    ax2 = ax1.twinx()  
    color = 'tab:red'
    ax2.set_ylabel('Total Time (seconds)', color=color)
    ax2.plot(filtered_results['dir_name'], filtered_results['total_time'], color=color, marker='o')
    ax2.tick_params(axis='y', labelcolor=color)
    
    fig.tight_layout()  
    plt.title('Final Percent Correct and Total Time by Directory')
    plt.savefig('final_percent_correct_and_total_time.png')
    plt.show()

    # Table: Hyperparameters and final metrics
    table_data = final_results[['eval', 'tokens', 'rollouts', 'branch', 'games', 'gen', 'batch', 'percent_correct', 'percent_blank', 'percent_incorrect', 'total_time', 'Δ Correct (%)', 'Δ Blank (%)', 'Δ Incorrect (%)']]
    table_data = table_data.sort_values(by=['eval', 'tokens', 'rollouts', 'branch', 'games', 'batch'])

    # Convert total time to hours, minutes, and seconds
    table_data['total_time'] = table_data['total_time'].apply(lambda x: format_time(x) if isinstance(x, (int, float)) else '')
    table_data.insert(0, 'ID', range(1, len(table_data) + 1))
    table_data.columns = ['ID', 'EVAL', 'TOKENS', 'ROLLOUTS', 'BRANCH', 'GAMES', 'GEN', 'BATCH', 'Correct (%)', 'Blank (%)', 'Incorrect (%)', 'Total Time', 'Δ Correct (%)', 'Δ Blank (%)', 'Δ Incorrect (%)']

    fig, ax = plt.subplots(figsize=(15, 10))
    ax.axis('tight')
    ax.axis('off')
    table = ax.table(cellText=table_data.values, colLabels=table_data.columns, cellLoc='center', loc='center')
    table.auto_set_font_size(False)
    table.set_fontsize(10)
    table.scale(1.2, 1.2)

    plt.title('Hyperparameters and Final Metrics')
    plt.savefig('hyperparameters_and_final_metrics.png')
    plt.show()

# Example usage
base_dir = ''
data = load_and_aggregate_data(base_dir)
if not data.empty:
    analyze_and_visualize(data)
else:
    print("No data found.")
