In [4]:
# -*- coding: utf-8 -*-
"""
train_card_policy_bc.py

Trains a card prediction policy using Behavioral Cloning (BC)
based on data parsed into an HDF5 file using a manual training loop.

Assumes the HDF5 file contains 'state_bits' (bool, N x 929) and
'action_card_bits' (bool, N x 13) datasets, as generated by the parser.
Converts action_card_bits to action_index (int, N) during loading.

Implements a custom PyTorch Dataset and DataLoader to avoid loading the
entire dataset into memory at once, addressing potential Out-of-Memory errors.
The DataLoader reads data samples on-demand from the HDF5 file during training.
"""

import h5py
import numpy as np
import gymnasium as gym
# stable_baselines3 is implicitly used by imitation for policy structure
import stable_baselines3 as sb3
# Import core BC algorithm and data types
from imitation.algorithms import bc
from imitation.data import types
# Need PyTorch Dataset and DataLoader for memory-efficient loading
import torch as th
from torch.utils.data import Dataset, DataLoader

import logging
import os
import traceback
from typing import Tuple, Optional, List, Any
from sklearn.model_selection import train_test_split

# --- Logging Configuration ---
logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s - %(levelname)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S')

# --- Configuration ---
HDF5_FILE_PATH = r'../../../Training_Data/jass.hdf5' # <--- Path to the HDF5 file from your parser
POLICY_SAVE_PATH = 'jass_bc_card_policy_591' # Model save path

# --- Data Keys (Must match your parser's output HDF5 structure) ---
OBS_KEY = 'state_bits'         # N x 929 boolean array
ACTION_BITS_KEY = 'action_card_bits' # N x 13 boolean array (the played card's representation)

# --- Model Dimensions (Based on your parser) ---
INPUT_DIM = 929 # From your parser's STATE_BITS calculation
CARD_ACTION_DIM = 9 # 9 possible card positions in hand (index 0-8)

# --- Constants from Parser (Needed to interpret state_bits) ---
CARD_BITS = 13 # Bits used to represent a single card in the parser
NUM_CARDS_HISTORY = 32 # Number of history card slots in state (32 * 13 bits)
NUM_CARDS_TABLE = 3  # Number of table card slots in state (3 * 13 bits) - matches parser code, not comment
NUM_CARDS_HAND = 9   # Number of hand card slots in state (9 * 13 bits) - matches CARD_ACTION_DIM conceptually

# Calculate the starting index of the player's hand bits within the state_bits vector
# Order from parser: History, Table, Hand, Shown, Trump
HAND_START_BIT_INDEX = (NUM_CARDS_HISTORY * CARD_BITS) + (NUM_CARDS_TABLE * CARD_BITS)
# 32 * 13 = 416
# 3 * 13  = 39
# HAND_START_BIT_INDEX = 416 + 39 = 455. The hand bits are from index 455 up to 455 + (9*13) = 455 + 117 = 572.


# --- Training Hyperparameters ---
VALIDATION_SPLIT_SIZE = 0.15 # Use 15% of data for validation (optional)
RANDOM_SEED = 42 # For reproducible train/test splits
BC_BATCH_SIZE = 32
BC_LEARNING_RATE = 3e-4
BC_N_EPOCHS = 10 # Number of passes over the training data
BC_NUM_WORKERS = 0 # Number of subprocesses for data loading (0 means main process). Start with 0 to debug HDF5 access issues with multiprocessing. Increase later if needed and HDF5 access is thread-safe or handled properly.

# --- Set device to attempt CUDA ---
device = th.device("cpu")
logging.info(f"Using device: {device}")


# --- 1. Custom Dataset for Memory-Efficient Loading ---

class JassCardDataset(Dataset):
    """
    Custom PyTorch Dataset to load observations and action indices
    on-demand from an HDF5 file, avoiding loading everything into memory.
    """
    def __init__(self, filepath: str, sample_list: List[Tuple[str, int, int]]):
        """
        Args:
            filepath (str): Path to the HDF5 file.
            sample_list (List[Tuple[str, int, int]]): A list of tuples where each tuple
                is (group_name, index_in_group, derived_action_index). This list
                is generated by scanning the HDF5 file *initially* to identify
                valid samples and their corresponding action indices.
        """
        self.filepath = filepath
        self.sample_list = sample_list
        logging.info(f"Initialized JassCardDataset with {len(self.sample_list)} samples.")

    def __len__(self):
        """Returns the total number of samples in the dataset."""
        return len(self.sample_list)

    def __getitem__(self, idx: int) -> Tuple[th.Tensor, th.Tensor]:
        """
        Loads and returns the observation and action index for a given index `idx`.

        Args:
            idx (int): The index in the `self.sample_list`.

        Returns:
            Tuple[th.Tensor, th.Tensor]: A tuple containing the observation tensor
                (float32) and the action index tensor (int64).
        """
        # Get the pre-calculated information for this sample
        group_name, sample_idx_in_group, derived_action_index = self.sample_list[idx]

        try:
            # Open the HDF5 file in __getitem__ for robustness with num_workers > 0
            # (Each worker process will open its own file handle)
            with h5py.File(self.filepath, 'r') as f:
                # Read only the observation for this specific sample
                # Access dataset via group, then slice to get one row
                obs_data = f[group_name][OBS_KEY][sample_idx_in_group, :] # Shape (INPUT_DIM,)

            # Convert observation data to float32 PyTorch tensor
            obs_tensor = th.tensor(obs_data, dtype=th.float32)

            # Convert the pre-calculated action index to int64 PyTorch tensor
            action_tensor = th.tensor(derived_action_index, dtype=th.int64)

            return obs_tensor, action_tensor

        except Exception as e:
            logging.error(f"Error loading data for index {idx} ({group_name}/{sample_idx_in_group}) from HDF5: {e}")
            # Depending on how critical corrupted samples are, you could return
            # dummy data, raise an error, or have the DataLoader's collate_fn handle it.
            # Returning zeros might corrupt training. For now, let's log and raise,
            # or rely on DataLoader error handling. A better approach for training
            # is usually to filter invalid samples *before* creating the sample_list.
            # Since we filter in scan_hdf5_for_valid_samples, this exception
            # indicates a read error, not a data validation error.
            raise e # Re-raise the exception


# --- 2. Function to Scan HDF5 and Identify Valid Samples ---

def scan_hdf5_for_valid_samples(filepath: str) -> List[Tuple[str, int, int]]:
    """
    Scans the HDF5 file group by group to identify valid state/action pairs
    and derive the action index (0-8). Returns a list of tuples containing
    (group_name, index_in_group, derived_action_index) for each valid sample.
    This avoids loading all observations into memory at once, but loads
    each group's data temporarily.
    """
    valid_samples = []

    if not os.path.exists(filepath):
        logging.error(f"HDF5 file not found at: {filepath}")
        return []

    try:
        with h5py.File(filepath, 'r') as f:
            game_groups = list(f.keys())[:2000] # Load all groups
            logging.info(f"Scanning {len(game_groups)} game groups in HDF5 file for valid samples...")

            if not game_groups:
                logging.error("No game groups found in the HDF5 file.")
                return []

            for i, group_name in enumerate(game_groups):
                if (i + 1) % 500 == 0: # Log progress every 500 groups
                     logging.info(f"Scanning group {i+1}/{len(game_groups)}: {group_name}")
                try:
                    group = f[group_name]
                    # --- Check required keys ---
                    if OBS_KEY not in group:
                        logging.warning(f"Scan Skipping group '{group_name}': Missing dataset '{OBS_KEY}'")
                        continue
                    if ACTION_BITS_KEY not in group:
                        logging.error(f"Scan CRITICAL: Skipping group '{group_name}': Missing dataset '{ACTION_BITS_KEY}'.")
                        logging.error(f"Ensure your parser saves the 13-bit card representation as '{ACTION_BITS_KEY}'.")
                        continue # Skip group if critical action data is missing

                    # Load data for *this group only* - potential memory use here,
                    # but much less than loading ALL groups at once.
                    obs_group_data = group[OBS_KEY][:] # Shape (N, INPUT_DIM) bool
                    actions_bits_group_data = group[ACTION_BITS_KEY][:] # Shape (N, CARD_BITS) bool

                    # --- Basic Validation ---
                    if obs_group_data.shape[0] != actions_bits_group_data.shape[0]:
                        logging.warning(f"Scan Skipping group '{group_name}': Data length mismatch ({obs_group_data.shape[0]} vs {actions_bits_group_data.shape[0]}).")
                        continue
                    if obs_group_data.shape[0] == 0:
                         logging.info(f"Scan Skipping empty group '{group_name}'.")
                         continue
                    if obs_group_data.shape[1] != INPUT_DIM:
                        logging.warning(f"Scan Skipping group '{group_name}': Observation dimension mismatch (Expected {INPUT_DIM}, Got {obs_group_data.shape[1]}).")
                        continue
                    if actions_bits_group_data.shape[1] != CARD_BITS:
                        logging.warning(f"Scan Skipping group '{group_name}': Action bits dimension mismatch (Expected {CARD_BITS}, Got {actions_bits_group_data.shape[1]}).")
                        continue

                    # --- Convert Action Card Bits to Action Index (0-8) ---
                    # Iterate through each sample in the group
                    # Convert boolean arrays to integers (0 or 1) for comparison
                    obs_group_int = obs_group_data.astype(np.int8)
                    actions_bits_group_int = actions_bits_group_data.astype(np.int8)

                    for j in range(obs_group_int.shape[0]):
                        state_vec_int = obs_group_int[j, :]
                        played_card_bits_int = actions_bits_group_int[j, :]

                        # Extract the player's hand bits from the state vector
                        hand_bits_int = state_vec_int[HAND_START_BIT_INDEX : HAND_START_BIT_INDEX + NUM_CARDS_HAND * CARD_BITS]
                        # Reshape hand bits for easier card-by-card comparison
                        hand_bits_reshaped = hand_bits_int.reshape(NUM_CARDS_HAND, CARD_BITS) # Shape (9, 13)

                        # Find the index of the played_card_bits within the hand_bits
                        action_index = -1
                        # Use axis=1 for comparison along the 13-bit dimension
                        # np.all checks if all elements in the comparison result are True
                        # np.where finds the indices where this condition is met
                        matches = np.where(np.all(hand_bits_reshaped == played_card_bits_int, axis=1))[0]

                        if len(matches) == 1:
                            # Found exactly one matching card in hand
                            action_index = matches[0]
                            valid_samples.append((group_name, j, int(action_index))) # Store as tuple (group, index_in_group, derived_action_index)
                        elif len(matches) > 1:
                             # This should ideally not happen if hand logic is correct and cards are unique
                             logging.warning(f"Scan WARNING: Found multiple matches ({len(matches)}) for played card (bits: {played_card_bits_int.tolist()}) in hand for sample {j} in group '{group_name}'. Skipping this sample.")
                             # Decide how to handle: skip, pick first, etc. Skipping is safest.
                             continue # Skip this sample
                        else: # len(matches) == 0
                             # This indicates the card played was not found in the state's hand representation
                             # (e.g., parsing error, state/action mismatch)
                             # logging.warning(f"Scan WARNING: Card played (bits: {played_card_bits_int.tolist()}) not found in hand for sample {j} in group '{group_name}'. Skipping this sample.")
                             continue # Skip this sample as action index is ambiguous/invalid

                    # Free up memory from the loaded group data before processing next group
                    del obs_group_data, actions_bits_group_data, obs_group_int, actions_bits_group_int, hand_bits_int, hand_bits_reshaped

                except Exception as e:
                    logging.error(f"Scan Error processing group '{group_name}': {e}")
                    traceback.print_exc()
                    continue # Skip this group on error

        logging.info(f"Scan complete. Identified {len(valid_samples)} valid samples across all groups.")
        return valid_samples

    except Exception as e:
        logging.critical(f"Scan: Failed to open or process HDF5 file '{filepath}': {e}")
        traceback.print_exc()
        return []


# --- Main Execution ---
if __name__ == "__main__":
    logging.info("--- Starting Card Policy Training (BC) ---")

    # --- 1. Scan HDF5 and get list of valid samples ---
    logging.info(f"Scanning HDF5 file '{HDF5_FILE_PATH}' to identify valid samples...")
    all_valid_samples = scan_hdf5_for_valid_samples(HDF5_FILE_PATH)

    if not all_valid_samples:
        logging.critical("No valid samples found in the HDF5 file after scanning. Exiting.")
        exit()

    logging.info(f"Total valid data samples identified: {len(all_valid_samples)}.")

    # --- 2. Split the list of valid samples ---
    # Split the list of (group_name, index_in_group, derived_action_index) tuples
    logging.info(f"Splitting valid samples (keeping {1-VALIDATION_SPLIT_SIZE:.0%} for training)...")
    train_samples_list, val_samples_list = train_test_split(
        all_valid_samples,
        test_size=VALIDATION_SPLIT_SIZE,
        random_state=RANDOM_SEED,
        shuffle=True,
        # Stratify is tricky here because we only have the index (0-8), not the full action.
        # If distribution is heavily skewed, consider a custom stratified split on the index_in_tuple[2].
        # For simplicity, skipping stratification for now.
    )
    logging.info(f"BC Training samples: {len(train_samples_list)}, Validation samples: {len(val_samples_list)}")
    # Free up memory from the full list
    del all_valid_samples, val_samples_list # Keep val_samples_list if needed for evaluation later


    # --- 3. Create Datasets and DataLoaders ---
    train_dataset = JassCardDataset(HDF5_FILE_PATH, train_samples_list)

    # Create DataLoader for batching and shuffling training data
    train_loader = DataLoader(
        train_dataset,
        batch_size=BC_BATCH_SIZE,
        shuffle=True, # Shuffle training data
        num_workers=BC_NUM_WORKERS, # 0 means main process, adjust if needed
        pin_memory=True if device.type == 'cuda' else False # Pin memory for faster GPU transfer
    )
    logging.info(f"Created DataLoader with batch_size={BC_BATCH_SIZE}, num_workers={BC_NUM_WORKERS}.")
    # Free up memory from the list now held by the dataset
    del train_samples_list


    # --- 4. Define Environment Spaces for Imitation Library ---
    observation_space = gym.spaces.Box(low=0, high=1, shape=(INPUT_DIM,), dtype=np.float32)
    action_space = gym.spaces.Discrete(CARD_ACTION_DIM) # Discrete space for action index (0-8)


    # --- 5. Train Behavioral Cloning Model using DataLoader ---
    rng = np.random.default_rng(RANDOM_SEED)

    # The BC trainer can accept a DataLoader directly via the 'data_loader' argument
    # This is the key difference to avoid loading all data into 'demonstrations'.
    bc_trainer = bc.BC(
        observation_space=observation_space,
        action_space=action_space,
        batch_size=BC_BATCH_SIZE, # Note: batch_size here is used by BC internally but the DataLoader controls the actual batching
        optimizer_kwargs=dict(lr=BC_LEARNING_RATE),
        device=device, # Ensure the trainer/policy are created on the correct device
        rng=rng,
        # Note: BC trainer calculates loss batch by batch from the DataLoader
    )

    logging.info(f"Starting BC training for {BC_N_EPOCHS} epochs using DataLoader...")
    if len(train_dataset) > 0:
        # When using data_loader, train method uses n_epochs argument directly.
        # The loader yields batches, and train runs for the specified number of epochs.
        bc_trainer.train(n_epochs=BC_N_EPOCHS, data_loader=train_loader) # <-- ADD data_loader here
        # Optional: Add validation evaluation here after training if val_samples_list is kept

    else:
        logging.warning("No training samples available. Skipping training.")


    # --- 6. Save Policy ---
    # We assume training finished if we got here without critical errors and train_dataset was not empty
    if len(train_dataset) > 0:
        try:
            # The policy object is stored in bc_trainer.policy
            # Save it using the underlying SB3 policy save method
            # This saves an SB3-compatible policy, usually as a .zip file
            save_path_zip = f"{POLICY_SAVE_PATH}.zip"
            bc_trainer.policy.save(save_path_zip)
            logging.info(f"Card prediction policy saved successfully to {save_path_zip}")
        except Exception as e:
            logging.error(f"Error saving BC policy: {e}")
            traceback.print_exc()
    else:
        logging.warning("Skipping policy save because no training was performed.")


    logging.info("--- Card Policy Training Finished ---")

2025-04-20 14:55:21 - INFO - Using device: cpu
2025-04-20 14:55:21 - INFO - --- Starting Card Policy Training (BC) ---
2025-04-20 14:55:21 - INFO - Scanning HDF5 file '../../../Training_Data/jass.hdf5' to identify valid samples...
2025-04-20 14:55:28 - INFO - Scanning 2000 game groups in HDF5 file for valid samples...
2025-04-20 14:55:30 - INFO - Scan Skipping empty group 'S2500_00ac20ed3234029404a30fb0e06d7109'.
2025-04-20 14:55:33 - INFO - Scan Skipping empty group 'S2500_0262972aa114bc043c893e8a1e99aad1'.
2025-04-20 14:55:36 - INFO - Scanning group 500/2000: S2500_03c94bff41ef615ddec3bdac0d88ea9a
2025-04-20 14:55:38 - INFO - Scan Skipping empty group 'S2500_049292ad4c8103be55bcc178aa3a3620'.
2025-04-20 14:55:42 - INFO - Scan Skipping empty group 'S2500_06a7d30b3751fde4fc347acb627b58c8'.
2025-04-20 14:55:44 - INFO - Scanning group 1000/2000: S2500_0730d9d611061cc0a24fb83c62cdca11
2025-04-20 14:55:50 - INFO - Scanning group 1500/2000: S2500_0b05817feed1dd5180a24f86a4df7b1a
2025-04-20 

TypeError: BC.train() got an unexpected keyword argument 'data_loader'