In [None]:
import os
import pickle

import flax.core
import jax.random
import wandb

In [None]:
# Set the project and entity if necessary
# entity is typically your username or team name
entity = "miladink"
project = "loqa-ipd"
api = wandb.Api()

In [None]:
# Load runs by tag

In [None]:

chosen_tag = "ipd_10_seeds_v4"

runs = api.runs(path=f"{entity}/{project}", filters={"tags": {"$in": [chosen_tag]}})

# Print specifics of the runs
for run in runs:
    print("Run ID:", run.id)
    print("Name:", run.name)
    print("Config:", run.config)
    print("Summary:", run.summary)
    print("Notes:", run.notes)
    print("Tags:", run.tags)
    print("="*50)
    
print(f"Found {len(runs)} runs")


# Load runs by run_ids

In [None]:
run_ids = ['301bjspk']
runs = [api.run(f"{entity}/{project}/{run_id}") for run_id in run_ids]

In [None]:
target_step = 6000
run_keys = ['p_0_START_C', 'p_0_CC_C', 'p_0_CD_C', 'p_0_DC_C', 'p_0_DD_C']
# get history of the keys
def grab_run_data(run):
    history = run.scan_history(keys=run_keys)
    data = [{k: row[k] for k in run_keys} for row in history]
    return data

grab_run_data(runs[0])

In [None]:
def extract_data(run, key):
    for row in run.scan_history(keys=[key]):
        print(row)
        yield {key: row[key]}

# list(extract_data(runs[0], "p_0_START_C"))
list(extract_data(runs[1], "p_0_CC_C"))

In [None]:
def extract_data(run):
    for step, row in enumerate(run.scan_history()):
        print(row)
        data = {}
        data['step'] = row['_step']
        for key, value in row.items():
            data[key] = value
        yield data


# Usage
run_data = {}
for run in runs:
    run_data[run.name] = list(extract_data(run))

In [None]:
from collections import defaultdict

organized_data = {}
for run in runs:
    organized_data[run.name] = defaultdict(dict)
    for data in run_data[run.name]:
        for key, value in data.items():
            real_step = (data['step'] // 100)*100
            organized_data[run.name][real_step][key] = value

In [None]:
import numpy as np

step = 7000

stats = {}
state_names =  ['START', 'CC', 'CD', 'DC', 'DD']
for state in state_names:
    key = f'p_0_{state}_C'
    stats[key] = np.array([organized_data[run.name][step][key] for run in runs])

stats


In [None]:
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style="whitegrid", font_scale=1.5)
# bar plot
fig, ax = plt.subplots(figsize=(8, 8))
ys = [l.mean() for l in stats.values()]
yerr = [l.std() for l in stats.values()]
xs = state_names
bar = ax.bar(x=xs, height=ys, yerr=yerr, align='center', alpha=1.0, ecolor='black', capsize=20, color=['purple'])
ax.set_ylabel('Probability of Cooperation', fontsize=20)
ax.set_xlabel('State', fontsize=20)
ax.set_title('LOQA (Ours)', fontsize=20)

for rect in bar:
    rect.set_edgecolor('gray')
    rect.set_linewidth(3.0)

plt.savefig('ipd_exp.pdf', bbox_inches='tight')

In [None]:
[l for l in stats.values()]