In [1]:
from typing import Tuple, List
import pathlib
import numpy as np
from loguru import logger
from PIL import Image
from src.utils.load import load_game_state
def load_sample_file(
    sample_file_path_pair: Tuple[pathlib.Path, pathlib.Path]
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Loads the JPEG and corresponding binary file from the specified sample file path,
    and returns the loaded images as NumPy arrays.
    """
    jpeg_file, binary_file = sample_file_path_pair
    jpeg_image = np.array(Image.open(jpeg_file))
    game_state = load_game_state(binary_file)
    return jpeg_image, game_state


def get_sample_file_paths(directory: str) -> List[Tuple[pathlib.Path, pathlib.Path]]:
    """
    Finds all the JPEG and corresponding binary files in the specified directory,
    and returns a list of tuples containing the file paths.
    """
    jpeg_files = list(pathlib.Path(directory).glob("*.jpeg"))
    jpeg_files.sort(key=lambda path: int(path.stem))

    sample_file_paths = []
    for jpeg_file in jpeg_files:
        binary_file = jpeg_file.with_suffix(".bin")
        if binary_file.exists():
            sample_file_paths.append((jpeg_file, binary_file))
        else:
            logger.warning(
                f"Could not find corresponding .bin file for {jpeg_file}. Skipping file."
            )

    return sample_file_paths

In [3]:
if __name__ == "__main__":
    recording_path = "../../recordings/monza_audi_r8_lms_1"
    sample_file_paths = get_sample_file_paths(recording_path)

    # Extract relevant data from state dictionary at each time step
    datas = []

    for sample in sample_file_paths[:100]:
        _, state = load_sample_file(sample)
        datas.append(state)

In [19]:
import numpy as np

# Define a function to cast numpy types to Python types
def cast_numpy_types(obj):
    if isinstance(obj, np.int32):
        return int(obj)
    elif isinstance(obj, np.float32):
        return float(obj)
    else:
        return obj

# Convert numpy types to Python types in the list of dictionaries
converted_datas = [{k: cast_numpy_types(v) for k, v in state.items()} for state in datas]

In [23]:
import json
import io
with io.open('data.json', 'w', encoding='utf-8') as f:
    f.write(json.dumps(converted_datas, ensure_ascii=True))