# Setup

In [None]:
import math, torch, torchrl

import numpy as np
import torch.nn as nn
import torch.nn.functional as F

from tensordict import TensorDict
from tensordict.nn import TensorDictModule
from torch import Tensor
from torch.nn.attention import SDPBackend, sdpa_kernel
from torchrl.collectors import SyncDataCollector
from torchrl.data import (
    Binary,
    Bounded,
    Categorical,
    Composite,
    LazyTensorStorage,
    ReplayBuffer,
    TensorSpec,
    UnboundedContinuous
)
from torchrl.envs import EnvBase, SerialEnv
from torchrl.envs.transforms import ActionMask, TransformedEnv
from torchrl.modules import MaskedCategorical, ProbabilisticActor
from torchrl.objectives import SoftUpdate
from torchrl.objectives.sac import DiscreteSACLoss

In [None]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
BATCH_SIZE = 64 # Batch size for training
LR = 1e-3 # Learning rate for the optimizer
WEIGHT_DECAY = 1e-4 # Weight decay for the optimizer

MAX_BOARD_SIZE = 4 # Maximum board size for the Hex game
N_CHANNEL = 4 # Number of channels for the observation (Red, Blue, Current Player, Valid Board)
BOARD_SIZE = 4 # Size of the Hex board (board_size x board_size)
SWAP_RULE = True # Whether to use the swap rule in the Hex game

BUFFER_SIZE = 10000 # Size of the replay buffer
N_FRAMES_PER_BATCH = 1024 # Number of frames to store in the replay buffer per episode
STORAGE_DEVICE = 'cpu' # Device for storing the replay buffer data
MODEL_PARAMS = {
    "conv_layers": [(32, 3), (64, 3)],
    "n_encoder_layers": 2,
    "d_input": N_CHANNEL,
    "n_heads": 2,
    "d_ff": 1024,
    "dropout": 0.01,
    "output_flatten": True,
}

# Environment

## Base Environment

In [None]:
class HexEnv(EnvBase):
    def __init__(self, 
                 board_size: int,
                 max_board_size: int = MAX_BOARD_SIZE,
                 swap_rule: bool = SWAP_RULE,
                 device: torch.device = DEVICE,
                #  batch_size: torch.Size = torch.Size()
                ):

        # Assertions
        assert board_size >= 1, "Board size must be greater than or equal to 1."
        assert board_size <= max_board_size, "Board size must be less than or equal to max Board size."

        super().__init__(device=device, spec_locked=False)

        # Parameters
        self.board_size: int = board_size
        self.max_board_size: int = max_board_size
        self.n_channel: int = N_CHANNEL
        self.swap_rule: bool = swap_rule
        # self.device: torch.device = device
        # self.batch_size: torch.Size = batch_size # No batching at all

        # Create shape variables
        self.board_shape: torch.Size = torch.Size(
            (self.max_board_size, self.max_board_size)
        ) # (max_board_size, max_board_size)

        # Valid board mask
        self.valid_board: Tensor = torch.zeros(
            self.board_shape, 
            dtype=torch.bool, 
            device=self.device
        ) # (max_board_size, max_board_size)
        self.valid_board[:self.board_size, :self.board_size] = 1

        # Create private spec variables
        self.observation_spec = Composite({
            "observation": Binary(
                shape=self.board_shape + (self.n_channel,),
                # (max_board_size, max_board_size, n_channel)
                device=self.device,
                dtype=torch.float32
            ),
            "action_mask": Binary(
                shape=(self.max_board_size ** 2,),
                # (max_board_size ** 2,)
                device=self.device,
                dtype=torch.bool
            )
        })
        self.action_spec = Categorical(
            n=self.max_board_size ** 2,
            # Number of discrete actions for each side of the board
            device=self.device,
            dtype=torch.long
        )
        self.reward_spec = UnboundedContinuous(
            shape=(1,),
            device=device,
            dtype=torch.float32
        ) # Reward for both players

    def _reset(self, tensordict: TensorDict | None = None, **kwargs) -> TensorDict:
        # Initialize a fresh board
        board: Tensor = torch.full((self.max_board_size, self.max_board_size), -1, dtype=torch.long, device=self.device) # -1: empty, 0: player 0 (red), 1: player 1 (blue)
        current_player: int = 0 # 0: player 0 (red), 1: player 1 (blue)
        # valid_move: Tensor = self.valid_board.float() # All valid moves at the start
        done: Tensor = torch.tensor(False, dtype=torch.bool, device=self.device) # Game not done
        # reward: Tensor = torch.tensor([0.0], dtype=torch.float32, device=self.device) # No reward at the start

        # Create fresh observation, mask, done, reward
        fresh_action: Tensor = torch.tensor([0], dtype=torch.long, device=self.device) # Placeholder action
        fresh_observation: Tensor = torch.zeros((self.max_board_size, self.max_board_size, self.n_channel), dtype=torch.float32, device=self.device) # (max_board_size, max_board_size, n_channel)
        fresh_observation[..., 0] = (board == 0).float() # Red pieces channel
        fresh_observation[..., 1] = (board == 1).float() # Blue pieces channel
        fresh_observation[..., -2] = current_player # 0: player 0 (red), 1: player 1 (blue)
        fresh_observation[..., -1] = self.valid_board.clone().float() # (max_board_size, max_board_size) Playable board mask
        fresh_action_mask: Tensor = self.valid_board.clone().bool().flatten() # (max_board_size ** 2,) Valid move mask
        fresh_done: Tensor = done # Not done
        # fresh_reward: Tensor = reward # No reward at the start

        # # Update action spec for the environment
        # self.action_spec.update_mask(fresh_action_mask.flatten())

        fresh_tensordict: TensorDict = TensorDict({
            "action": fresh_action,
            "observation": fresh_observation,
            "action_mask": fresh_action_mask,
            "done": fresh_done,
            # "reward": fresh_reward
        }, device=self.device)
        # # Update tensordict
        # if not isinstance(tensordict, TensorDict):
        #     fresh_tensordict = TensorDict({
        #         "action": fresh_action,
        #         "observation": fresh_observation,
        #         "action_mask": fresh_action_mask,
        #         "done": fresh_done,
        #         # "reward": fresh_reward
        #     }, device=self.device)
        # else:
        #     fresh_tensordict: TensorDict = tensordict
        #     fresh_tensordict.update({
        #         "action": fresh_action,
        #         "observation": fresh_observation,
        #         "action_mask": fresh_action_mask,
        #         "done": fresh_done,
        #         # "reward": fresh_reward
        #     })

        return fresh_tensordict

    def _step(self, tensordict: TensorDict, **kwargs) -> TensorDict:
        # Extract action
        action: Tensor = tensordict.get("action").clone() # Scalar tensor representing the action
        observation: Tensor = tensordict.get("observation").clone() # (max_board_size, max_board_size, n_channel)
        action_mask: Tensor = tensordict.get("action_mask").clone() # (max_board_size ** 2,)
        done: Tensor = tensordict.get("done").clone() # Scalar tensor representing if the game is done
        reward: Tensor = self.reward_spec.zero() # Initialize reward tensor # (1,)
        # reward: Tensor = tensordict.get("reward").clone() # (2,)

        # Extract indexes of action from observation
        index: int = int(action.item())
        row, col = divmod(index, self.max_board_size) # Convert flat index to 2D coordinates

        # Extract current state from observation
        current_player: int = int(observation[0, 0, -2].item()) # 0: player 0 (red), 1: player 1 (blue)

        # Check if this is a swap situation
        is_first_move = (torch.sum(observation[..., 0:2]).item() == 0 and
                         current_player == 0)  # Player 0's turn and no pieces placed yet
        is_second_move = (torch.sum(observation[..., 0:2]).item() == 1 and
                          current_player == 1)  # Player 1's turn and only one piece placed
        is_swap_action = (self.swap_rule and
                        is_second_move and
                        observation[row, col, 0] == 1) # Player 1 selecting player 0's piece

        # Validate action
        is_valid = (
            0 <= row < self.max_board_size and # Must be within board's max bounds
            0 <= col < self.max_board_size and # Must be within board's max bounds
            self.valid_board[row, col] and # Must be in valid board area
            (action_mask[index] == 1 or is_swap_action)  # Must be empty to place a piece, or a valid swap action
        )

        # If action is not valid (only when action_spec mask is not working properly)
        if not is_valid:
            raise ValueError(f"Invalid action {action.item()} at row={row}; col={col}; valid={self.valid_board[row, col]}, action_mask={action_mask[index]}.")
            # reward[self.current_player - 1] = -1.0 # Penalty for invalid move
            # self.done = False # Continue the game even if the move is invalid
            # new_observation, new_action_mask = tensordict.get("observation"), tensordict.get("action_mask") # Keep previous observation and action_mask
        else:
            # Update action_mask to prevent placing another piece here
            action_mask[index] = 0 # Update action_mask to prevent placing another piece here

            # Place piece or swap
            if is_swap_action: # Swap the pieces
                observation[..., 0], observation[..., 1] = observation[..., 1].clone(), observation[..., 0].clone()
            else: # Place the piece on the board
                observation[row, col, current_player] = 1.0 # Update observation for the current player

            # Check for win condition (placeholder logic)
            if self._check_done(observation, current_player):
                reward: Tensor = torch.tensor([1.0 * (1 - current_player) - 1.0 * current_player], dtype=torch.float32, device=self.device) # Single reward for the current player (+1 if player 0 wins, -1 if player 1 wins)
                done = torch.tensor(True, dtype=torch.bool) # Game done
            else:
                reward: Tensor = torch.tensor([0.0], dtype=torch.float32, device=self.device) # Initialize reward
                done = torch.tensor(False, dtype=torch.bool) # Game not done

                # Switch player
                current_player = 1 - current_player # Switch between 0 and 1

            # Update observation, action_mask
            new_observation: Tensor = torch.zeros((self.max_board_size, self.max_board_size, self.n_channel), dtype=torch.float, device=self.device) # (max_board_size, max_board_size, n_channel)
            new_observation[..., 0] = observation[..., 0] # Red pieces channel
            new_observation[..., 1] = observation[..., 1] # Blue pieces channel
            new_observation[..., -2] = float(current_player) # Current player channel
            new_observation[..., -1] = observation[..., -1] # (max_board_size, max_board_size) Playable board action_mask (doesn't change)
            new_action_mask: Tensor = action_mask.bool() # Valid move action_mask

        # Create done, reward tensors
        new_action: Tensor = action
        new_done: Tensor = done
        new_reward: Tensor = reward

        # # Update action spec for the environment
        # if is_first_move and self.swap_rule:
        #     # Allow swap action if it's the first move and swap rule is enabled
        #     swap_action_mask = new_action_mask.clone()
        #     swap_action_mask[index] = 1 # Allow the swap action
        #     self.action_spec.update_action_mask(swap_action_mask.flatten())
        # else:
        #     # Update action spec for the environment
        #     self.action_spec.update_action_mask(new_action_mask.flatten())

        # Update tensordict
        new_tensordict = TensorDict({
            "action": new_action,
            "observation": new_observation,
            "action_mask": new_action_mask,
            "done": new_done,
            "reward": new_reward
        }, device=self.device)

        return new_tensordict

    def _check_done(self, observation: Tensor, current_player: int) -> bool:
        def dfs(board, start_positions, target_condition, directions):
            visited = torch.zeros((self.board_size, self.board_size), dtype=torch.bool)
            for start in start_positions:
                if board[start] == 1 and not visited[start]:
                    stack = [start]
                    visited[start] = True
                    while stack:
                        r, c = stack.pop()
                        if target_condition(r, c):
                            return True
                        for dr, dc in directions:
                            nr, nc = r + dr, c + dc
                            if 0 <= nr < self.board_size and 0 <= nc < self.board_size and board[nr, nc] == 1 and not visited[nr, nc]:
                                visited[nr, nc] = True
                                stack.append((nr, nc))
            return False

        directions = [(-1,0), (1,0), (0,-1), (0,1), (1,-1), (-1,1)] # 6 possible directions in a hex grid
        board_state = observation[:self.board_size, :self.board_size, :]
        # Use DFS to check if player 0 (red) has connected top to bottom
        if current_player == 0:
            board = board_state[..., 0] # Shape (board_size, board_size) # Player 0 pieces
            start_positions = [(0, col) for col in range(self.board_size)]
            target_condition = lambda r, c: r == self.board_size - 1
            if dfs(board, start_positions, target_condition, directions):
                return True

        # Use DFS to check if player 1 (blue) has connected left to right
        else:
            board = board_state[..., 1] # Shape (board_size, board_size) # Player 1 pieces
            start_positions = [(row, 0) for row in range(self.board_size)]
            target_condition = lambda r, c: c == self.board_size - 1
            if dfs(board, start_positions, target_condition, directions):
                return True

        return False # No winner yet

    def _set_seed(self, seed: int) -> None:
        np.random.seed(seed)
        torch.manual_seed(seed)

## Test Environment

In [None]:
create_hex_env = lambda: HexEnv(board_size=BOARD_SIZE, max_board_size=MAX_BOARD_SIZE, device=STORAGE_DEVICE)
serial_env = TransformedEnv(
    SerialEnv(
        num_workers=1,
        create_env_fn=create_hex_env
    ),
    ActionMask()
)
# serial_env = TransformedEnv(
#     HexEnv(
#         board_size=BOARD_SIZE,
#         max_board_size=MAX_BOARD_SIZE,
#         device=STORAGE_DEVICE
#     ),
#     ActionMask()
# )

r = serial_env.rollout(100)
r["action"].to(dtype=torch.int)
serial_env.reset()

# Policy

## Custom Sub-modules

In [None]:
class HexConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size: int = 3, bias: bool = True):
        super(HexConv2d, self).__init__()
        assert kernel_size % 2 == 1 and kernel_size > 0, "kernel_size must be odd and positive."

        stride, padding = 1, kernel_size // 2  # To maintain spatial dimensions
        mask = self._create_hex_mask(kernel_size)

        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
        self.register_buffer('mask', mask) # (k, k), requires_grad=False to keep it fixed

    @staticmethod
    def _create_hex_mask(kernel_size: int) -> Tensor:
        assert kernel_size % 2 == 1 and kernel_size > 0, "kernel_size must be odd and positive."

        mask = torch.zeros((kernel_size, kernel_size), dtype=torch.float32)
        center = kernel_size // 2

        for r in range(kernel_size): # Row index
            for c in range(kernel_size): # Column index
                # Using axial distance for a vertically oriented hex grid
                # mapped to an offset coordinate system in the kernel.
                # (r, c) are kernel indices, (dr, dc) are relative to center.
                dr, dc = r - center, c - center
                chebyshev_distance = max(abs(dr), abs(dc), abs(dr + dc))

                if chebyshev_distance <= center: # Inside or on the hexagon
                    mask[r, c] = 1.0

        return mask # (kernel_size, kernel_size)

    def forward(self, x: Tensor) -> Tensor:
        # 1. Tạo trọng số đã mask (Soft masking)
        # Phép nhân '*' tạo ra tensor mới, KHÔNG sửa in-place trọng số gốc.
        # Gradient vẫn truyền ngược qua đây bình thường, các vị trí mask=0 sẽ có grad=0.
        masked_weight = self.conv.weight * self.mask
        
        # 2. Dùng functional conv2d thay vì self.conv(x)
        # Chúng ta truyền masked_weight vào đây.
        x = F.conv2d(
            input=x,
            weight=masked_weight,
            bias=self.conv.bias,
            stride=self.conv.stride,
            padding=self.conv.padding,
            dilation=self.conv.dilation,
            groups=self.conv.groups
        )
        return x


class SkipConnection(nn.Module):
    def __init__(self, adjust_input: nn.Module, original_input: nn.Module = nn.Identity()):
        super(SkipConnection, self).__init__()
        self.adjust_input = adjust_input
        self.original_input = original_input

    def forward(self, x: Tensor) -> Tensor:
        return self.adjust_input(x) + self.original_input(x)


class TriAxialPositionalEmbedding(nn.Module):
    def __init__(self, d_model: int):
        """
        Args:
            d_model: Kích thước channel của input (C). Phải là số chẵn.
        """
        super().__init__()
        assert d_model % 2 == 0, "d_model (C) phải là số chẵn để tính sin/cos."

        # Pre-compute div_term cho sinusoidal
        # Lưu ý: arange(0, d_model, 2) tạo ra d_model/2 phần tử
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

        self.d_model = d_model
        self.register_buffer('div_term', div_term)

    def _get_1d_sinusoidal(self, coords: torch.Tensor) -> torch.Tensor:
        """
        Tạo sinusoidal embedding cho một trục toạ độ.
        Args:
            coords: Tensor chứa giá trị toạ độ (số nguyên hoặc thực), shape (H, W)
        Returns:
            Tensor embedding, shape (H, W, d_model)
        """
        # Mở rộng chiều cuối để tính toán: (H, W, 1)
        coords = coords.unsqueeze(-1).float()
        
        # div_term có shape (d_model/2,)
        # phase = coords * div_term -> Shape (H, W, d_model/2)
        phase = coords * self.div_term
        
        # Tính sin và cos riêng biệt
        sin_part = torch.sin(phase) # (H, W, d_model/2)
        cos_part = torch.cos(phase) # (H, W, d_model/2)
        
        # --- FIX LỖI VMAP Ở ĐÂY ---
        # Thay vì gán in-place (pe[..., 0::2] = ...), ta dùng stack và flatten.
        # 1. Stack lại ở chiều cuối cùng: (H, W, d_model/2, 2)
        #    Tại vị trí cuối: [sin, cos], [sin, cos], ...
        val = torch.stack([sin_part, cos_part], dim=-1)
        
        # 2. Flatten 2 chiều cuối để trộn lại thành (H, W, d_model)
        #    Kết quả sẽ là: sin, cos, sin, cos... đúng thứ tự chẵn lẻ
        pe = val.flatten(-2, -1)
        
        return pe

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # ... (Giữ nguyên phần logic forward cũ) ...
        N, H, W, C = x.shape
        assert C == self.d_model, f"Input channel {C} không khớp với d_model khởi tạo {self.d_model}"

        device = x.device

        rows = torch.arange(H, device=device, dtype=torch.float)
        cols = torch.arange(W, device=device, dtype=torch.float)
        
        r_grid, q_grid = torch.meshgrid(rows, cols, indexing='ij')

        s_grid = -q_grid - r_grid

        pe_q = self._get_1d_sinusoidal(q_grid)
        pe_r = self._get_1d_sinusoidal(r_grid)
        pe_s = self._get_1d_sinusoidal(s_grid)

        full_pe = (pe_q + pe_r + pe_s) / math.sqrt(3)

        return full_pe.unsqueeze(0)

## Base Model

In [None]:
class HexModel(nn.Module):
    def __init__(self, 
                 conv_layers: list[tuple[int, int]],
                 n_encoder_layers: int,
                 d_input: int,
                 n_heads: int = 8,
                 d_ff: int = 2048,
                 dropout: float = 0.1,
                 output_flatten: bool = True):
        """Args:
            conv_layers: List of tuples (out_channels, kernel_size) for each conv layer.
                Note that, in_channels is inferred from the previous layer's out_channels (d_input for the first layer).
            n_encoder_layers: Number of transformer encoder layers.
            d_input: Dimension of input features to the transformer.
            n_heads: Number of attention heads in the transformer.
            d_ff: Dimension of the feedforward network in the transformer.
            dropout: Dropout rate.
        """
        super(HexModel, self).__init__()
        self.output_flatten = output_flatten
        self.d_encoder: int = conv_layers[-1][0] # Last conv layer's out_channels as d_model
        self.conv = nn.Sequential(*[
            SkipConnection(
                nn.Sequential(
                    HexConv2d(conv_layers[i-1][0] if i > 0 else d_input, conv_layers[i][0], conv_layers[i][1], bias=False),
                    nn.GroupNorm(num_groups=4, num_channels=conv_layers[i][0]),
                    nn.GELU(),
                    HexConv2d(conv_layers[i][0], conv_layers[i][0], conv_layers[i][1], bias=False),
                    nn.GroupNorm(num_groups=4, num_channels=conv_layers[i][0]),
                ),
                nn.Identity() if conv_layers[i][0] == conv_layers[i-1][0] # Skip connection (identity)
                else nn.Sequential(
                    nn.Conv2d(conv_layers[i-1][0] if i > 0 else d_input, conv_layers[i][0], 1), # Combine with Conv2d (kernel_size = 1) for channel adjustment
                    nn.GroupNorm(num_groups=4, num_channels=conv_layers[i][0]),
                )
            )
            for i in range(len(conv_layers))
        ])
        self.positional_embedding = TriAxialPositionalEmbedding(self.d_encoder)
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=self.d_encoder,
                nhead=n_heads,
                dim_feedforward=d_ff,
                dropout=dropout,
                activation='gelu',
                batch_first=True
            ),
            num_layers=n_encoder_layers
        )
        self.projection = nn.Linear(self.d_encoder, 1) # Đầu ra cho Actor (logits)/Critic (Q-value)

    def forward(self, x: Tensor) -> Tensor:
        """Args:
            x: Input tensor of shape (N, H, W, C) where
               N = batch size, H = height, W = width, C = channels (d_input).
            Returns: Tensor of shape (N, H, W) with Q-values for each position.
        """
        # Reshape input to (N, C, H, W) for Conv2d
        if len(x.shape) == 3:
            x = x.unsqueeze(0)  # Add batch dimension if missing
        elif len(x.shape) != 4:
            raise ValueError(f"Input tensor x must have shape (N, H, W, C) or (H, W, C), but got {x.shape}.")

        # 1. Convolutional layers
        batch_size, height, width = x.size(0), x.size(1), x.size(2)
        x = x.permute(0, 3, 1, 2).contiguous() # (N, C, H, W)
        x = self.conv(x) # (N, d_encoder, H, W)

        # 2. Positional Embedding + Transformer Encoder
        # x = x.permute(0, 2, 3, 1).flatten(1, 2).contiguous()
        x = x.permute(0, 2, 3, 1).contiguous()
        pe: Tensor = self.positional_embedding(x)
        x = (x + pe).flatten(1, 2).contiguous() # (N, H*W, d_encoder)
        x = self.encoder(x) # (N, H*W, d_encoder)

        # Chỉ sử dụng khi sử dụng vmap của DiscreteSACLOss (deactivate_vmap=False)
        # Nếu không dùng vmap thì không cần thiết (do giảm hiệu suất).
        # with sdpa_kernel(SDPBackend.MATH):
        #     x = self.encoder(x) # (N, H*W, d_encoder)

        # 3. Projection to create outputs for Actor/Critic
        x = self.projection(x) # (N, H*W, 1)
        if self.output_flatten:
            return x.view(batch_size, -1) # (N, H*W)
        else:
            return x.view(batch_size, height, width) # (N, H, W)

## Policy Wrappers

In [None]:
class ActorWrapper(nn.Module):
    """Bọc TransformerQL_AC, chỉ trả về 'logits'."""
    def __init__(self, model: HexModel):
        super().__init__()
        self.model = model # Tham chiếu đến model chung

    def forward(self, observation: Tensor, action_mask: Tensor) -> tuple[Tensor, Tensor]:
        action_mask = action_mask.view(observation.shape[0], -1) # (N, H*W)
        
        # Chạy model chung, chỉ lấy đầu ra đầu tiên
        logits = self.model(observation) # logits shape (N, H*W)
        
        # logits[~action_mask] = -torch.inf # Áp dụng mask
        return logits, action_mask


class CriticWrapper(nn.Module):
    """Bọc HexModel, chỉ trả về 'action_value'."""
    def __init__(self, model: HexModel):
        super().__init__()
        self.model = model # Tham chiếu đến CÙNG model chung

    def forward(self, observation: Tensor, action_mask: Tensor) -> Tensor:
        # action_mask = action_mask.view(observation.shape[0], -1)
        
        # Chạy model chung, chỉ lấy đầu ra thứ hai
        q_values = self.model(observation) # q_values shape (N, H*W)
        
        # q_values[~action_mask] = -torch.inf # Áp dụng mask
        return q_values

# Utilities

In [None]:
# 1. Hàm khởi tạo tham số tối ưu
def init_params(model: nn.Module):
    """
    Khởi tạo tham số (Weights Initialization) tối ưu cho RL & GELU.
    """
    for m in model.modules():
        # 1. Xử lý Linear và Conv2d (bao gồm cả trong HexConv2d)
        if isinstance(m, (nn.Linear, nn.Conv2d)):
            # Nếu là lớp Projection cuối cùng: Init nhỏ để Policy bắt đầu ngẫu nhiên (Max Entropy)
            if hasattr(model, 'projection') and m is model.projection:
                nn.init.orthogonal_(m.weight, gain=0.01)
            # Các lớp ẩn: Gain = sqrt(2) phù hợp với GELU/ReLU
            else:
                nn.init.orthogonal_(m.weight, gain=math.sqrt(2))
            
            # Luôn đưa bias về 0
            if m.bias is not None:
                nn.init.constant_(m.bias, 0.0)
        
        # 2. Xử lý Normalization (GroupNorm, LayerNorm)
        elif isinstance(m, (nn.GroupNorm, nn.LayerNorm, nn.BatchNorm2d)):
            if m.weight is not None:
                nn.init.constant_(m.weight, 1.0) # Gamma = 1
            if m.bias is not None:
                nn.init.constant_(m.bias, 0.0)   # Beta = 0

def get_optimizer_params(model: nn.Module, weight_decay: float = 1e-5):
    """
    Tạo dictionary tham số cho AdamW, tách biệt nhóm cần decay và nhóm không.
    Args:
        weight_decay: Hệ số weight decay mong muốn.
    Returns:
        List các dict config cho optimizer.
    """
    decay_params = []
    no_decay_params = []
    
    # Danh sách các lớp mà weight của nó CẦN decay
    whitelist_weight_modules = (nn.Linear, nn.Conv2d)
    # Danh sách các lớp mà weight của nó KHÔNG decay (Norm layers)
    blacklist_weight_modules = (nn.GroupNorm, nn.LayerNorm, nn.BatchNorm2d)

    # Duyệt qua tất cả module con
    for mn, m in model.named_modules():
        for pn, p in m.named_parameters(recurse=False):
            # Chỉ xét tham số cần học (Mask và Buffer đã tự động bị loại vì requires_grad=False)
            if not p.requires_grad:
                continue
            
            full_param_name = f"{mn}.{pn}" if mn else pn

            # 1. Tất cả Bias -> KHÔNG decay
            if pn.endswith('bias'):
                no_decay_params.append(p)
            
            # 2. Weight của Conv/Linear -> CÓ decay
            elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
                decay_params.append(p)
            
            # 3. Weight của Norm (Gamma) -> KHÔNG decay
            elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
                no_decay_params.append(p)
            
            # 4. Các trường hợp khác (nếu có) -> Mặc định KHÔNG decay cho an toàn
            else:
                no_decay_params.append(p)

    # Kiểm tra nhanh để đảm bảo không bỏ sót tham số nào
    param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad}
    inter_params = len(decay_params) + len(no_decay_params)
    assert len(param_dict.keys()) == inter_params, f"Lỗi: Tổng tham số lọc được ({inter_params}) không khớp với model ({len(param_dict.keys())})"

    return [
        {'params': decay_params, 'weight_decay': weight_decay},
        {'params': no_decay_params, 'weight_decay': 0.0}
    ]

def merge_optimizer_params(loss_fn_params, actor_groups: list, qvalue_groups: list):
    """
    Gộp các nhóm tham số lại với nhau, loại bỏ trùng lặp theo thứ tự ưu tiên.
    
    Priority:
    1. Actor Groups (Custom weight decay logic)
    2. QValue Groups (Custom weight decay logic)
    3. Loss Module Leftovers (Ví dụ: log_alpha trong SAC, thường không có weight decay)
    
    Args:
        loss_fn_params: Iterable (thường là loss_fn.parameters())
        actor_groups: List dicts từ actor_model.get_optimizer_params()
        qvalue_groups: List dicts từ qvalue_model.get_optimizer_params()
        
    Returns:
        List các dict config cho optimizer.
    """
    final_groups = []
    seen_param_ids = set()

    # -------------------------------------------------------
    # 1. Ưu tiên cao nhất: Actor Model
    # -------------------------------------------------------
    for group in actor_groups:
        # Lọc những param đã tồn tại (trường hợp hiếm nếu actor/critic share weights)
        params = group['params']
        new_params = []
        for p in params:
            if id(p) not in seen_param_ids:
                seen_param_ids.add(id(p))
                new_params.append(p)
        
        if new_params:
            # Tạo bản copy của group để không sửa đổi input gốc
            new_group = group.copy()
            new_group['params'] = new_params
            final_groups.append(new_group)

    # -------------------------------------------------------
    # 2. Ưu tiên nhì: QValue Model
    # -------------------------------------------------------
    for group in qvalue_groups:
        params = group['params']
        new_params = []
        for p in params:
            # Chỉ lấy tham số chưa xuất hiện trong Actor
            if id(p) not in seen_param_ids:
                seen_param_ids.add(id(p))
                new_params.append(p)
        
        if new_params:
            new_group = group.copy()
            new_group['params'] = new_params
            final_groups.append(new_group)

    # -------------------------------------------------------
    # 3. Ưu tiên thấp nhất: Loss Module (Leftovers)
    # (Nơi chứa log_alpha hoặc các tham số tự động của TorchRL)
    # -------------------------------------------------------
    leftover_params = []
    for p in loss_fn_params:
        if id(p) not in seen_param_ids:
            leftover_params.append(p)
            seen_param_ids.add(id(p))
    
    if leftover_params:
        # Các tham số "thừa" này thường là hệ số học (như alpha), 
        # không nên áp dụng weight decay cho chúng.
        final_groups.append({
            'params': leftover_params, 
            'weight_decay': 0.0 
        })

    return final_groups

def check_params_changed(model, model_name, old_params):
    changed = False
    print(f"Checking updates for {model_name}...")
    for name, param in model.named_parameters():
        if param.requires_grad:
            # Lấy tham số cũ tương ứng
            old_p = old_params[name]
            # Tính sự khác biệt (norm của hiệu)
            diff = (param - old_p).abs().sum().item()
            
            if diff > 0:
                changed = True
                # Chỉ in ra một vài layer đại diện để không làm rối màn hình
                if "weight" in name and diff > 1e-6: 
                    print(f"  ✓ {name} changed (diff: {diff:.6f})")
            else:
                 print(f"  ⚠️ {name} did NOT change (diff: 0.0)")
    
    if changed:
        print(f"✅ {model_name} weights updated successfully.")
    else:
        print(f"❌ {model_name} weights did NOT update. Check gradients or learning rate.")
    return changed

# Training Loop

## Initialize

In [None]:
# 1. Tạo model gốc
actor_model = HexModel(**MODEL_PARAMS).train().to(DEVICE) # Model cho actor
qvalue_model = HexModel(**MODEL_PARAMS).train().to(DEVICE) # Model cho critic
init_params(actor_model)
init_params(qvalue_model)
# model = HexModel(**MODEL_PARAMS).train().to(DEVICE) # Dùng chung cho cả actor và critic
# init_params(model)
# actor_model, qvalue_model = model, model

# 2. Tạo hai wrapper và policy
actor_network = TensorDictModule(
    ActorWrapper(actor_model),
    in_keys=["observation", "action_mask"],
    out_keys=["logits", "mask"]
)
qvalue_network = TensorDictModule(
    CriticWrapper(qvalue_model),
    in_keys=["observation", "action_mask"],
    out_keys=["action_value"]
)
actor = ProbabilisticActor(
    actor_network,
    in_keys=["logits", "mask"],
    spec=serial_env.action_spec,
    distribution_class=MaskedCategorical
)

# 3. Tạo loss_fn, optimizer, updater
loss_fn = DiscreteSACLoss(
    actor_network=actor,
    qvalue_network=qvalue_network,
    action_space=serial_env.action_spec,
    num_actions=serial_env.action_spec.n,
    deactivate_vmap=True
    # Do Transformer không hỗ trợ vmap tốt. Nếu muốn dùng vmap, cần sử dụng SDPBackend.MATH
).to(DEVICE)

## Lấy các nhóm tham số ưu tiên từ Actor và Critic (đã có config weight_decay chuẩn)
actor_params_groups = get_optimizer_params(actor_model, WEIGHT_DECAY)
qvalue_params_groups = get_optimizer_params(qvalue_model, WEIGHT_DECAY)

## Gộp các nhóm tham số lại, ưu tiên Actor > Critic > Loss leftovers
combined_params = merge_optimizer_params(
    loss_fn_params=loss_fn.parameters(),
    actor_groups=actor_params_groups,
    qvalue_groups=qvalue_params_groups
)

optimizer = torch.optim.AdamW(
    params=combined_params,
    lr=LR,
    weight_decay=WEIGHT_DECAY 
)
updater = SoftUpdate(
    loss_module=loss_fn,
    tau=0.005
)

# 4. Thiết lập optimizer, replay buffer, và các thành phần khác như bình thường
replay_buffer = ReplayBuffer(
    storage=LazyTensorStorage(BUFFER_SIZE, device=STORAGE_DEVICE),
    batch_size=BATCH_SIZE
)
collector = SyncDataCollector(
    create_env_fn=serial_env,
    policy=actor,
    frames_per_batch=N_FRAMES_PER_BATCH,
    total_frames=-1, # Vô hạn
    device=DEVICE,
    storing_device=STORAGE_DEVICE
)

## Functionality Test

In [None]:
# Quick sanity check for env, actor, loss, and collector
print("=" * 50)
print("Testing Environment")
print("=" * 50)

# Test environment reset and step
test_td = serial_env.reset()
print(f"Reset output keys: {test_td.keys()}")
print(f"Observation shape: {test_td['observation'].shape}")
print(f"Action mask shape: {test_td['action_mask'].shape}")

# Take a random action
test_td = serial_env.rand_step(test_td)
print(f"Step output keys: {test_td.keys()}")
print(f"Reward: {test_td.get("next", {}).get('reward', 'N/A')}")
print(f"Done: {test_td['done']}")

print("\n" + "=" * 50)
print("Testing Actor")
print("=" * 50)

# Test actor forward pass
test_td = serial_env.reset()
with torch.no_grad():
    actor_output = actor(test_td)

print(f"Actor output keys: {actor_output.keys()}")
print(f"Action shape: {actor_output['action'].shape}")
print(f"Action value: {actor_output['action'].item()}")

# Test loss calculation with dummy data
print("\n" + "=" * 50)
print("Testing Loss Function and Backpropagation")
print("=" * 50)

actor_params_before = {name: p.clone() for name, p in actor.named_parameters()}
# Giả sử biến mạng Q-value của bạn tên là qvalue_network hoặc critic
# Nếu bạn dùng SAC/TD3 trong TorchRL, nó thường nằm trong loss_module hoặc là một module riêng
# Hãy thay 'qvalue_network' bằng tên biến thực tế của bạn (ví dụ: loss_fn.qvalue_network_params hoặc critic)
try:
    # Ví dụ nếu bạn có biến 'qvalue_network'
    qvalue_params_before = {name: p.clone() for name, p in qvalue_network.named_parameters()}
    # qvalue_params_before = {name: p.clone() for name, p in loss_fn.qvalue_network_params.flatten_keys(".").to_dict().values()}
except NameError:
    print("⚠️ Could not find 'qvalue_network' variable to check. Skipping Q-net check.")
    qvalue_params_before = {name: p.clone() for name, p in loss_fn.qvalue_network_params}
    qvalue_params_before = None

test_td = serial_env.reset()
test_td = serial_env.rand_step(test_td)
test_batch = test_td.repeat(BATCH_SIZE).contiguous().to(DEVICE)  # Create batch
loss_dict = loss_fn(test_batch)

print(f"Loss keys: {loss_dict.keys()}")
for key, value in loss_dict.items():
    if isinstance(value, Tensor):
        print(f"{key}: {value.item():.4f}")

# Test loss back-propagation
optimizer.zero_grad()
loss = loss_dict['loss_actor'] + loss_dict['loss_alpha'] + loss_dict['loss_qvalue']
loss.backward()
optimizer.step()
updater.step()

print("Gradient check for Actor before step:")
total_norm = 0
for p in actor.parameters():
    if p.grad is not None:
        total_norm += p.grad.data.norm(2).item()
print(f"  Actor grad norm: {total_norm:.6f}")
print("-" * 30)
check_params_changed(actor, "Actor", actor_params_before)

print("Gradient check for Q-value network before step:")
total_norm = 0
for p in qvalue_network.parameters():
    if p.grad is not None:
        total_norm += p.grad.data.norm(2).item()
print(f"  Q-Value Network grad norm: {total_norm:.6f}")
if qvalue_params_before:
    check_params_changed(qvalue_network, "Q-Value Network", qvalue_params_before)

print("\n" + "=" * 50)
print("Testing Loss Backpropagation Completed")
print("=" * 50)

# Test collector (collect a small batch)
print("\n" + "=" * 50)
print("Testing Collector")
print("=" * 50)

collector_iter = iter(collector)
batch = next(collector_iter)
print(f"Collected batch keys: {batch.keys()}")
print(f"Batch size: {batch.batch_size}")
print(f"Number of frames: {len(batch)}")

print("\n" + "=" * 50)
print("All components working! ✓")
print("=" * 50)

# Test if actor selects invalid actions
print("=" * 50)
print("Testing Actor Action Validity")
print("=" * 50)

# Run multiple episodes and check for invalid actions
n_test_episodes = 10
invalid_actions_count = 0
total_steps = 0

for episode in range(n_test_episodes):
    test_td = serial_env.reset()
    done = False
    step_count = 0
    
    while not done and step_count < 100:  # Max 100 steps per episode
        # Get action from actor
        with torch.no_grad():
            test_td = actor(test_td)
        
        # Extract action and action_mask
        action = test_td['action'].item()
        action_mask = test_td['action_mask']
        
        # Check if action is valid
        is_valid = action_mask[..., action].item()
        
        if not is_valid:
            invalid_actions_count += 1
            print(f"⚠️ Episode {episode}, Step {step_count}: Invalid action {action}")
            print(f"   Action mask: {action_mask.nonzero(as_tuple=True)[0].tolist()}")
        
        # Take step in environment
        try:
            test_td = serial_env.step(test_td)
            done = test_td['done'].item()
            total_steps += 1
            step_count += 1
        except Exception as e:
            print(f"❌ Error at episode {episode}, step {step_count}: {e}")
            break

print(f"\n{'='*50}")
print(f"Results:")
print(f"Total episodes: {n_test_episodes}")
print(f"Total steps: {total_steps}")
print(f"Invalid actions: {invalid_actions_count}")
print(f"Invalid action rate: {invalid_actions_count/total_steps*100:.2f}%")

if invalid_actions_count == 0:
    print("✅ No invalid actions detected!")
else:
    print("⚠️ Invalid actions detected! Check MaskedCategorical configuration.")
print("=" * 50)