In [14]:
import matplotlib.pyplot as plt
import tabulate

import numpy as np

import random

import pandas as pd

import pandas as pd
import wandb
api = wandb.Api()

USERS = ["David", "Jeff", "Kay", "Sara", "Tana"]
METRICS = ["BLEU", "Rouge", "Mauve"]
METRIC_NAMES = {
    "BLEU": "BLEU",
    "Rouge": "rougeLsum_fmeasure",
    "Mauve": "MAUVE"

}
MODEL_TYPES = ["FFT", "RoSA", "LoRA"]

COLUMNS = ["model", "user", "model_type", "metric", "RAFT", "RAG", "score", "seed", "metadata"]

In [32]:
def extract_config_from_run_name(run_name):
    user = [u for u in USERS if u.lower() in run_name][0]
    model_type = [m for m in MODEL_TYPES if m.lower() in run_name][0]
    if model_type == "RoSA" and "lr0.0-epochs" in run_name:
        model_type = "LoRA"
    raft = "RAFT" in run_name
    seed = int(run_name.split("seed")[-1].split("-")[0])
    return user, model_type, raft, seed


def load_wandb_project_runs(project_name):
    results_df = pd.DataFrame(columns=COLUMNS)
    runs = api.runs(project_name)
    for run in runs:
        if run.summary["_timestamp"] < 1717813799.4198122:
            continue  # Skip older runs
        user, model_type, raft, seed = extract_config_from_run_name(run.name)
        summary = run.summary._json_dict
        for rag in [False, True]:
            for metric in METRICS:
                score = summary[f"EVAL/{METRIC_NAMES[metric]}{'-RAG' if rag else ''}-mean"]
                results_df = results_df._append({
                    "model": "Phi3",
                    "user": user,
                    "model_type": model_type,
                    "metric": metric,
                    "RAFT": raft,
                    "RAG": rag,
                    "score": score,
                    "seed": seed,
                    "metadata": None
                }, ignore_index=True)

    print(len(results_df))
    return results_df

all_projects = [
    ("david", "diverse-vit/panza-david_anonymous-Phi3-June8"),
    ("jeff", "diverse-vit/panza-jeff_johnson-Phi3-June8"),
    ("kay", "diverse-vit/panza-kay_brown-Phi3-June8"),
    ("sara", "diverse-vit/panza-shackleton_sara-Phi3-June8"),
    ("tana", "diverse-vit/panza-tana_williams-Phi3-June8"),
]

results_df = pd.DataFrame(columns=COLUMNS)
for _, project_name in all_projects:
    results_df = pd.concat([results_df, load_wandb_project_runs(project_name)])

# results_df = load_wandb_project_runs("diverse-vit/panza-jeff_johnson-Phi3-June8")

results_df = results_df.groupby(['model', 'user', 'model_type', 'metric', 'RAFT', 'RAG']).agg({'score': ['mean', 'std']}).reset_index()
results_df.columns = ['model', 'user', 'model_type', 'metric', 'RAFT', 'RAG', 'score', 'score_std']
print(results_df)

  results_df = results_df._append({
  results_df = pd.concat([results_df, load_wandb_project_runs(project_name)])
  results_df = results_df._append({
  results_df = results_df._append({
  results_df = results_df._append({


108
108
108
108
108
    model   user model_type metric   RAFT    RAG     score  score_std
0    Phi3  David        FFT   BLEU  False  False  0.329567   0.018969
1    Phi3  David        FFT   BLEU  False   True  0.309882   0.008343
2    Phi3  David        FFT   BLEU   True  False  0.330211   0.014018
3    Phi3  David        FFT   BLEU   True   True  0.327001   0.023329
4    Phi3  David        FFT  Mauve  False  False  0.999684   0.000471
..    ...    ...        ...    ...    ...    ...       ...        ...
175  Phi3   Tana       RoSA  Mauve   True   True  0.945010   0.016103
176  Phi3   Tana       RoSA  Rouge  False  False  0.336564   0.007991
177  Phi3   Tana       RoSA  Rouge  False   True  0.339023   0.005040
178  Phi3   Tana       RoSA  Rouge   True  False  0.331023   0.006132
179  Phi3   Tana       RoSA  Rouge   True   True  0.354281   0.007302

[180 rows x 8 columns]


  results_df = results_df._append({


In [33]:
def get_mock_results():
    df = pd.DataFrame(columns=COLUMNS)

    for user in USERS:
        for model_type in MODEL_TYPES:
            for metric in METRICS:
                for rag in [False, True]:
                    for raft in [False, True]:
                        df = df._append(
                            {
                                "model": "model",
                                "user": user,
                                "model_type": model_type,
                                "metric": metric,
                                "RAFT": raft,
                                "RAG": rag,
                                "score": random.random(),
                                "metadata": None,
                            },
                            ignore_index=True,
                        )

    return df


def create_all_results_table(results_df):
    table = []
    for model_type in MODEL_TYPES:
        for raft in [False, True]:
            for rag in [False, True]:
                # Create line for given model type and raft
                filtered_df = results_df[
                    (results_df["model_type"] == model_type)
                    & (results_df["RAFT"] == raft)
                    & (results_df["RAG"] == rag)
                ]
                raft_str = "-RAFT" if raft else ""
                rag_str = "-RAG" if rag else ""
                model_str = f"{model_type}{raft_str}{rag_str}"
                line = [model_str]
                for user in USERS:
                    for metric in METRICS:
                        try:
                            score = filtered_df[
                                (filtered_df["user"] == user) & (filtered_df["metric"] == metric)
                            ]["score"].values[0]
                        except Exception as e:
                            print(e)
                            score = "-"
                        line.append(score)
                table.append(line)
    return table

mock_df = get_mock_results()
table = create_all_results_table(results_df)
for line in table:
    print(line)

  df = df._append(


['FFT', 0.32956710199515027, 0.488078807592392, 0.9996842647524252, 0.16916790710098395, 0.28105342263615757, 0.8597102933295929, 0.19915374772734107, 0.2985645563867702, 0.8656762003858293, 0.27011771550169217, 0.3791506297708977, 0.8838354668654723, 0.264710224947651, 0.36415353960668045, 0.8686559280988825]
['FFT-RAG', 0.3098816779255867, 0.4741027146577835, 0.9936033432268913, 0.1788094564348172, 0.27978655180439005, 0.7262057197565728, 0.19587726811059167, 0.2833636980590031, 0.8583141116843129, 0.23062316548307318, 0.3218195655338821, 0.9593771411943007, 0.2539928340392413, 0.35317645590042784, 0.9111948342846895]
['FFT-RAFT', 0.33021117707093556, 0.5081067683299382, 0.9922753770520328, 0.16557368286140803, 0.27639164469500843, 0.9149353630038947, 0.19358916039345786, 0.2889097487571201, 0.9233778943827771, 0.2608540862593578, 0.3657797660678625, 0.8805893938791374, 0.26241157709268054, 0.3714597349587296, 0.8772436896304493]
['FFT-RAFT-RAG', 0.32700149118900296, 0.49509853581587

In [34]:
from pylatex import Document, Section, Tabular, MultiColumn, MultiRow, Command, NoEscape

def format_as_math(number):
    """Return number formatted as LaTeX math."""
    return NoEscape(f"${number}$")

def get_colored_column():
    """Return a colored column formatter."""
    return NoEscape(">{\columncolor{gray!10}}")

def create_document(results_table):
    doc = Document("results")

    columns_format = ""
    columns_format += "l"
    for i, user in enumerate(USERS):
        for metric in METRICS:
            if i % 2 == 0:
                columns_format += get_colored_column()
            columns_format += "c"
    # with doc.create(Tabular("c" * (1 + len(USERS) * len(METRICS)))) as table:
    with doc.create(Tabular(columns_format)) as table:
        table.add_hline()

        # Add user columns
        user_columns = []
        user_columns.append(" ")
        for i, user in enumerate(USERS):
            if False: #i % 2 == 0:
                user_columns.append(MultiColumn(3, align='c', data=Command('cellcolor', arguments=['gray!10'], extra_arguments=Command('texttt', user))))
            else:
                user_columns.append(MultiColumn(3, align='c', data=Command('texttt', user)))
        table.add_row(user_columns)
        table.add_hline()

        # Add metric columns
        metrics_columns = []
        metrics_columns.append("Method")
        for user in USERS:
            for metric in METRICS:
                metrics_columns.append(metric)
        table.add_row(metrics_columns)

        for i, line in enumerate(results_table, 1):
            clean_line = []
            for element in line:
                if isinstance(element, str):
                    element = Command('texttt', element)
                    clean_line.append(element)
                else:
                    element = float(element)
                    element = round(element, 3)
                    element = format_as_math(element)
                    clean_line.append(element)
            table.add_row(clean_line)

            # Add horizontal line after each model type
            if i % 4 == 0:
                table.add_hline()

        table.add_hline()

    return doc

doc = create_document(table)
print(doc.dumps())

\documentclass{article}%
\usepackage[T1]{fontenc}%
\usepackage[utf8]{inputenc}%
\usepackage{lmodern}%
\usepackage{textcomp}%
\usepackage{lastpage}%
%
%
%
\begin{document}%
\normalsize%
\begin{tabular}{l>{\columncolor{gray!10}}c>{\columncolor{gray!10}}c>{\columncolor{gray!10}}cccc>{\columncolor{gray!10}}c>{\columncolor{gray!10}}c>{\columncolor{gray!10}}cccc>{\columncolor{gray!10}}c>{\columncolor{gray!10}}c>{\columncolor{gray!10}}c}%
\hline%
 &\multicolumn{3}{c}{\texttt{David}}&\multicolumn{3}{c}{\texttt{Jeff}}&\multicolumn{3}{c}{\texttt{Kay}}&\multicolumn{3}{c}{\texttt{Sara}}&\multicolumn{3}{c}{\texttt{Tana}}\\%
\hline%
Method&BLEU&Rouge&Mauve&BLEU&Rouge&Mauve&BLEU&Rouge&Mauve&BLEU&Rouge&Mauve&BLEU&Rouge&Mauve\\%
\texttt{FFT}&$0.33$&$0.488$&$1.0$&$0.169$&$0.281$&$0.86$&$0.199$&$0.299$&$0.866$&$0.27$&$0.379$&$0.884$&$0.265$&$0.364$&$0.869$\\%
\texttt{FFT{-}RAG}&$0.31$&$0.474$&$0.994$&$0.179$&$0.28$&$0.726$&$0.196$&$0.283$&$0.858$&$0.231$&$0.322$&$0.959$&$0.254$&$0.353$&$0.911$\\%
\texttt