In [None]:
from pathlib import Path
import re
import json
import numpy as np
import matplotlib.pyplot as plt

# Find all folders starting with 'results-'
results_folders = sorted(
    [d for d in Path('.').iterdir() if d.is_dir() and d.name.startswith('results-')]
)

data = []

for folder in results_folders:
    folder_info = {
        'folder_name': folder.name,
        'model_name': folder.name.replace('results-', ''),
        'average_accuracy': None,
        'category_accuracies': {},
    }

    # Find the summary file
    summary_dir = folder / 'summary'
    if summary_dir.exists() and summary_dir.is_dir():
        summary_files = list(summary_dir.glob('*.txt'))

        if summary_files:
            summary_file = summary_files[0]  # Get the first (and typically only) file

            # Read and parse the file
            with open(summary_file, 'r') as f:
                for line in f:
                    line = line.strip()

                    # Check for overall average accuracy (with colon)
                    if line.startswith('Average accuracy:'):
                        accuracy = float(line.split(':')[1].strip())
                        folder_info['average_accuracy'] = accuracy

                    # Check for category-level accuracy (without colon, with dash)
                    elif line.startswith('Average accuracy') and '-' in line:
                        # Parse: "Average accuracy 0.7699 - biology"
                        parts = line.split('-')
                        if len(parts) == 2:
                            category = parts[1].strip()
                            # Extract the number from the first part
                            match = re.search(r'Average accuracy\s+([\d.]+)', parts[0])
                            if match:
                                accuracy = float(match.group(1))
                                folder_info['category_accuracies'][category] = accuracy

    data.append(folder_info)

print(json.dumps(data, indent=2))

In [None]:
# Extract model names and average accuracies
model_names = [item['model_name'] for item in data]
avg_accuracies = [item['average_accuracy'] for item in data]

# Create bar graph with different colors
plt.figure(figsize=(16, 9))
colors = [f'C{i}' for i in range(len(model_names))]
bars = plt.bar(range(len(model_names)), avg_accuracies, color=colors, edgecolor='black')

# Add value labels on top of each bar
for bar in bars:
    height = bar.get_height()
    plt.text(
        bar.get_x() + bar.get_width() / 2.0,
        height,
        f'{height:.4f}',
        ha='center',
        va='bottom',
        fontsize=16,
    )

plt.ylabel('Average Accuracy', fontsize=18)
plt.title('Average Accuracy by Model', fontsize=20, fontweight='bold')
plt.ylim(0, 1.0)  # Accuracy ranges from 0 to 1
plt.grid(axis='y', alpha=0.3)
plt.xticks([])  # Remove x-axis labels
plt.yticks(fontsize=16)

plt.legend(bars, model_names, loc='best', fontsize=16)
plt.tight_layout()
plt.show()

In [None]:
# Get all unique categories (assuming all models have the same categories)
categories = list(data[0]['category_accuracies'].keys())
num_categories = len(categories)
num_models = len(data)

# Set up the bar positions
x = np.arange(num_categories)
width = 0.2  # Width of each bar
offset = width * (num_models - 1) / 2

fig, ax = plt.subplots(figsize=(16, 8))

# Plot bars for each model
for idx, item in enumerate(data):
    model_name = item['model_name']
    accuracies = [item['category_accuracies'][cat] for cat in categories]

    # Position bars for this model
    positions = x - offset + (idx * width)
    bars = ax.bar(
        positions,
        accuracies,
        width,
        label=model_name,
        color=f'C{idx}',
        edgecolor='black',
        alpha=0.8,
    )

# Customize the plot
ax.set_xlabel('Category', fontsize=18)
ax.set_ylabel('Average Accuracy', fontsize=18)
ax.set_title('Model Performance Across Categories', fontsize=20, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(categories, rotation=45, ha='right', fontsize=16)
ax.tick_params(axis='y', labelsize=16)
ax.set_ylim(0, 1.0)
ax.legend(loc='best', fontsize=16)
ax.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.show()