In [37]:
import os
import sys
import pathlib
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

plt.style.use("ggplot")

data_path = pathlib.Path("/devcode/GATE-private/notebooks/gate-results.csv")

In [38]:
df = pd.read_csv(data_path)

# Print the first few rows of the DataFrame to see what's inside
print(df.head())

In [3]:
# Print the summary statistics of the DataFrame
print(df.describe())

       testing/accuracy_top_1-epoch-mean  testing/accuracy_top_1-epoch-std   
count                         130.000000                        130.000000  \
mean                           65.825859                          4.967498   
std                            27.083131                          3.582391   
min                             2.274525                          1.162434   
25%                            56.174597                          3.216839   
50%                            77.484615                          3.790733   
75%                            82.613693                          4.446880   
max                            96.229073                         15.894830   

       testing/accuracy_top_5-epoch-mean  testing/accuracy_top_5-epoch-std   
count                         130.000000                        130.000000  \
mean                           82.054007                          3.441022   
std                            24.640253                       

In [39]:
import pandas as pd
from rich.table import Table
from rich.console import Console


def load_and_process_data(data_path: str):
    # Load the data
    df = pd.read_csv(data_path)

    # Extract the seed number and remove it from the "Name"
    df["Seed"] = df["Name"].apply(
        lambda x: int(x.split("-")[-1].replace('"', ""))
        if x.split("-")[-1].replace('"', "").isdigit()
        else None
    )
    df["Name"] = df["Name"].apply(
        lambda x: "-".join(x.split("-")[:-1]).replace('"', "")
    )

    # Filter the DataFrame
    df = df[df["Name"].str.startswith("athena")]
    df = df[~df["Name"].str.contains("tali")]

    # df = df.dropna()

    # Extract the task name
    df["Task"] = df["Name"].apply(lambda x: x.split("-")[1])
    df["Name"] = df["Name"].apply(lambda x: "-".join(x.split("-")[2:-1]))

    # Convert columns to numeric type
    numeric_columns = [
        "validation/accuracy_top_1-epoch-mean",
        "validation/accuracy_top_5-epoch-mean",
    ]

    for col in numeric_columns:
        df[col] = pd.to_numeric(df[col], errors="coerce")

    # Group by 'Name' and 'Task' and calculate mean and standard deviation
    grouped = df.groupby(["Task", "Name"])[numeric_columns].agg(
        ["mean", "std", "count"]
    )

    # Sort by 'Task'
    grouped = grouped.sort_values(by="Task")

    return grouped


df = load_and_process_data(data_path)
print(df.to_csv("data.csv"))

In [34]:
from rich.console import Console
from rich.table import Table

console = Console(width=100)


def print_rich_table(df: pd.DataFrame, task: str):
    # Create a new table
    table = Table(show_header=True, header_style="bold magenta")

    # Add columns
    table.add_column("Name")
    for col in df.columns:
        table.add_column(str(col), width=50)

    # Add rows
    for index, row in df.iterrows():
        table.add_row(str(index), *[str(v) for v in row])

    # Print the table
    console.print(f"Task: {task}")
    console.print(table, width=100)


# Pivot the data
for task in df.index.get_level_values("Task").unique():
    pivot_df = df.loc[task].pivot_table(
        index="Name",
        values=[
            "validation/accuracy_top_1-epoch-mean",
            "validation/accuracy_top_5-epoch-mean",
        ],
        aggfunc=["mean", "std", "count"],
    )
    # Print the table
    print_rich_table(pivot_df, task)