**方針**
- マッチに勝利した方を抽出する
- stepごとにaction -> observation -> action -> observation -> ... となる
- 0step目のactionは1マッチ目は空でそれ以降は最後のaction, observationは1ステップ目は無でそれ以降は最後のobservation
- 1step目のactionは全て-1, observationは初期状態
- 最初の情報は使えるのか?

In [245]:
import json
import os
import numpy as np
import polars as pl

In [246]:
DATA_REPLAY_DIR = os.path.join("..", "data", "match", "replay")

replay_files = os.listdir(DATA_REPLAY_DIR)


def load_json(filename):
    file_path = os.path.join(DATA_REPLAY_DIR, filename)
    with open(file_path) as f:
        return json.load(f)

In [247]:
def get_winners(steps: dict) -> list:
    winners = []

    for i in range(101, 506, 101):
        reward_0 = steps[i][0]["reward"]
        prev_reward_0 = steps[i - 1][0]["reward"]
        winner = 0 if reward_0 != prev_reward_0 else 1
        winners.append(winner)

    return winners


def flip_coords(coords: list) -> list:
    MAP_SIZE = 24
    return [[MAP_SIZE - 1 - pos if pos >= 0 else pos for pos in coord] for coord in coords]


def flip_map(map_data: list) -> list:
    return [row[::-1] for row in map_data[::-1]]

In [248]:
def create_dataframe_from_steps(steps: dict) -> pl.DataFrame:
    NUM_STEPS_PER_MATCH = 100
    winners = get_winners(steps)

    step_data_list = []
    relic_nodes_memory = json.loads(steps[0][0]["observation"]["obs"])["relic_nodes"]

    for i, winner in enumerate(winners):
        enemy = 1 - winner
        ind_start = i * NUM_STEPS_PER_MATCH + i + 1

        for ind in range(ind_start, ind_start + NUM_STEPS_PER_MATCH):
            obs = json.loads(steps[ind][winner]["observation"]["obs"])

            units_positions = obs["units"]["position"][winner]
            enemy_units_positions = obs["units"]["position"][enemy]
            units_energy = obs["units"]["energy"][winner]
            enemy_units_energy = obs["units"]["energy"][enemy]
            sensor_mask = obs["sensor_mask"]
            tile_energy = obs["map_features"]["energy"]
            tile_type = obs["map_features"]["tile_type"]

            relic_nodes = obs["relic_nodes"]
            relic_nodes_memory = [
                obs_node if obs_node[0] >= 0 else memory_node
                for obs_node, memory_node in zip(relic_nodes, relic_nodes_memory)
            ]

            team_points = obs["team_points"]
            next_team_points = json.loads(steps[ind + 1][winner]["observation"]["obs"])["team_points"]
            team_rewards = [next_team_points[i] - team_points[i] for i in range(2)]

            if winner == 1:
                units_positions = flip_coords(units_positions)
                enemy_units_positions = flip_coords(enemy_units_positions)
                sensor_mask = flip_map(sensor_mask)
                tile_energy = flip_map(tile_energy)
                tile_type = flip_map(tile_type)
                team_points = team_points[::-1]
                team_rewards = team_rewards[::-1]

            action = steps[ind + 1][winner]["action"]

            step_data = {
                "units_positions": units_positions,
                "enemy_units_positions": enemy_units_positions,
                "units_energy": units_energy,
                "enemy_units_energy": enemy_units_energy,
                "sensor_mask": sensor_mask,
                "tile_energy": tile_energy,
                "tile_type": tile_type,
                "relic_nodes": relic_nodes_memory if winner == 0 else flip_coords(relic_nodes_memory),
                "team_points": team_points,
                "team_rewards": team_rewards,
                "action": action,
            }

            step_data_list.append(step_data)

    df = pl.DataFrame(step_data_list)

    df = df.with_columns(
        pl.col("units_positions").cast(pl.List(pl.List(pl.Int8))),
        pl.col("enemy_units_positions").cast(pl.List(pl.List(pl.Int8))),
        pl.col("units_energy").cast(pl.List(pl.Int16)),
        pl.col("enemy_units_energy").cast(pl.List(pl.Int16)),
        pl.col("tile_energy").cast(pl.List(pl.List(pl.Int16))),
        pl.col("tile_type").cast(pl.List(pl.List(pl.Int8))),
        pl.col("relic_nodes").cast(pl.List(pl.List(pl.Int8))),
        pl.col("team_points").cast(pl.List(pl.Int16)),
        pl.col("team_rewards").cast(pl.List(pl.Int16)),
        pl.col("action").cast(pl.List(pl.List(pl.Int8))),
    )

    return df

In [249]:
df = None
for replay_file in replay_files:
    replay = load_json(replay_file)
    steps = replay["steps"]
    if len(steps) < 506:
        print(f"Skipping {replay_file} because it has only {len(steps)} steps")
        continue

    tmp_df = create_dataframe_from_steps(steps)
    if df is None:
        df = tmp_df
    else:
        df = df.vstack(tmp_df)

In [250]:
df.write_parquet("../data/preprocessed/train.parquet")

In [251]:
df = pl.read_parquet("../data/preprocessed/train.parquet")