In [31]:
import sys
cwd = "/Users/jordydikkers/Documents/repos/machi-koro-ai"
sys.path.append(cwd)
import os
os.chdir(cwd)

from mcts_agent import PVNet, HDF5DataLoader
import h5py
import torch
from env_machi_koro_2 import GymMachiKoro2
import pickle
import numpy as np

In [62]:
class BufferAnalyzer:
    def __init__(self, checkpoint_path):
        self.buffer_path = checkpoint_path+"/buffers.h5"
        with open(checkpoint_path+"/env.pickle", "rb") as f:
            self.env = pickle.load(f)
        with h5py.File(self.buffer_path, "r") as h5f:
            self.columns_indices = {col: i for i, col in enumerate(h5f.attrs["columns"])}

    def nth_action_count(self, nth_action: int):
        actions = {}

        with h5py.File(self.buffer_path, "r") as h5f:
            for split in ["train", "val"]:
                for iteration in h5f[split].keys():
                    if iteration not in actions.keys():
                        actions[iteration] = {action: 0 for action in self.env._action_idx_to_str.values()}
                    for game in h5f[split][iteration].keys():
                        actions[iteration][self.env._action_idx_to_str[h5f[split][iteration][game][nth_action, self.columns_indices["action"]]]] += 1
        return actions

    def find_buffers_with_action_sequence(self, player_idx: int, sequence: list[str | int]):
        if isinstance(sequence[0], str):
            sequence = [self.env._action_str_to_idx[action] for action in sequence]
        buffers = {}
        with h5py.File(self.buffer_path, "r") as h5f:
            for split in ["train", "val"]:
                for iteration in h5f[split].keys():
                    for game in h5f[split][iteration].keys():
                        players = h5f[split][iteration][game][:, self.columns_indices["player_id"]]
                        targeted_player_indices = np.argwhere(players == player_idx)
                        player_actions = h5f[split][iteration][game][targeted_player_indices.flatten(), self.columns_indices["action"]]
                        # if any section in player_actions overlaps with the sequence, save the buffer
                        for i in range(len(player_actions)-len(sequence)+1):
                            if np.array_equal(player_actions[i:i+len(sequence)], sequence):
                                buffer_name = f"{split}/{iteration}/{game}"
                                sequence_range = (i, i+len(sequence))
                                print(f"Found sequence in {buffer_name} at indices {sequence_range}")
                                if buffer_name not in buffers.keys():
                                    buffers[buffer_name] = {
                                        "buffer": h5f[split][iteration][game][:],
                                        "sequence_ranges": [(i, i+len(sequence))]
                                    }
                                else:
                                    buffers[buffer_name]["sequence_ranges"].append((i, i+len(sequence)))
        return buffers



In [63]:
ba = BufferAnalyzer("checkpoints/2025-01-21 10:11:40.769705")

In [None]:
action_counts = ba.nth_action_count(0)

In [None]:
buffers = ba.find_buffers_with_action_sequence(player_idx=0, sequence=["Build nothing", "1 dice", "Launch Pad"])

In [None]:
print(buffers["train/iteration_0/game_204"]["buffer"][:][:, ba.columns_indices["action"]])
print(buffers["train/iteration_0/game_204"]["buffer"][:][:, ba.columns_indices["value"]])

In [None]:
obs_indices = [i for i, col in enumerate(ba.columns_indices) if col.startswith("obs")]

In [None]:
pvnet = PVNet(env=ba.env)

for model in ["model_0", "model_1", "model_2"]:
    pvnet.load(f"checkpoints/2025-01-21 10:11:40.769705/{model}.pt")
    print(model)
    [print(action, prob) for action, prob in zip(ba.env._action_str_to_idx.keys(), pvnet.predict(buffers["train/iteration_0/game_204"]["buffer"][-1, obs_indices])[0])]


In [None]:
for iteration in action_counts.keys():
    print({k: v for k, v in sorted(action_counts[iteration].items(), key=lambda item: item[1], reverse=True)})

In [112]:
h5f_path = "checkpoints/2025-01-21 10:11:40.769705/buffers.h5"

In [113]:
h5f = h5py.File(h5f_path, "r")

In [120]:
h5f.close()

In [None]:
h5f.attrs["columns"]

In [None]:
h5f["train"]["iteration_0"]["game_0"][:]

In [101]:
h5f.close()

In [3]:
def explore_hdf5_structure(h5_file, group=None, indent=0):
    """
    Recursively explores the group structure of an HDF5 file and prints the number of rows or elements in datasets.
    
    Parameters:
    - h5_file: h5py File object
    - group: Current group or dataset (None starts from the root group)
    - indent: Current indentation level for pretty printing
    """
    if group is None:
        group = h5_file  # Start from the root group
    
    for key in group:
        item = group[key]
        if isinstance(item, h5py.Group):  # If it's a group
            print(" " * indent + f"Group: {key}")
            explore_hdf5_structure(h5_file, item, indent + 4)
        elif isinstance(item, h5py.Dataset):  # If it's a dataset
            shape = item.shape  # Get the shape of the dataset
            num_rows = shape[0] if len(shape) > 0 else 0  # Number of rows if applicable
            print(" " * indent + f"Dataset: {key}, Rows: {num_rows}, Shape: {shape}")

In [None]:
h5f["train"]["iteration_0"]

In [None]:
explore_hdf5_structure(h5f)

In [13]:
data_manager = HDF5DataLoader(h5f_path, subset_rules={"iteration_0": 1.0}, chunk_size=64e5)
train_loader, val_loader = data_manager.get_dataloaders()

In [97]:
obs_col_indices = [i for i, col in enumerate(h5f.attrs["columns"]) if col.startswith("obs")]
probs_col_indices = [i for i, col in enumerate(h5f.attrs["columns"]) if col.startswith("prob")]

In [None]:
init_obs = h5f["train"]["iteration_0"]["game_0"][0, obs_col_indices]
init_prob = h5f["train"]["iteration_0"]["game_0"][0, probs_col_indices]

In [None]:
init_prob

In [109]:
pred = pvnet.predict(init_obs)

In [None]:
pred[0]

In [None]:
import matplotlib.pyplot as plt
import numpy as np

plt.bar([env._action_idx_to_str[idx] for idx in np.arange(len(pred[0]))], pred[0])
plt.xticks(rotation=90)
plt.show()

In [None]:
env._action_idx_to_str[0]

In [None]:
CARD_INFO_PATH = "card_info_machi_koro_2_quick_game.yaml"

env = GymMachiKoro2(n_players=2, card_info_path=CARD_INFO_PATH)

env.reset()
state = env.state_dict()
state["player_info"]["player 0"]["coins"] = 30
state["player_info"]["player 1"]["coins"] = 30
# state["player_info"]["player 0"]["cards"]["Forge"] = 1
# state["player_info"]["player 0"]["cards"]["Park"] = 1
# state["player_info"]["player 1"]["cards"]["Forge"] = 1
# state["player_info"]["player 1"]["cards"]["Park"] = 1
state["marketplace"]["landmark"]["pos_0"]["card"] = "Launch Pad"
state["marketplace"]["landmark"]["pos_1"]["card"] = "Loan Office"
state["marketplace"]["landmark"]["pos_2"]["card"] = "Soda Bottling Plant"
state["marketplace"]["landmark"]["pos_3"]["card"] = "Charterhouse"
state["marketplace"]["landmark"]["pos_4"]["card"] = "Temple"
env.set_state(env.state_dict_to_array(state))

# GAME_START_STATE = env.state_dict_to_array(state)
# next_obs, reward, done, truncated, info = env.step(33)
env.step(39)

In [None]:
env._env.player_icon_count("player 1", "Landmark")

In [None]:
env._action_str_to_idx["Observatory"]

In [None]:
state["marketplace"]["landmark"]