In [None]:
import os
import pickle

import flax.core
import jax.random
import wandb

In [11]:
# Set the project and entity if necessary
# entity is typically your username or team name
entity = "miladink"
project = "loqa-ipd"
api = wandb.Api()

# Load runs by tag

In [12]:

# chosen_tag = "iconic-planet-190-seeds"
# chosen_tag = "loqa_10_seeds_v1"
chosen_tag = "loqa_10_seeds_v1_rb_ablation"

runs = api.runs(path=f"{entity}/{project}", filters={"tags": {"$in": [chosen_tag]}})

# Print specifics of the runs
for run in runs:
    print("Run ID:", run.id)
    print("Name:", run.name)
    print("Config:", run.config)
    print("Summary:", run.summary)
    print("Notes:", run.notes)
    print("Tags:", run.tags)
    print("="*50)
    
print(f"Found {len(runs)} runs with tag {chosen_tag}")

Run ID: 1abkzsx3
Name: logical-wildflower-402
Config: {'game': {'width': 3, 'height': 3, 'game_length': 50, 'gnumactions': 4}, 'seed': 51, 'actor': {'train': {'separate': 'disabled', 'advantage': 'TD0', 'clip_grad': {'mode': 'norm', 'max_norm': 1}, 'optimizer': 'adam', 'entropy_beta': 0.1, 'lr_loss_actor': 0.001}, 'inf_weight': 0.5, 'hidden_size': 128, 'layers_before_gru': 2}, 'reset': {'mode': 'disabled'}, 'qvalue': {'mode': 'mean', 'train': {'optimizer': 'adam', 'lr_loss_qvalue': 0.01, 'target_ema_gamma': 0.99}, 'hidden_size': 64, 'replay_buffer': {'mode': 'disabled', 'capacity': 1000}, 'layers_before_gru': 2}, 'agent_0': 'loqa', 'agent_1': 'loqa', 'save_dir': '/home/mila/a/aghajohm/scratch/loqa', 'batch_size': 512, 'eval_every': 100, 'save_every': 1000, 'just_self_play': True, 'op_softmax_temp': 1, 'reward_discount': 0.96, 'agent_replay_buffer': {'mode': 'disabled', 'capacity': 10000, 'update_freq': 10, 'cur_agent_frac': 0}, 'differentiable_opponent': {'method': 'loaded-dice', 'disc

# Load runs by run_ids

In [None]:
run_ids = ['301bjspk']
runs = [api.run(f"{entity}/{project}/{run_id}") for run_id in run_ids]

# Copy to Local

In [13]:
from pathlib import Path
local_save_dir = Path("/Users/miladaghajohari/PycharmProjects/loqa/league/checkpoints/loqa_ckpt_10_seeds_v1_rb_ablation/")
# make local dir if it does not exist
if not local_save_dir.exists():
    os.makedirs(local_save_dir)
mila_save_dir = Path("mila:/home/mila/a/aghajohm/scratch/loqa/")
# download from mila cluster to local
for run in runs:
   cluster_name = run.summary['slurm/cluster_name']
   run_id = run.id
   if cluster_name == 'mila':
     os.system(f"scp -r -P 2222 {mila_save_dir/run_id} {local_save_dir/run_id}")
   else:
       print(f"cluster name {cluster_name} is not recognized.")
       print(f"skipping {run_id}")
       continue
   print(f'finished writing {run_id}')

finished writing 1abkzsx3
finished writing ejpsvuwk
finished writing d3g7n3wp
finished writing 375wzono
finished writing 2zpspi7d
finished writing 3tocppgh
finished writing 3o6xocrm
finished writing 1tieobkn
finished writing 3kfuinhg
finished writing 2xndvhhp


# Load agents (just play with them for sanity check)

In [None]:
from coin_train import GRUCoinAgent, CoinAgent
import pickle

def load_loqa_agent(path, player_id: int):
    with open(path, 'rb') as f:
        minimal_state = pickle.load(f)
    hp = minimal_state['hp']
    agent_params = minimal_state[f'agent0']['params']
    agent_module = GRUCoinAgent(hidden_size_actor=hp['actor']['hidden_size'],
                                hidden_size_qvalue=hp['qvalue']['hidden_size'],
                                layers_before_gru_actor=hp['actor']['layers_before_gru'],
                                layers_before_gru_qvalue=hp['qvalue']['layers_before_gru'], )
    # IMPORTANT: this is a hack to change player_id, but it works because the agent just sees the obsseq with GRU, and we just use inference, so it does not matter
    return CoinAgent(params=agent_params, model=agent_module, player=player_id), hp

agent, hp = load_loqa_agent('/Users/miladaghajohari/PycharmProjects/loqa/league/checkpoints/loqa_ckpt_iconic_planet_190_seeds/1bwbcc0q/minimal_state_30000', player_id=0)

In [None]:
c_0 = agent.get_initial_carries()
c_0_actor = c_0['carry_actor']
c_0_qvalue = c_0['carry_qvalue']
c_1_actor = c_0_actor
c_1_qvalue = c_0_qvalue

carries = {'c_0_actor': c_0_actor,
           'c_0_qvalue': c_0_qvalue,
           'c_1_actor': c_1_actor,
           'c_1_qvalue': c_1_qvalue}

In [None]:
from coin_train import generate_episodes
import jax, flax
hp = flax.core.FrozenDict(hp)
episodes = generate_episodes(agent, [agent], player_id=0, rng=jax.random.PRNGKey(42), hp=hp, carries=carries)

In [None]:
episodes['rew'].sum(axis=1).mean(axis=0)

In [None]:
from coin_game import do_eval_agent_against_always_cooperate
play_rngs = jax.random.split(jax.random.PRNGKey(78), 16)
play_rngs = jax.numpy.stack(play_rngs)
episodes = jax.vmap(lambda r: do_eval_agent_against_always_cooperate(agent=agent,
                                                                         hp=hp,
                                                                         rng=r))(play_rngs)

In [None]:
episodes['rew'].sum(axis=1).mean(axis=0)