In [1]:
import torch, torchrl

import numpy as np

from tensordict import TensorDict
from torch import Tensor
from torchrl.envs import EnvBase, SerialEnv
from torchrl.data import (
    Binary,
    Bounded,
    Categorical,
    Composite,
    TensorSpec,
    UnboundedContinuous
)

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
board_size = 5 # Size of the Hex board (board_size x board_size)
N_CHANNEL = 4 # Number of channels for the observation (Red, Blue, Current Player, Valid Board)
MAX_BOARD_SIZE = 32 # Maximum board size for the Hex game
SWAP_RULE = False # Whether to use the swap rule in the Hex game

In [3]:
class HexEnv(EnvBase):
    def __init__(self, 
                 board_size: int,
                 max_board_size: int = MAX_BOARD_SIZE,
                #  swap_rule: bool = SWAP_RULE,
                 device: torch.device = 'cpu',
                #  batch_size: torch.Size = torch.Size()
                ):

        # Assertions
        assert board_size >= 1, "Board size must be greater than or equal to 1."
        assert board_size <= max_board_size, "Board size must be less than or equal to max Board size."

        super().__init__(device=device, spec_locked=False)

        # Parameters
        self.board_size: int = board_size
        self.max_board_size: int = max_board_size
        self.n_channel: int = N_CHANNEL
        # self.swap_rule: bool = swap_rule # Not implemented yet
        # self.device: torch.device = device
        # self.batch_size: torch.Size = batch_size # No batching at all

        # Create shape variables
        self.board_shape: torch.Size = torch.Size(
            (self.max_board_size, self.max_board_size)
        ) # (max_board_size, max_board_size)

        # Valid board mask
        self.valid_board: Tensor = torch.zeros(
            self.board_shape, 
            dtype=torch.bool, 
            device=self.device
        ) # (max_board_size, max_board_size)
        self.valid_board[:self.board_size, :self.board_size] = 1

        # Create private spec variables
        self.observation_spec = Composite({
            "observation": Binary(
                shape=self.board_shape + (self.n_channel,),
                # (max_board_size, max_board_size, n_channel)
                device=self.device,
                dtype=torch.float32
            ),
            "mask": Binary(
                shape=self.board_shape,
                # (max_board_size, max_board_size)
                device=self.device,
                dtype=torch.bool
            )
        })
        self.action_spec = Categorical(
            n=self.max_board_size ** 2,
            # Number of discrete actions for each side of the board
            shape=(1,),
            # (1,) because action is a scalar representing the flat index of the board
            device=self.device,
            dtype=torch.long,
            mask=(self.valid_board.flatten() == 1) # (max_board_size ** 2)
        )
        self.reward_spec = UnboundedContinuous(
            shape=(1,),
            device=device,
            dtype=torch.float32
        ) # Reward for both players

    def _reset(self, tensordict: TensorDict | None = None, **kwargs) -> TensorDict:
        # Initialize a fresh board
        board: Tensor = torch.full((self.max_board_size, self.max_board_size), -1, dtype=torch.long) # -1: empty, 0: player 0 (red), 1: player 1 (blue)
        # current_player: int = 0 # 0: player 0 (red), 1: player 1 (blue)
        # valid_move: Tensor = self.valid_board.float() # All valid moves at the start
        done: Tensor = torch.tensor(False, dtype=torch.bool) # Game not done
        reward: Tensor = torch.tensor([0.0], dtype=torch.float32) # No reward at the start

        # Create fresh observation, mask, done, reward
        fresh_action: Tensor = torch.tensor([0], dtype=torch.long) # Placeholder action
        fresh_observation: Tensor = torch.zeros((self.max_board_size, self.max_board_size, self.n_channel), dtype=torch.float32) # (max_board_size, max_board_size, n_channel)
        fresh_observation[..., 0] = (board == 0).float() # Red pieces channel
        fresh_observation[..., 1] = (board == 1).float() # Blue pieces channel
        fresh_observation[..., 2] = 0 # 0: player 0 (red), 1: player 1 (blue)
        fresh_observation[..., -1] = self.valid_board.clone().float() # (max_board_size, max_board_size) Playable board mask
        fresh_mask: Tensor = self.valid_board.clone().bool() # (max_board_size ** 2) Valid move mask
        fresh_done: Tensor = done # Not done
        fresh_reward: Tensor = reward # No reward at the start

        # Update action spec for the environment
        self.action_spec.update_mask(fresh_mask.flatten())

        # Update tensordict
        fresh_tensordict = TensorDict({
            "action": fresh_action,
            "observation": fresh_observation,
            "mask": fresh_mask,
            "done": fresh_done,
            "reward": fresh_reward
        })

        return fresh_tensordict

    def _step(self, tensordict: TensorDict, **kwargs) -> TensorDict:
        # Extract action
        action: Tensor = tensordict.get("action").clone() # Scalar tensor representing the action
        observation: Tensor = tensordict.get("observation").clone() # (max_board_size, max_board_size, n_channel)
        mask: Tensor = tensordict.get("mask").clone() # (max_board_size, max_board_size)
        done: Tensor = tensordict.get("done").clone() # Scalar tensor representing if the game is done
        reward: Tensor = tensordict.get("reward").clone() # (2,)

        # Extract indexes of action from observation
        index: int = int(action.item())
        row, col = divmod(index, self.max_board_size) # Convert flat index to 2D coordinates

        # Extract current state from observation
        current_player: int = int(observation[..., 0, 0, 2].item()) # 0: player 0 (red), 1: player 1 (blue)

        # Validate action
        is_valid = (
            0 <= row < self.max_board_size and # Must be within board's max bounds
            0 <= col < self.max_board_size and # Must be within board's max bounds
            self.valid_board[row, col] and # Must be in valid board area
            mask[row, col] == 1  # Must be empty to place a piece
        )

        # If action is not valid (only when action_spec mask is not working properly)
        if not is_valid:
            raise ValueError(f"Invalid action {action.item()} at row={row}; col={col}; valid={self.valid_board[row, col]}, mask={mask[row, col]}.")
            # reward[self.current_player - 1] = -1.0 # Penalty for invalid move
            # self.done = False # Continue the game even if the move is invalid
            # new_observation, new_mask = tensordict.get("observation"), tensordict.get("mask") # Keep previous observation and mask
        else:
            # Place the piece
            observation[row, col, current_player] = 1.0 # Update observation for the current player
            mask[row, col] = 0 # Update mask to prevent placing another piece here

            # Check for win condition (placeholder logic)
            if self._check_done(observation, current_player):
                reward: Tensor = torch.tensor([1.0 * (1 - current_player) - 1.0 * current_player], dtype=torch.float32, device=self.device) # Single reward for the current player (+1 if player 0 wins, -1 if player 1 wins)
                done = torch.tensor(True, dtype=torch.bool) # Game done
            else:
                reward: Tensor = torch.tensor([0.0], dtype=torch.float32, device=self.device) # Initialize reward
                done = torch.tensor(False, dtype=torch.bool) # Game not done

                # Switch player
                current_player = 1 - current_player # Switch between 0 and 1

            # Update observation, mask
            new_observation: Tensor = torch.zeros((self.max_board_size, self.max_board_size, self.n_channel), dtype=torch.float) # (max_board_size, max_board_size, n_channel)
            new_observation[..., 0] = observation[..., 0] # Red pieces channel
            new_observation[..., 1] = observation[..., 1] # Blue pieces channel
            new_observation[..., 2] = float(current_player) # Current player channel
            new_observation[..., -1] = observation[..., -1] # (max_board_size, max_board_size) Playable board mask (doesn't change)
            new_mask: Tensor = mask.bool() # Valid move mask

        # Create done, reward tensors
        new_action: Tensor = action
        new_done: Tensor = done
        new_reward: Tensor = reward

        # Update action spec for the environment
        self.action_spec.update_mask(new_mask.flatten())

        # Update tensordict
        new_tensordict = TensorDict({
            "action": new_action,
            "observation": new_observation,
            "mask": new_mask,
            "done": new_done,
            "reward": new_reward
        })

        return new_tensordict

    def _check_done(self, observation: Tensor, current_player: int) -> bool:
        def dfs(board, start_positions, target_condition, directions):
            visited = torch.zeros((self.board_size, self.board_size), dtype=torch.bool)
            for start in start_positions:
                if board[start] == 1 and not visited[start]:
                    stack = [start]
                    visited[start] = True
                    while stack:
                        r, c = stack.pop()
                        if target_condition(r, c):
                            return True
                        for dr, dc in directions:
                            nr, nc = r + dr, c + dc
                            if 0 <= nr < self.board_size and 0 <= nc < self.board_size and board[nr, nc] == 1 and not visited[nr, nc]:
                                visited[nr, nc] = True
                                stack.append((nr, nc))
            return False

        directions = [(-1,0), (1,0), (0,-1), (0,1), (1,-1), (-1,1)] # 6 possible directions in a hex grid
        # Use DFS to check if player 0 (red) has connected top to bottom
        if current_player == 0:
            board = (observation[..., 0] == 1)[..., :self.board_size, :self.board_size] # Player 0 pieces
            start_positions = [(0, col) for col in range(self.board_size)]
            target_condition = lambda r, c: r == self.board_size - 1
            if dfs(board, start_positions, target_condition, directions):
                return True

        # Use DFS to check if player 1 (blue) has connected left to right
        else:
            board = (observation[..., 1] == 1)[..., :self.board_size, :self.board_size] # Player 1 pieces
            start_positions = [(row, 0) for row in range(self.board_size)]
            target_condition = lambda r, c: c == self.board_size - 1
            if dfs(board, start_positions, target_condition, directions):
                return True

        return False # No winner yet

    def _set_seed(self, seed: int) -> None:
        np.random.seed(seed)
        torch.manual_seed(seed)

env = HexEnv(board_size=board_size, max_board_size=MAX_BOARD_SIZE, device=device)
env.set_seed(42)

1445067485

# Use rand_action()

In [4]:
# Reset the environment once at the beginning
create_env_fn = HexEnv(board_size=board_size, max_board_size=MAX_BOARD_SIZE, device=device)
serial_env = SerialEnv(
    num_workers=1,
    create_env_fn=create_env_fn
)
state = serial_env.reset()
step = 1

# The loop condition remains the same
while not state.get("done").any(): # Using .any() is safer for batched envs
    # A single call to rand_step() performs a random action and advances the state.
    # It returns a TensorDict with the full transition data.
    transition_data = serial_env.rand_step(state)

    # Extract information for printing from the results
    action = transition_data.get("action")
    current_observation = transition_data.get("observation")
    current_mask = transition_data.get("mask")
    current_done = transition_data.get("done")
    next_state = transition_data.get("next")

    reward = next_state.get("reward")
    next_observation = next_state.get("observation")
    next_mask = next_state.get("mask")
    next_done = next_state.get("done")

    # The printing logic is the same
    x, y = divmod(action.item(), env.max_board_size)
    print(
        f"STEP: {step}. \n"
        f"Current State: "
        f"Player: {current_observation[..., 2].mean().item():.0f}. "
        f"Sum: {current_observation.sum().item():.0f}. "
        f"Remained: {current_mask.sum().item():.0f}. "
        f"{'Done: True. ' if current_done.item() else ''}\n"
        f"Action: {x}-{y}. \n"
        f"Next State: "
        f"Player: {next_observation[..., 2].mean().item():.0f}. "
        f"Sum: {next_observation.sum().item():.0f}. "
        f"Remained: {next_mask.sum().item():.0f}. "
        f"{'Done: True. ' if next_done.item() else ''}\n"
        f"Reward: {reward.item()} \n"
    )
    
    # IMPORTANT: Update the state for the next loop iteration
    step += 1
    state = next_state.clone()

NotImplementedError: EnvBase.forward is not implemented. If you ended here during a call to `ParallelEnv(...)`, please use a constructor such as `ParallelEnv(num_env, lambda env=env: env)` instead. Batched envs require constructors because environment instances may not always be serializable.

In [5]:
# Reset the environment once at the beginning
state = env.reset()

# The loop condition remains the same
while not state.get("done").any(): # Using .any() is safer for batched envs
    # A single call to rand_step() performs a random action and advances the state.
    # It returns a TensorDict with the full transition data.
    transition_data = env.rand_step(state)

    # Extract information for printing from the results
    action = transition_data.get("action")
    next_state = transition_data.get("next")
    
    observation = next_state.get("observation")
    mask = next_state.get("mask")
    reward = next_state.get("reward")
    done = next_state.get("done")

    # The printing logic is the same
    x, y = divmod(action.item(), env.max_board_size)
    print(
        f"Player: {observation[..., 2].mean().item():.0f}. "
        f"Action: {x}-{y}. "
        f"Observation Sum: {observation.sum().item():.0f}. "
        f"Remained: {mask.sum().item():.0f}. "
        f"Done: {done.item()}. "
        f"Reward: {reward.item()}"
    )
    
    # IMPORTANT: Update the state for the next loop iteration
    state = next_state.clone()

Player: 1. Action: 4-0. Observation Sum: 1050. Remained: 24. Done: False. Reward: 0.0
Player: 0. Action: 3-4. Observation Sum: 27. Remained: 23. Done: False. Reward: 0.0
Player: 1. Action: 3-0. Observation Sum: 1052. Remained: 22. Done: False. Reward: 0.0
Player: 0. Action: 4-1. Observation Sum: 29. Remained: 21. Done: False. Reward: 0.0
Player: 1. Action: 0-1. Observation Sum: 1054. Remained: 20. Done: False. Reward: 0.0
Player: 0. Action: 2-3. Observation Sum: 31. Remained: 19. Done: False. Reward: 0.0
Player: 1. Action: 1-4. Observation Sum: 1056. Remained: 18. Done: False. Reward: 0.0
Player: 0. Action: 4-4. Observation Sum: 33. Remained: 17. Done: False. Reward: 0.0
Player: 1. Action: 1-2. Observation Sum: 1058. Remained: 16. Done: False. Reward: 0.0
Player: 0. Action: 2-4. Observation Sum: 35. Remained: 15. Done: False. Reward: 0.0
Player: 1. Action: 2-0. Observation Sum: 1060. Remained: 14. Done: False. Reward: 0.0
Player: 0. Action: 2-1. Observation Sum: 37. Remained: 13. Done:

# Use rand_step()

In [6]:
# Giả sử 'env' đã được khởi tạo
state = env.reset()
done = state.get("done")

while not done.any():
    # Thực hiện một bước ngẫu nhiên.
    # 'transition_data' giờ chứa toàn bộ (S, A, R, S')
    transition_data = env.rand_step(state)

    # Lấy thông tin từ transition để in ra
    action = transition_data.get("action")
    next_state = transition_data.get("next") # Lấy ra state kết quả

    observation = next_state.get("observation")
    mask = next_state.get("mask")
    reward = next_state.get("reward")
    done = next_state.get("done")

    # In thông tin
    x, y = divmod(action.item(), env.max_board_size)
    print(
        f"Player: {observation[..., 2].mean().item():.0f}. "
        f"Action: {x}-{y}. "
        f"Observation Sum: {observation.sum().item():.0f}. "
        f"Remained: {mask.sum().item():.0f}. "
        f"Done: {done.item()}. "
        f"Reward: {reward.item()}"
    )

    # !!! ĐÂY LÀ DÒNG QUAN TRỌNG NHẤT !!!
    # Cập nhật state hiện tại để chuẩn bị cho vòng lặp tiếp theo.
    # .clone() là một thói quen tốt để tránh các lỗi tham chiếu.
    state = next_state.clone()

Player: 1. Action: 4-4. Observation Sum: 1050. Remained: 24. Done: False. Reward: 0.0
Player: 0. Action: 3-2. Observation Sum: 27. Remained: 23. Done: False. Reward: 0.0
Player: 1. Action: 2-2. Observation Sum: 1052. Remained: 22. Done: False. Reward: 0.0
Player: 0. Action: 1-0. Observation Sum: 29. Remained: 21. Done: False. Reward: 0.0
Player: 1. Action: 0-4. Observation Sum: 1054. Remained: 20. Done: False. Reward: 0.0
Player: 0. Action: 0-2. Observation Sum: 31. Remained: 19. Done: False. Reward: 0.0
Player: 1. Action: 3-1. Observation Sum: 1056. Remained: 18. Done: False. Reward: 0.0
Player: 0. Action: 0-0. Observation Sum: 33. Remained: 17. Done: False. Reward: 0.0
Player: 1. Action: 1-4. Observation Sum: 1058. Remained: 16. Done: False. Reward: 0.0
Player: 0. Action: 2-3. Observation Sum: 35. Remained: 15. Done: False. Reward: 0.0
Player: 1. Action: 2-0. Observation Sum: 1060. Remained: 14. Done: False. Reward: 0.0
Player: 0. Action: 4-3. Observation Sum: 37. Remained: 13. Done:

# Use rollout()

In [8]:
# env.rollout() handles the reset internally, so you don't need to call it.
# It collects data for a maximum of 50 steps.
rollout_data = env.rollout(max_steps=50, break_when_any_done=False)

# The 'rollout_data' now contains the entire history of the episode.
# Its shape will be [batch_size, num_steps]. For a single environment, this is [1, T].
# We can squeeze the batch dimension and iterate through the steps to print them.
episode_data = rollout_data.squeeze(0)

print("--- Rollout Results ---")
for i in range(episode_data.shape[0]):
    # Get the data for the i-th step from the recorded history
    step_data = episode_data[i]
    
    action = step_data.get("action")
    next_state = step_data.get("next")
    
    observation = next_state.get("observation")
    mask = next_state.get("mask")
    reward = next_state.get("reward")
    done = next_state.get("done")

    # The printing logic is the same as before
    x, y = divmod(action.item(), env.max_board_size)
    print(
        f"{i + 1:02d}. "
        f"Player: {observation[..., 2].mean().item():.0f}. "
        f"Action: {x}-{y}. "
        f"Remained: {mask.sum().item():.0f}. "
        f"Done: {done.item()}. "
        f"Reward: {reward.item()}"
    )

AttributeError: 'NoneType' object has no attribute 'clone'

# Multi-Environments

# DQN model

In [None]:
import torch

import torch.nn as nn

from torch import Tensor


class HexConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3):
        super(HexConv2d, self).__init__()
        assert kernel_size % 2 == 1 and kernel_size > 0, "kernel_size must be odd and positive."
        stride, padding = 1, kernel_size // 2  # To maintain spatial dimensions

        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.mask = nn.Parameter(self._create_hex_mask(kernel_size), requires_grad=False) # (k, k), requires_grad=False to keep it fixed
        with torch.no_grad():
            self.conv.weight.mul_(self.mask)  # Apply mask to weights

    @staticmethod
    def _create_hex_mask(kernel_size: int) -> Tensor:
        assert kernel_size % 2 == 1 and kernel_size > 0, "kernel_size must be odd and positive."

        mask = torch.zeros((kernel_size, kernel_size), dtype=torch.float32)
        center = kernel_size // 2

        for r in range(kernel_size): # Row index
            for c in range(kernel_size): # Column index
                # Using axial distance for a vertically oriented hex grid
                # mapped to an offset coordinate system in the kernel.
                # (r, c) are kernel indices, (dr, dc) are relative to center.
                dr, dc = r - center, c - center
                chebyshev_distance = max(abs(dr), abs(dc), abs(dr + dc))

                if chebyshev_distance <= center: # Inside or on the hexagon
                    mask[r, c] = 1.0

        return mask # (kernel_size, kernel_size)

    def forward(self, x: Tensor) -> Tensor:
        # Assuming x is of shape (batch_size, channels, height, width)
        # Apply convolution
        x = self.conv(x)
        return x


class SkipConnection(nn.Module):
    def __init__(self, adjust_input: nn.Module, original_input: nn.Module = nn.Identity()):
        super(SkipConnection, self).__init__()
        self.adjust_input = adjust_input
        self.original_input = original_input

    def forward(self, x: Tensor) -> Tensor:
        return self.adjust_input(x) + self.original_input(x)


class TransformerQL(nn.Module):
    def __init__(self, 
                 conv_layers: list[tuple[int, int]],
                 n_encoder_layers: int,
                 d_input: int,
                 n_heads: int = 8,
                 d_ff: int = 2048,
                 dropout: float = 0.1,
                 output_flatten: bool = True):
        """Args:
            conv_layers: List of tuples (out_channels, kernel_size) for each conv layer.
                Note that, in_channels is inferred from the previous layer's out_channels (d_input for the first layer).
            n_encoder_layers: Number of transformer encoder layers.
            d_input: Dimension of input features to the transformer.
            n_heads: Number of attention heads in the transformer.
            d_ff: Dimension of the feedforward network in the transformer.
            dropout: Dropout rate.
        """
        super(TransformerQL, self).__init__()
        self.output_flatten = output_flatten
        self.d_encoder: int = conv_layers[-1][0] # Last conv layer's out_channels as d_model
        self.conv = nn.Sequential(*[
            SkipConnection(
                nn.Sequential(
                    HexConv2d(conv_layers[i-1][0] if i > 0 else d_input, conv_layers[i][0], conv_layers[i][1]),
                    nn.BatchNorm2d(conv_layers[i][0]),
                    nn.GELU(),
                    HexConv2d(conv_layers[i][0], conv_layers[i][0], conv_layers[i][1]),
                    nn.BatchNorm2d(conv_layers[i][0]),
                ),
                nn.Identity() if conv_layers[i][0] == conv_layers[i-1][0] # Skip connection (identity)
                else HexConv2d(conv_layers[i-1][0] if i > 0 else d_input, conv_layers[i][0], 1) # Combine with Conv2d (kernel_size = 1) for channel adjustment
            )
            for i in range(len(conv_layers))
        ])
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=self.d_encoder,
                nhead=n_heads,
                dim_feedforward=d_ff,
                dropout=dropout,
                activation='gelu',
                batch_first=True
            ),
            num_layers=n_encoder_layers
        )
        self.projection = nn.Linear(self.d_encoder, 1)

    def forward(self, x: Tensor, tensordict: TensorDict | None = None) -> Tensor:
        """Args:
            x: Input tensor of shape (N, H, W, C) where
               N = batch size, H = height, W = width, C = channels (d_input).
            Returns: Tensor of shape (N, H, W) with Q-values for each position.
        """
        # Reshape input to (N, C, H, W) for Conv2d
        batch_size, height, width = x.size(0), x.size(1), x.size(2)
        x = x.cuda() if torch.cuda.is_available() else x.cpu()
        x = x.permute(0, 3, 1, 2).contiguous() # (N, C, H, W)
        x = self.conv(x) # (N, d_encoder, H, W)
        # Reshape to (N, H*W, d_encoder)
        x = x.view(batch_size, -1, self.d_encoder) # (N, H*W, d_encoder)
        x = self.encoder(x) # (N, H*W, d_encoder)
        x = self.projection(x) # (N, H*W, 1)
        if self.output_flatten:
            return x.view(batch_size, -1) # (N, H*W)
        else:
            return x.view(batch_size, height, width) # (N, H, W)

In [None]:
class TurnWrapper(nn.Module):
    """
    A custom policy for Hex that wraps a DQN.
    - If it's Player 0's turn, it maximizes the Q-value.
    - If it's Player 1's turn, it minimizes the Q-value.
    It always respects the action mask.
    """
    def __init__(self, dqn_network: nn.Module):
        super().__init__()
        self.dqn_network = dqn_network

    def forward(self, x: Tensor, tensordict: TensorDict | None = None) -> Tensor:
        # x: (N, H, W, C)
        if x.dim() == 3:
            x = x.unsqueeze(0) # Add batch dimension if missing
        batch_size = x.size(0)
        # Step 1: Get the raw Q-values from your network
        mask = ~(x[..., 0].bool() | x[..., 1].bool() | ~x[..., -1].bool()).view(batch_size, -1) # Assuming mask is the sum of red and blue channels
        q_values: Tensor = self.dqn_network(x) # (N, num_actions)

        # Step 2: Determine the current player and adjust Q-values accordingly.
        # We assume the 3rd channel (index 2) of the observation indicates the player.
        # Player 0: channel is all 0s. Player 1: channel is all 1s.
        current_player: float = x[..., 2].mean().item() # Will be 0.0 or 1.0

        if current_player == 1.0:
            # Player 1 (blue) wants to MINIMIZE the Q-value.
            # This is equivalent to MAXIMIZING the negative Q-value.
            effective_q_values = -q_values
        else:
            # Player 0 (red) wants to MAXIMIZE the Q-value.
            effective_q_values = q_values

        # Step 3: Apply the action mask. This is crucial.
        # Set the Q-value of all illegal moves to negative infinity.
        effective_q_values[~mask] = -torch.inf

        return effective_q_values

In [None]:
model = TransformerQL(
    conv_layers=[(32, 3), (64, 3), (128, 3), (256, 3)],
    n_encoder_layers=8,
    d_input=N_CHANNEL,
    n_heads=8,
    d_ff=2048,
    dropout=0.1,
    output_flatten=True
).to(device)
test_state = env.reset().to(device).unsqueeze(0) # Add batch dimension

print(f"Initial state keys: {test_state.keys()}")
print(f"Mask sum (valid moves): {test_state.get('mask').sum().item()}")
print(f"Board size: {board_size}x{board_size}")

from torchrl.modules import QValueActor

# This actor takes observations and an action_mask,
# feeds the observation to the model,
# applies the mask to the resulting Q-values,
# and then outputs the best valid action.
policy = TurnWrapper(model).to(device)
policy_actor = QValueActor(
    module=policy,
    in_keys=["observation"], # CRUCIAL: Tell the actor to use the mask
    spec=env.action_spec
)
with torch.no_grad():
    result = policy_actor(test_state)

print(f"\nResult keys: {result.keys()}")
action = result.get("action")
chosen_value = result.get("chosen_action_value")

if action is not None:
    action_idx = action.item()
    row, col = divmod(action_idx, env.max_board_size)
    print(f"\nChosen action: {action_idx} -> Position ({row}, {col})")
    print(f"Chosen action value: {chosen_value.item() if chosen_value is not None else 'N/A'}")
    print(f"Is valid move: {test_state.get('mask')[..., row, col].item()}")
else:
    print("No action returned!")

# Optional: Check all Q-values
if "action_value" in result.keys():
    q_values = result.get("action_value")
    print(f"\nQ-values shape: {q_values.shape}")
    print(f"Q-values range: [{q_values.min().item():.4f}, {q_values.max().item():.4f}]")

Initial state keys: _StringKeys(dict_keys(['action', 'observation', 'mask', 'done', 'reward', 'terminated']))
Mask sum (valid moves): 25
Board size: 5x5

Result keys: _StringKeys(dict_keys(['action', 'observation', 'mask', 'done', 'reward', 'terminated', 'action_value', 'chosen_action_value']))

Chosen action: 32 -> Position (1, 0)
Chosen action value: 0.7713963985443115
Is valid move: True

Q-values shape: torch.Size([1, 1024])
Q-values range: [-inf, 0.7714]


In [None]:
# env.rollout() handles the reset internally, so you don't need to call it.
# It collects data for a maximum of 50 steps.
rollout_data = env.rollout(max_steps=50, policy=policy_actor, break_when_any_done=False)

# The 'rollout_data' now contains the entire history of the episode.
# Its shape will be [batch_size, num_steps]. For a single environment, this is [1, T].
# We can squeeze the batch dimension and iterate through the steps to print them.
episode_data = rollout_data.squeeze(0)

print("--- Rollout Results ---")
for i in range(episode_data.shape[0]):
    # Get the data for the i-th step from the recorded history
    step_data = episode_data[i]
    
    action = step_data.get("action")
    next_state = step_data.get("next")
    
    observation = next_state.get("observation")
    mask = next_state.get("mask")
    reward = next_state.get("reward")
    done = next_state.get("done")

    # The printing logic is the same as before
    x, y = divmod(action.item(), env.max_board_size)
    print(
        f"{i + 1:02d}. "
        f"Player: {observation[..., 2].mean().item():.0f}. "
        f"Action: {x}-{y}. "
        f"Remained: {mask.sum().item():.0f}. "
        f"Done: {done.item()}. "
        f"Reward: {reward.item()}"
    )

RuntimeError: Expected all tensors to be on the same device, but got weight is on cuda:0, different from other tensors on cpu (when checking argument in method wrapper_CUDA___slow_conv2d_forward)