In [1]:
import ast
import json

from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import numpy as np
import os
import wandb

In [2]:
api = wandb.Api()

In [3]:
def get_experiment_data(filters, name):
    runs = api.runs("bartekcupial/sf2_nethack", filters=filters)
    data = []
    for run in runs:
        try:
            artifact = api.artifact(f"bartekcupial/sf2_nethack/run-{run.id}-table_results:v1")
            artifact_path = Path("artifacts") / f"run-{run.id}-table_results:v1" / "table_results.table.json"
            if artifact_path.exists():
                table = wandb.Table.from_json(json.load(open(artifact_path, "r+")), artifact)
            else:
                table = artifact.get("table_results")
            df = table.get_dataframe()
            df["seed"] = run.config["seed"]
            df["train/env_steps"] = run.summary_metrics["train/env_steps"]
            df["train_dir"] = run.config["train_dir"]
            df["name"] = name
            df["exp_tag"] = ast.literal_eval(run.config["exp_tags"])[0]
            data.append(df)
        except Exception as e:
            pass

    df = pd.concat(data, axis=0).reset_index()
    df.reset_index(drop=True, inplace=True)
    return df

In [4]:
data = get_experiment_data({"config.exp_tags": "['2024_02_02_eval_full']"}, "all")

In [10]:
data["exp_tag"] = data["train_dir"].apply(lambda x: "-".join(x.split("/")[-2].split("-")[:-1]))

In [5]:
data["death"]

0                           killed by a bolt of lightning
1                                       killed by a guard
2       killed by a hobbit while fainted from lack of ...
3       killed by an iguana while fainted from lack of...
4       killed by a rothe while fainted from lack of food
                              ...                        
2386                                                 quit
2387                                                 quit
2388                                                 quit
2389                                                 quit
2390                                                 quit
Name: death, Length: 2391, dtype: object

In [49]:
data[["score", "exp_tag"]].groupby("exp_tag").count()

Unnamed: 0_level_0,score
exp_tag,Unnamed: 1_level_1
2024-01-23-monk-appo,500
2024-01-23-monk-appo-bc-t,389
2024-01-23-monk-appo-ks-t,499
2024-01-23-monk-appo-t,515
2024-01-29-monk-appo-ewc-t,488


In [92]:
for tag in data["exp_tag"].unique():
    method_data = data[data["exp_tag"] == tag]
    unique, counts = np.unique(method_data["death"].to_numpy(), return_counts=True)

    c = counts[counts.argsort()[::-1]]
    u = unique[counts.argsort()[::-1]]
    df = pd.DataFrame({"count": c, "death": u})
    print(f"Tag: {tag}")
    print(f'Death related to lack of food: {(len(method_data[method_data["death"].str.contains("food|starvation")]) / len(method_data) * 100):.2f}% of the time!')
    print("Excluding food related deaths")
    print(df[~df["death"].str.contains("food|starvation")][:10].to_string(index=False))


Tag: 2024-01-23-monk-appo-ks-t
Death related to lack of food: 38.68% of the time!
Excluding food related deaths
 count                                  death
    18                                   quit
    18                killed by a soldier ant
    13              killed by a white unicorn
    10   killed by the wrath of Chih Sung-tzu
    10 poisoned by a rotted glob of gray ooze
     9               killed by a Woodland-elf
     9             petrified by a chickatrice
     7                       killed by a wand
     7                      killed by a guard
     6                      killed by a mumak
Tag: 2024-01-23-monk-appo-bc-t
Death related to lack of food: 50.64% of the time!
Excluding food related deaths
 count                                death
    46                                 quit
     8 killed by the wrath of Chih Sung-tzu
     7                 killed by a hill orc
     5              killed by a soldier ant
     5                     killed by a wand
     4