In [1]:
import numpy as np
import pickle
import json
import os
import shutil
import glob

In [2]:
def get_training_info (agent_name):
    source_file = f"../agent_code/agent_{agent_name}/logs/params.json"
    with open(source_file, 'r') as file:
        params = json.load(file)
    
    model_name = params['training']['MODEL_NAME']
    n_rounds   = params['training']['TRAINING_ROUNDS']

    return model_name, n_rounds

In [3]:
def make_analysis_directory (agent_name, model_name):
    directory = f"./{agent_name}/{model_name}/"
    os.makedirs(directory)

In [4]:
def copy_model (agent_name, model_name):
    source_file        = f"../agent_code/agent_{agent_name}/model_{agent_name}_{model_name}.pt"
    destination_folder = f"./{agent_name}/{model_name}/"

    shutil.copy(source_file, destination_folder)

In [5]:
def move_log (agent_name, model_name, mode = "train"):
    source_file      = f"../agent_code/agent_{agent_name}/logs/agent_{agent_name}.log"
    destination_file = f"./{agent_name}/{model_name}/log_{mode}.log"

    shutil.move(source_file, destination_file)

In [6]:
def move_sa_counter (agent_name, model_name):
    source_file      = f"../agent_code/agent_{agent_name}/logs/state_action_counter.npy"
    destination_file = f"./{agent_name}/{model_name}/state_action_counter.npy"

    shutil.move(source_file, destination_file)

In [7]:
def copy_analysis_template (agent_name, model_name, mode = "train"):
    source_file        = f"./templates/analysis_{mode}.ipynb"
    destination_folder = f"./{agent_name}/{model_name}/"

    shutil.copy(source_file, destination_folder)

In [8]:
def move_results (agent_name, model_name, mode = "train"):
    source_file        = sorted(glob.glob("../results/*.json"))[-1]   # latest json log file
    destination_folder = f"./{agent_name}/{model_name}/"
    destination_file   = f"{destination_folder}results_{mode}.json"

    if destination_file not in glob.glob(destination_folder):
        shutil.move(source_file, destination_file)
    else:
        print(f"There's already a 'results_{mode}.json' file in '{destination_folder}'.")

In [9]:
def collect_Q_data (agent_name, model_name, number_of_rounds):
    source_folder = f"../agent_code/agent_{agent_name}/logs/Q_data/"
    Q_file_name   = lambda x: f"{source_folder}Q{x}.npy"
    Q_file        = lambda x: np.load(Q_file_name(x))
  
    all_rounds = np.arange(number_of_rounds) + 1
    allQ       = np.zeros((number_of_rounds, *Q_file(1).shape))
    for round in all_rounds:
        print(f"Loading 'Q{round}'", end="\r")
        allQ[round-1] = Q_file(round)
    print('\n')

    analysis_folder = f"../analysis/{agent_name}/{model_name}/"
    allQ_file_name  = f"{analysis_folder}Qtrain.npy"
    # Check to prevent accidental overwrites
    if os.path.exists(allQ_file_name):
        print(f"Do you want to overwrite the Qtrain.py file of agent_{agent_name}_{model_name}? (y/n)")
        answer = ""
        while not (answer == "y" or answer == "n"):  
            answer = input()
        if answer == "n":
            print("Didn't overwrite.")
            return
        else:
            print(f"Overwriting '{allQ_file_name}'.")
    np.save(allQ_file_name, allQ)

    if allQ_file_name in glob.glob(f"{analysis_folder}*.npy"):
        print(f"Removing Q-files in '{source_folder}'.")
        for file in glob.glob(f"{source_folder}Q*.npy"):
            os.remove(file)
    print("Done.")

In [10]:
def move_params (agent_name, model_name, mode = "train"):
    source_file      = f"../agent_code/agent_{agent_name}/logs/params.json"
    destination_file = f"./{agent_name}/{model_name}/params_{mode}.json"

    shutil.move(source_file, destination_file)

In [11]:
def collect_analysis_data (agent_name, mode = "train"):
    model_name, n_rounds = get_training_info(agent_name)
    make_analysis_directory(agent_name, model_name)
    copy_model             (agent_name, model_name)
    move_log               (agent_name, model_name, mode)
    move_sa_counter        (agent_name, model_name)
    copy_analysis_template (agent_name, model_name, mode)
    move_results           (agent_name, model_name, mode)
    collect_Q_data         (agent_name, model_name, n_rounds)
    move_params            (agent_name, model_name, mode)
    

In [12]:
collect_analysis_data("h2b")

Loading 'Q1000'

Removing Q-files in '../agent_code/agent_h2b/logs/Q_data/'.
Done.
