In [1]:
import gymnasium as gym
import numpy as np
from gymnasium import spaces
from typing import Optional, Tuple, Dict, Any


class YahtzeeEnv(gym.Env):
    """
    A custom Yahtzee environment that follows the OpenAI Gym interface.
    Modified to work better with standard RL algorithms by using discrete spaces
    and providing a flattened action space.
    """
    
    def __init__(self):
        super(YahtzeeEnv, self).__init__()
        
        # Define scoring categories and their maximum scores
        self.categories = {
            'ones': 5,         # Max score: 5 (1×5)
            'twos': 10,        # Max score: 10 (2×5)
            'threes': 15,      # Max score: 15 (3×5)
            'fours': 20,       # Max score: 20 (4×5)
            'fives': 25,       # Max score: 25 (5×5)
            'sixes': 30,       # Max score: 30 (6×5)
            'three_of_a_kind': 30,
            'four_of_a_kind': 30,
            'full_house': 25,
            'small_straight': 30,
            'large_straight': 40,
            'yahtzee': 50,
            'chance': 30
        }
        
        # Flatten action space into a single discrete space
        # Actions 0-12: Choose scoring category
        # Actions 13-44: Reroll combinations (2^5 = 32 possible reroll combinations)
        self.action_space = spaces.Discrete(45)
        
        # Define observation space using MultiDiscrete
        self.observation_space = spaces.MultiDiscrete([
            6, 6, 6, 6, 6,  # dice values (0-5 representing 1-6)
            2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,  # available categories (binary) (6-18) (13 categories)
            3  # remaining rolls (0,1,2) (19)
        ])
        
        # Initialize game state
        self.reset()
    
    def reset(self, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[np.ndarray, Dict]:
        """Reset the environment to initial state."""
        super().reset(seed=seed)
        if seed is not None:
            np.random.seed(seed)
        
        self.dice = self._roll_dice()
        self.remaining_rolls = 2
        self.available_categories = np.ones(13, dtype=np.int8)
        self.scores = np.zeros(13, dtype=np.int32)
        self.total_score = 0
        self.is_bonus = False
        
        return self._get_observation(), {}
    
    def _roll_dice(self, reroll_mask: Optional[np.ndarray] = None) -> np.ndarray:
        """Roll the dice according to reroll mask."""
        if reroll_mask is None:
            return np.random.randint(0, 6, size=5)  # 0-5 representing 1-6
        
        new_dice = self.dice.copy()
        for i, reroll in enumerate(reroll_mask):
            if reroll:
                new_dice[i] = np.random.randint(0, 6)
        return new_dice
    
    def _decode_action(self, action: int) -> Tuple[int, np.ndarray]:
        """Convert flat action space to category and reroll mask."""
        if action < 13:  # Scoring actions
            return action, np.zeros(5, dtype=np.int8)
        else:  # Reroll actions
            reroll_idx = action - 13
            return -1, np.array([int(x) for x in format(reroll_idx, '05b')])
    
    def _check_bonus(self) -> int:
        """Check if bonus is earned and return bonus score."""
        upper_section_score = np.sum(self.scores[:6])
        if not self.is_bonus and upper_section_score >= 63:
            self.is_bonus = True
            return 35
        return 0
    
    def _calculate_score(self, category_idx: int, dice: np.ndarray) -> int:
        """Calculate score for given category and dice combination."""
        dice = dice + 1  # Convert from 0-5 to 1-6
        dice_counts = np.bincount(dice, minlength=7)
        category_name = list(self.categories.keys())[category_idx]
        
        if category_name in ['ones', 'twos', 'threes', 'fours', 'fives', 'sixes']:
            number = category_idx + 1
            return (number * dice_counts[number]) + self._check_bonus()
        
        elif category_name == 'three_of_a_kind':
            if np.any(dice_counts >= 3):
                return np.sum(dice)
            return 0
        
        elif category_name == 'four_of_a_kind':
            if np.any(dice_counts >= 4):
                return np.sum(dice)
            return 0
        
        elif category_name == 'full_house':
            if np.any(dice_counts == 3) and np.any(dice_counts == 2):
                return 25
            return 0
        
        elif category_name == 'small_straight':
            for straight in [(1,2,3,4), (2,3,4,5), (3,4,5,6)]:
                if all(dice_counts[s] >= 1 for s in straight):
                    return 30
            return 0
        
        elif category_name == 'large_straight':
            if (all(dice_counts[1:7] == 1) or all(dice_counts[2:8] == 1)) or \
               (all(dice_counts[1:6] == 1) or all(dice_counts[2:7] == 1)):
                return 40
            return 0
        
        elif category_name == 'yahtzee':
            if np.any(dice_counts == 5):
                return 50
            return 0
        
        elif category_name == 'chance':
            return np.sum(dice)
        
        return 0
    
    def step(self, action: int) -> Tuple[np.ndarray, float, bool, bool, Dict]:
        """Take a step in the environment using the given action."""
        category, reroll = self._decode_action(action)
        info = {}
            
        # Handle reroll action
        if category == -1:
            if self.remaining_rolls > 0:
                self.dice = self._roll_dice(reroll)
                self.remaining_rolls -= 1
                reward = 0  # Neutral reward for rerolling
                return self._get_observation(), reward, False, False, info
                
            else:
                reward = -50  # Penalty for invalid reroll
                return self._get_observation(), reward, False, False, info
        
        # Handle scoring action
        if not self.available_categories[category]:
            return self._get_observation(), -50, False, False, {'error': 'Category already used'}
        
        # Calculate score and update state
        score = self._calculate_score(category, self.dice)
        self.scores[category] = score
        self.available_categories[category] = 0
        self.total_score += score
        
        # Reset dice and rolls for next turn
        self.dice = self._roll_dice()
        self.remaining_rolls = 2
        
        # Check if game is done
        done = np.sum(self.available_categories) == 0
        
        # Calculate reward (use the score as the reward)
        reward = score
        
        return self._get_observation(), reward, done, False, {
            'total_score': self.total_score,
            'scores': self.scores.copy()
        }
    
    def _get_observation(self) -> np.ndarray:
        """Return current observation of the environment."""
        return np.concatenate([
            self.dice,  # 5 dice values (0-5)
            self.available_categories,  # 13 binary values
            [self.remaining_rolls]  # 1 value (0-2)
        ])
    
    def render(self, mode='human'):
        """Render the current state of the game."""
        if mode == 'human':
            print("\nCurrent Dice:", self.dice + 1)  # Convert back to 1-6 for display
            print("Remaining Rolls:", self.remaining_rolls)
            print("\nAvailable Categories:")
            for i, (category, available) in enumerate(zip(self.categories.keys(), self.available_categories)):
                if available:
                    possible_score = self._calculate_score(i, self.dice)
                    print(f"{category}: {possible_score} points possible")
            print("\nScored Categories:")
            for i, (category, score) in enumerate(zip(self.categories.keys(), self.scores)):
                if not self.available_categories[i]:
                    print(f"{category}: {score}")
            print("\nTotal Score:", self.total_score)
            if self.is_bonus:
                print("Bonus achieved: +35 points")
    
    def get_legal_actions(self) -> np.ndarray:
        """
        Return an indicator (binary) vector of legal actions.
        
        - For scoring actions (0-12): legal if that category is still available.
        - For reroll actions (13-44): legal if remaining_rolls > 0 and at least one die is rerolled.
        """
        legal = np.zeros(self.action_space.n, dtype=np.int8)
        # Scoring actions: allowed only if the category hasn't been used yet.
        for i in range(13):
            if self.available_categories[i]:
                legal[i] = 1
        # Reroll actions: allowed only if there are remaining rolls.
        if self.remaining_rolls > 0:
            for action in range(13, 45):
                # _, reroll = self._decode_action(action)
                # if np.any(reroll):  # Must reroll at least one die.
                legal[action] = 1
        return legal

In [2]:
class HardCodedStrategy:
    def __init__(self):
        self.category_scores = {
            'ones': {'type': 'upper', 'number': 1},
            'twos': {'type': 'upper', 'number': 2},
            'threes': {'type': 'upper', 'number': 3},
            'fours': {'type': 'upper', 'number': 4},
            'fives': {'type': 'upper', 'number': 5},
            'sixes': {'type': 'upper', 'number': 6},
            'three_of_a_kind': {'type': 'three_kind'},
            'four_of_a_kind': {'type': 'four_kind'},
            'full_house': {'type': 'full_house'},
            'small_straight': {'type': 'small_straight'},
            'large_straight': {'type': 'large_straight'},
            'yahtzee': {'type': 'yahtzee'},
            'chance': {'type': 'chance'},
        }
    
    def calculate_reroll_strategy(self, dice, target_category, rolls_remaining):
        """
        Optimized reroll strategy based on target category.
        Expects dice as list of ints in 0-5 (will be converted to 1-6).
        Returns a list of 5 binary values (1 = reroll, 0 = keep).
        """
        if rolls_remaining == 0:
            return [0, 0, 0, 0, 0]  # No rerolls left
        
        # Convert dice from 0-5 to 1-6
        dice = [d + 1 for d in dice]
        dice_counter = Counter(dice)
        
        category_info = self.category_scores[target_category]
        category_type = category_info['type']
        
        # Default strategy: reroll all dice
        reroll = [1, 1, 1, 1, 1]
        
        # Upper section (ones through sixes)
        if category_type == 'upper':
            target_value = category_info['number']
            
            # Keep all dice of target value
            for i, value in enumerate(dice):
                if value == target_value:
                    reroll[i] = 0

        # Three of a Kind
        elif category_type == 'three_kind':
            most_common = dice_counter.most_common(2)
            
            # Already have three or more of a kind
            if most_common and most_common[0][1] >= 3:
                value_to_keep = most_common[0][0]
                # Keep the three of a kind
                for i, value in enumerate(dice):
                    if value == value_to_keep:
                        reroll[i] = 0
                        
                # With remaining dice, keep high values if last roll
                if rolls_remaining == 1:
                    for i, value in enumerate(dice):
                        if value != value_to_keep and value >= 5:
                            reroll[i] = 0
            
            # Have a pair
            elif most_common and most_common[0][1] == 2:
                value_to_keep = most_common[0][0]
                
                # If multiple pairs, keep the higher pair
                if len(most_common) > 1 and most_common[1][1] == 2:
                    if most_common[0][0] < most_common[1][0]:
                        value_to_keep = most_common[1][0]
                
                # Keep the pair
                for i, value in enumerate(dice):
                    if value == value_to_keep:
                        reroll[i] = 0
                
                # If it's the last roll, also keep high values
                if rolls_remaining == 1:
                    for i, value in enumerate(dice):
                        if reroll[i] == 1 and value >= 5:
                            reroll[i] = 0
            
            # No pairs yet, but last roll - keep highest value
            elif rolls_remaining == 1:
                highest_value = max(dice) if dice else 6
                for i, value in enumerate(dice):
                    if value == highest_value:
                        reroll[i] = 0
                        break

        # Four of a Kind
        elif category_type == 'four_kind':
            most_common = dice_counter.most_common(1)
            
            # Already have four or more of a kind
            if most_common and most_common[0][1] >= 4:
                value_to_keep = most_common[0][0]
                # Keep the four of a kind
                kept = 0
                for i, value in enumerate(dice):
                    if value == value_to_keep and kept < 4:
                        reroll[i] = 0
                        kept += 1
                        
                # With remaining dice, keep high values if last roll
                if rolls_remaining == 1:
                    for i, value in enumerate(dice):
                        if reroll[i] == 1 and value >= 5:
                            reroll[i] = 0
            
            # Have three of a kind
            elif most_common and most_common[0][1] == 3:
                value_to_keep = most_common[0][0]
                # Keep the three of a kind
                for i, value in enumerate(dice):
                    if value == value_to_keep:
                        reroll[i] = 0
            
            # Have a pair and more rolls remaining
            elif most_common and most_common[0][1] == 2:
                # With multiple rolls, keep highest pair
                high_pair = 0
                for val, count in dice_counter.items():
                    if count == 2 and val > high_pair:
                        high_pair = val
                
                if high_pair > 0:
                    for i, value in enumerate(dice):
                        if value == high_pair:
                            reroll[i] = 0
            
            # Last roll and no good combos - keep highest value dice
            elif rolls_remaining == 1:
                sorted_dice = sorted(enumerate(dice), key=lambda x: x[1], reverse=True)
                for idx, _ in sorted_dice[:1]:  # Keep the highest die
                    reroll[idx] = 0

        # Full House
        elif category_type == 'full_house':
            # Already have a full house
            if len(dice_counter) == 2 and 2 in dice_counter.values() and 3 in dice_counter.values():
                reroll = [0, 0, 0, 0, 0]  # Keep all
            else:
                counts = dice_counter.most_common(2)
                
                # Have three of a kind and a different pair
                if len(counts) == 2 and counts[0][1] >= 3 and counts[1][1] >= 2:
                    reroll = [0, 0, 0, 0, 0]  # Keep all
                
                # Have three of a kind - keep it and try for a pair
                elif len(counts) >= 1 and counts[0][1] >= 3:
                    three_kind_value = counts[0][0]
                    
                    # Keep the three of a kind
                    for i, value in enumerate(dice):
                        if value == three_kind_value:
                            reroll[i] = 0
                    
                    # If we also have a single of a different value and last roll, keep it
                    if rolls_remaining == 1 and len(counts) > 1:
                        other_value = counts[1][0]
                        kept = 0
                        for i, value in enumerate(dice):
                            if value == other_value and reroll[i] == 1 and kept < 2:
                                reroll[i] = 0
                                kept += 1
                
                # Have two pairs - keep both pairs
                elif len(counts) >= 2 and counts[0][1] == 2 and counts[1][1] == 2:
                    for i, value in enumerate(dice):
                        if value == counts[0][0] or value == counts[1][0]:
                            reroll[i] = 0
                
                # Have one pair - keep it
                elif len(counts) >= 1 and counts[0][1] == 2:
                    pair_value = counts[0][0]
                    
                    # Keep the pair
                    for i, value in enumerate(dice):
                        if value == pair_value:
                            reroll[i] = 0
                    
                    # If last roll and we have a single of another value, keep it too
                    if rolls_remaining == 1 and len(counts) > 1:
                        other_values = [val for val, _ in counts[1:]]
                        highest_other = max(other_values)
                        kept = 0
                        for i, value in enumerate(dice):
                            if value == highest_other and reroll[i] == 1 and kept < 1:
                                reroll[i] = 0
                                kept += 1

        # Small Straight
        elif category_type == 'small_straight':
            values_set = set(dice)
            
            # Check how close we are to each large straight
            low_straight = [1, 2, 3, 4]
            mid_straight = [2, 3, 4, 5]
            high_straight = [3, 4, 5, 6]
            
            low_matches = []
            mid_matches = []
            high_matches = []
            for value in values_set:
                if value in low_straight:
                    low_matches.append(value)
                if value in high_straight:
                    high_matches.append(value)
                if value in mid_straight:
                    mid_matches.append(value)
                    
            low_matches = list(set(low_matches))
            high_matches = list(set(high_matches))
            mid_matches = list(set(mid_matches))
            
            maxm = max(len(low_matches), len(mid_matches), len(high_matches))
            maxm_list = []
            if len(low_matches) == maxm:
                maxm_list = low_matches
            elif len(mid_matches) == maxm:
                maxm_list = mid_matches
            elif len(high_matches) == maxm: 
                maxm_list = high_matches
                                
            for die in dice:
                if die in maxm_list:
                    reroll[dice.index(die)] = 0
                    maxm_list.remove(die)  # Remove to avoid duplicates

        # Large Straight
        elif category_type == 'large_straight':
            values_set = set(dice)
            
            # Already have a large straight
            if (all(v in values_set for v in [1, 2, 3, 4, 5]) or 
                all(v in values_set for v in [2, 3, 4, 5, 6])):
                reroll = [0, 0, 0, 0, 0]  # Keep all
            else:
                # Check how close we are to each large straight
                low_straight = [1, 2, 3, 4, 5]
                high_straight = [2, 3, 4, 5, 6]
                
                low_matches = []
                high_matches = []
                for value in values_set:
                    if value in low_straight:
                        low_matches.append(value)
                    if value in high_straight:
                        high_matches.append(value)
                low_matches = list(set(low_matches))
                high_matches = list(set(high_matches))
                
                if len(low_matches) >= len(high_matches):
                    for die in dice:
                        if die in low_matches:
                            reroll[dice.index(die)] = 0
                            low_matches.remove(die)  # Remove to avoid duplicates
                else:
                    for die in dice:
                        if die in high_matches:
                            reroll[dice.index(die)] = 0
                            high_matches.remove(die)  # Remove to avoid duplicates

        # Yahtzee
        elif category_type == 'yahtzee':
            
            most_common = dice_counter.most_common(1)
            for i, value in enumerate(dice):
                if most_common and most_common[0][0] == value:
                    reroll[i] = 0
        

        # Chance - keep high values
        elif category_type == 'chance':
            # Always keep 6s
            for i, value in enumerate(dice):
                if value == 6:
                    reroll[i] = 0
            
            # Keep 5s
            for i, value in enumerate(dice):
                if value == 5 and reroll[i] == 1:
                    reroll[i] = 0
            
            # On last roll, also keep 4s
            if rolls_remaining == 1:
                for i, value in enumerate(dice):
                    if value == 4 and reroll[i] == 1:
                        reroll[i] = 0
            
            # If we're keeping too few dice and it's the last roll, keep 3s too
            if rolls_remaining == 1 and sum(1 for r in reroll if r == 0) <= 2:
                for i, value in enumerate(dice):
                    if value == 3 and reroll[i] == 1:
                        reroll[i] = 0
                        
        return reroll

  and should_run_async(code)


In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import random
import time
from collections import deque, Counter

import torch.nn as nn
import torch.nn.functional as F

import torch
import torch.nn as nn
import torch.nn.functional as F

class ResidualBlock(nn.Module):
    def __init__(self, dim):
        super(ResidualBlock, self).__init__()
        self.linear1 = nn.Linear(dim, dim)
        self.linear2 = nn.Linear(dim, dim)
        self.activation = nn.ReLU()

        nn.init.kaiming_normal_(self.linear1.weight)
        nn.init.kaiming_normal_(self.linear2.weight)

    def forward(self, x):
        identity = x
        out = self.activation(self.linear1(x))
        out = self.linear2(out)
        out += identity  # Residual connection
        return self.activation(out)

class TargetIntuitionNet(nn.Module):
    def __init__(self, input_dim, output_dim=13):
        super(TargetIntuitionNet, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.res_block = ResidualBlock(128)
        self.fc2 = nn.Linear(128, 256)
        self.fc3 = nn.Linear(256, output_dim)

        # Initialize weights
        nn.init.kaiming_normal_(self.fc1.weight)
        nn.init.kaiming_normal_(self.fc2.weight)
        nn.init.kaiming_normal_(self.fc3.weight)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.res_block(x)
        x = F.relu(self.fc2(x))
        return self.fc3(x)


class YahtzeeAgent:
    def __init__(self, env, device):
        self.env = env
        self.device = device
        
        # Calculate input dimension from the environment
        obs_dim = env.observation_space.shape[0]
        # Enhanced features as in the original implementation (obs + 22 derived features)
        self.enhanced_dim = obs_dim + 22
        self.reroll_input_size = 5 * 6 + 13 + 1  # dice one-hot (30) + category one-hot (13) + rolls_remaining (1)
        
        # Target intuition network (replaces high_level_net)
        # This outputs the expected value of targeting each category
        self.target_intuition_net = TargetIntuitionNet(self.enhanced_dim, 13).to(self.device)
        # self.reroll_net = RerollNet(self.reroll_input_size).to(self.device)

        self.reroller = HardCodedStrategy()

        # reroll_path = r"/kaggle/input/reroll_pretained/pytorch/default/1/reroll_net_pretrained.pth"
        # self.reroll_net.load_state_dict(torch.load(reroll_path, map_location=self.device))
        # self.reroll_net.to(self.device)
        
        # Optimizer for the target intuition network
        self.optimizer = optim.Adam(self.target_intuition_net.parameters(), lr=0.00005)
        
        # Hyperparameters
        self.gamma = 1.0  # Full credit for future rewards (no discount)
        self.epsilon = 0.5  # Higher starting epsilon for more exploration
        self.epsilon_min = 1.2
        self.epsilon_decay = 0.9999
        self.batch_size = 256  # Larger batch size for more stable learning
        self.buffer = deque(maxlen=200000)  # Larger buffer for more diverse experiences
        
        # Category names for reference
        self.categories = [
            'ones', 'twos', 'threes', 'fours', 'fives', 'sixes',
            'three_of_a_kind', 'four_of_a_kind', 'full_house',
            'small_straight', 'large_straight', 'yahtzee', 'chance'
        ]
        
        # Category score descriptions for the reroll strategy
        self.category_scores = {
            0: {'type': 'upper', 'number': 1},  # Ones
            1: {'type': 'upper', 'number': 2},  # Twos
            2: {'type': 'upper', 'number': 3},  # Threes
            3: {'type': 'upper', 'number': 4},  # Fours
            4: {'type': 'upper', 'number': 5},  # Fives
            5: {'type': 'upper', 'number': 6},  # Sixes
            6: {'type': 'three_kind', 'score': 'sum'},  # Three of a Kind
            7: {'type': 'four_kind', 'score': 'sum'},   # Four of a Kind
            8: {'type': 'full_house', 'score': 25},     # Full House
            9: {'type': 'small_straight', 'score': 30},  # Small Straight
            10: {'type': 'large_straight', 'score': 40}, # Large Straight
            11: {'type': 'yahtzee', 'score': 50},        # Yahtzee
            12: {'type': 'chance', 'score': 'sum'}       # Chance
        }
        
        # Statistics tracking
        self.stats = {
            'upper_bonus_achieved': 0,
            'total_games': 0,
            'category_usage': {cat: 0 for cat in self.categories},
            'scores': [],
            'target_switches': 0,  # Track how often the target category changes
            'final_category_matches_target': 0  # Track if final category matches initial target
        }
        
        # Timing information
        self.times = {
            'select_action': [],
            'optimize_model': []
        }

    def enhance_observation(self, observation):
        """Add derived features to the observation to help the agent learn better"""
        dice = observation[:5] + 1  # Convert 0-5 to 1-6
        categories_available = observation[5:18]
        rolls_remaining = observation[18]
        
        # Dice value counts
        dice_counts = np.zeros(6)
        for i in range(5):
            if 1 <= dice[i] <= 6:
                dice_counts[int(dice[i])-1] += 1
        
        # Key statistics about dice
        has_three_kind = int(any(count >= 3 for count in dice_counts))
        has_four_kind = int(any(count >= 4 for count in dice_counts))
        has_yahtzee = int(any(count == 5 for count in dice_counts))
        has_pair = int(any(count >= 2 for count in dice_counts))
        
        # Upper section scoring potential
        upper_potentials = np.zeros(6)
        for i in range(6):
            if categories_available[i] == 1:  # Category is available
                upper_potentials[i] = (i+1) * dice_counts[i]
        
        # Upper section bonus tracking (need 63+ for bonus)
        upper_filled = sum(1 for i in range(6) if categories_available[i] == 0)
        remaining_turns = sum(categories_available)
        
        # Calculate straight potential
        unique_values = sum(1 for count in dice_counts if count > 0)
        small_straight_potential = 1.0 if unique_values >= 4 else (unique_values / 4.0)
        large_straight_potential = 1.0 if unique_values >= 5 else (unique_values / 5.0)
        
        # Full house potential
        has_three = any(count == 3 for count in dice_counts)
        has_two = any(count == 2 for count in dice_counts)
        full_house_potential = 1.0 if (has_three and has_two) else 0.5 if has_three or has_two else 0.0
        
        # Game progress (normalized)
        game_progress = (13 - remaining_turns) / 13.0
        
        # Upper section bonus situation
        upper_score = 0
        for i in range(6):
            if categories_available[i] == 0:  # Category already filled
                # Try to extract the score from the environment if available
                if hasattr(self.env, 'scorecard'):
                    upper_score += self.env.scorecard.get(self.categories[i], 0)
        
        upper_bonus_threshold = 63
        upper_bonus_progress = min(1.0, upper_score / upper_bonus_threshold)
        
        # Estimated potential to reach upper bonus
        remaining_upper_potential = 0
        for i in range(6):
            if categories_available[i] == 1:
                # Use average expected value for each category
                remaining_upper_potential += min((i+1) * 3, (i+1) * 5)  # Conservative estimate
        
        upper_bonus_potential = min(1.0, (upper_score + remaining_upper_potential) / upper_bonus_threshold)
        
        # Combine original observation with derived features
        enhanced = np.concatenate([
            observation,
            dice_counts,
            upper_potentials,
            [has_three_kind, has_four_kind, has_yahtzee, has_pair],
            [small_straight_potential, large_straight_potential, full_house_potential],
            [game_progress, upper_bonus_progress, upper_bonus_potential]
        ])
        
        return enhanced.astype(np.float32)

    def select_target_category(self, observation, is_eval=False):
        """Select a target category based on the current observation"""
        enhanced_obs = self.enhance_observation(observation)
        state = torch.from_numpy(enhanced_obs).float().to(self.device)
        
        # Get available categories
        categories_available = observation[5:18]
        legal_categories = [i for i in range(13) if categories_available[i] == 1]
        
        if not legal_categories:
            return None  # No legal categories left
        
        # Use epsilon-greedy for exploration during training
        if not is_eval and random.random() < self.epsilon:
            return random.choice(legal_categories)
        
        with torch.no_grad():
            # Get Q-values for all categories
            q_values = self.target_intuition_net(state.unsqueeze(0)).squeeze(0)
            
            # Mask unavailable categories with large negative values
            mask = torch.ones(13, device=self.device) * -1000000
            for i in legal_categories:
                mask[i] = 0
            masked_q = q_values + mask
            
            # Select category with highest Q-value
            return torch.argmax(masked_q).item()

    def calculate_reroll_strategy(self, dice, target_category, rolls_remaining):
        
        return self.reroller.calculate_reroll_strategy(dice, self.categories[target_category], rolls_remaining)

    def select_action(self, observation, is_eval=False, target_category=None):
        """
        Select an action based on the current observation
        Returns: (is_category_selection, action)
        """
        start = time.time()
        
        dice = observation[:5]
        categories_available = observation[5:18]
        rolls_remaining = observation[18]
        
        # If no target provided, compute one
        if target_category is None:
            target_category = self.select_target_category(observation, is_eval)
        
        # When no rerolls are left or no categories are available for targeting,
        # we must select a category to play
        if rolls_remaining == 0 or all(categories_available[i] == 0 for i in range(13)):
            is_category_selection = True
            action = self.select_target_category(observation, is_eval)
        else:
            # Otherwise, calculate reroll pattern
            is_category_selection = False
            decisions = self.calculate_reroll_strategy(dice, target_category, rolls_remaining)
            action = int(''.join(map(str, decisions)), 2) + 13
        
        # Verify action is legal
        legal_actions = self.env.get_legal_actions()
        if action >= len(legal_actions) or legal_actions[action] != 1:
            # Fallback to a legal action
            legal_indices = [i for i, is_legal in enumerate(legal_actions) if is_legal == 1]
            if legal_indices:
                if action < 13:  # Category selection
                    is_category_selection = True
                    action = random.choice([i for i in legal_indices if i < 13])
                else:  # Reroll action
                    is_category_selection = False
                    action = random.choice([i for i in legal_indices if i >= 13])
        
        end = time.time()
        self.times['select_action'].append(end - start)
        return is_category_selection, action, target_category

    def optimize_model(self):
        """Train the target intuition network from experiences"""
        start = time.time()
        
        if len(self.buffer) < self.batch_size:
            return
        
        # Sample batch from the replay buffer
        samples = random.sample(self.buffer, self.batch_size)
        states, actions, rewards, next_states, dones = zip(*samples)
        
        # Convert to tensors
        states = torch.tensor(np.array(states), dtype=torch.float32).to(self.device)
        actions = torch.tensor(np.array(actions), dtype=torch.long).to(self.device)
        rewards = torch.tensor(np.array(rewards), dtype=torch.float32).to(self.device)
        next_states = torch.tensor(np.array(next_states), dtype=torch.float32).to(self.device)
        dones = torch.tensor(np.array(dones), dtype=torch.float32).to(self.device)
        
        # Filter for category selection actions (0-12) only
        category_mask = actions < 13
        if category_mask.any():
            category_states = states[category_mask]
            category_actions = actions[category_mask]
            category_rewards = rewards[category_mask]
            category_next_states = next_states[category_mask]
            category_dones = dones[category_mask]
            
            # Calculate current Q-values
            current_q = self.target_intuition_net(category_states).gather(1, category_actions.unsqueeze(1)).squeeze(1)
            
            # Calculate target Q-values (Double DQN approach)
            with torch.no_grad():
                next_q_values = self.target_intuition_net(category_next_states)
                next_actions = next_q_values.max(1)[1].unsqueeze(1)
                next_q = next_q_values.gather(1, next_actions).squeeze(1)
                target_q = category_rewards + (1 - category_dones) * self.gamma * next_q
            
            # Calculate loss and optimize
            loss = F.smooth_l1_loss(current_q, target_q)
            self.optimizer.zero_grad()
            loss.backward()
            # Clip gradients to prevent exploding gradients
            torch.nn.utils.clip_grad_norm_(self.target_intuition_net.parameters(), 1.0)
            self.optimizer.step()
            
            end = time.time()
            self.times['optimize_model'].append(end - start)
            return loss.item()
        
        end = time.time()
        self.times['optimize_model'].append(end - start)
        return None

    def shape_reward(self, reward, action, observation, next_observation, done, info):
        """Apply reward shaping to encourage better strategic play"""
        shaped_reward = reward
        dice = observation[:5] + 1  # Convert 0-5 to 1-6
        categories_available = observation[5:18]
        
        # Category selection (actions 0-12)
        if action < 13:
            # Track the chosen category
            category = self.categories[action]
            if reward == 0:
                remaining_categories = sum(categories_available)
                if remaining_categories <= 3:
                    shaped_reward = -2  # Less penalty for strategic zeros late game
                else:
                    shaped_reward = -5  # Standard penalty
            
            # Upper section scoring
            if action < 6:
                # Calculate current upper section total
                upper_total = 0
                if hasattr(self.env, 'scorecard'):
                    for i in range(6):
                        if i != action and not categories_available[i]:  # Category already filled
                            upper_total += self.env.scorecard.get(self.categories[i], 0)
                
                # Add current category score
                upper_with_current = upper_total + reward

                if reward >= (action + 1) * 3:
                    shaped_reward += 15  # Reward good upper section scores
            
            # Reward efficient use of categories
            if category == 'yahtzee' and reward >= 50:
                shaped_reward += 30  # Extra bonus for Yahtzee
            elif category in ['small_straight', 'large_straight'] and reward > 0:
                shaped_reward += 25  # Bonus for straights
            elif category == 'full_house' and reward > 0:
                shaped_reward += 25  # Bonus for full house
            elif category in ['three_of_a_kind', 'four_of_a_kind'] and reward >= (3 if category == 'three_of_a_kind' else 4) * 4:
                shaped_reward += 10  # Bonus for good three/four of a kind
        
        return shaped_reward

    def decay_epsilon(self):
        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
        
    def evaluate(self, num_episodes=10):
        """Evaluate the agent's performance without exploration"""
        total_scores = []
        upper_bonus_count = 0
        category_usage = {cat: 0 for cat in self.categories}
        
        for _ in range(num_episodes):
            observation, _ = self.env.reset()
            done = False
            episode_reward = 0
            
            # Initial target category selection
            target_category = self.select_target_category(observation, is_eval=True)
            
            while not done:
                # Select action based on target category
                is_category_selection, action, target_category = self.select_action(
                    observation, is_eval=True, target_category=target_category
                )
                
                # Track category usage
                if is_category_selection:
                    category_usage[self.categories[action]] += 1
                
                # Take action
                observation, reward, terminated, truncated, info = self.env.step(action)
                episode_reward += reward
                done = terminated or truncated
                
                # Recalculate target after each step if there are rerolls left
                if not done and observation[18] > 0:  # rolls_remaining > 0
                    target_category = self.select_target_category(observation, is_eval=True)
            
            total_scores.append(episode_reward)
            
            # Check if upper bonus was achieved
            if hasattr(self.env, 'scorecard') and 'upper_bonus' in self.env.scorecard:
                if self.env.scorecard['upper_bonus'] > 0:
                    upper_bonus_count += 1
        
        # Update statistics
        self.stats['scores'].extend(total_scores)
        self.stats['upper_bonus_achieved'] += upper_bonus_count
        self.stats['total_games'] += num_episodes
        
        # Merge category usage
        for cat, count in category_usage.items():
            self.stats['category_usage'][cat] += count
            
        return np.mean(total_scores)
class YahtzeeTrainerWithCurriculum:
    def __init__(self, env, device):
        self.env = env
        self.device = device
        self.agent = YahtzeeAgent(self.env, self.device)
        
        # Curriculum learning parameters
        self.curriculum_stage = 0
        self.curriculum_stages = [
            {"name": "Upper Section Focus", "episodes": 2500},
            {"name": "Basic Combinations", "episodes": 5000},
            {"name": "Full Game", "episodes": 9000},
        ]
        
        # Override agent's reward shaping based on curriculum
        self.original_shape_reward = self.agent.shape_reward
        self.agent.shape_reward = self.curriculum_shape_reward
        
        # Original select_target method
        self.original_select_target = self.agent.select_target_category
        
        # Override target selection based on curriculum
        self.agent.select_target_category = self.curriculum_select_target
    
    def get_curriculum_stage(self, episode):
        """Determine the current curriculum stage based on episode number"""
        completed_episodes = 0
        for i, stage in enumerate(self.curriculum_stages):
            completed_episodes += stage["episodes"]
            if episode < completed_episodes or i == len(self.curriculum_stages) - 1:
                return i
        return len(self.curriculum_stages) - 1
    
    def curriculum_select_target(self, observation, is_eval=False):
        """Improved target selection using Q-values within curriculum constraints"""
        # Use original method for evaluation
        if is_eval:
            return self.original_select_target(observation, is_eval=True)
            
        categories_available = observation[5:18]
        legal_categories = [i for i in range(13) if categories_available[i] == 1]
        
        if not legal_categories:
            return None
    
        # Prepare state
        enhanced_obs = self.agent.enhance_observation(observation)
        state = torch.from_numpy(enhanced_obs).float().to(self.agent.device)
    
        with torch.no_grad():
            q_values = self.agent.target_intuition_net(state.unsqueeze(0)).squeeze(0)
    
        # Stage 0: Focus on upper section
        if self.curriculum_stage == 0:
            upper_categories = [i for i in range(6) if i in legal_categories]
            if upper_categories and random.random() < 0.8:
                # Choose best among upper section
                best_upper = max(upper_categories, key=lambda i: q_values[i].item())
                return best_upper
    
        # Stage 1: Basic combinations
        elif self.curriculum_stage == 1:
            upper_categories = [i for i in range(6) if i in legal_categories]
            if upper_categories and random.random() < 0.5:
                best_upper = max(upper_categories, key=lambda i: q_values[i].item())
                return best_upper
            
            basic_lower = [i for i in [6, 7, 12] if i in legal_categories]
            if basic_lower and random.random() < 0.3:
                best_basic = max(basic_lower, key=lambda i: q_values[i].item())
                return best_basic
    
        # Stage 2 or fallback: Use full legal category set
        epsilon_scale = 1.0
        if random.random() < (self.agent.epsilon * epsilon_scale):
            return random.choice(legal_categories)
        
        # Mask unavailable categories for final selection
        mask = torch.ones(13, device=self.agent.device) * -1000000
        for i in legal_categories:
            mask[i] = 0
        masked_q = q_values + mask
        return torch.argmax(masked_q).item()

    
    def curriculum_shape_reward(self, reward, action, observation, next_observation, done, info):
        
        return self.original_shape_reward(reward, action, observation, next_observation, done, info)
        
    def train(self, num_episodes=50000, eval_freq=100):
        """
        Train the agent with curriculum learning over a specified number of episodes
        """
        rewards = []
        eval_scores = []
        losses = []
        curriculum_transitions = []
        
        # Performance tracking
        best_eval_score = 0
        best_model_state = None
        
        for episode in range(num_episodes):

            if episode == 17500:
                for param_group in self.agent.optimizer.param_groups:
                    param_group['lr'] = 0.00001

            if episode == 30000:
                for param_group in self.agent.optimizer.param_groups:
                    param_group['lr'] = 0.000005
                        
            # Check if we need to update curriculum stage
            curr_stage = self.get_curriculum_stage(episode)
            if curr_stage != self.curriculum_stage:
                # Transition to new curriculum stage
                old_stage = self.curriculum_stage
                self.curriculum_stage = curr_stage
                curriculum_transitions.append((episode, 
                                             self.curriculum_stages[old_stage]["name"],
                                             self.curriculum_stages[curr_stage]["name"]))
                print(f"\nAdvancing curriculum: {self.curriculum_stages[old_stage]['name']} -> {self.curriculum_stages[curr_stage]['name']} at episode {episode}\n")
            
            observation, _ = self.agent.env.reset()
            total_reward = 0
            episode_losses = []
            done = False
            
            # Enhanced observation for the buffer
            enhanced_obs = self.agent.enhance_observation(observation)
            
            # Initial target category selection using curriculum method
            target_category = self.agent.select_target_category(observation)
            initial_target = target_category
            
            # Track curriculum statistics
            curriculum_stats = {
                "stage": self.curriculum_stage,
                "upper_section_used": 0,
                "lower_section_used": 0
            }
            
            # Safety limit to prevent infinite loops
            step_count = 0
            max_steps = 100
            
            while not done and step_count < max_steps:
                # Select action based on the current target category
                is_category_selection, action, target_category = self.agent.select_action(
                    observation, target_category=target_category
                )
                
                # Take the action
                new_observation, reward, terminated, truncated, info = self.agent.env.step(action)
                done = terminated or truncated
                
                # Track curriculum-specific statistics
                if is_category_selection:
                    if action < 6:  # Upper section
                        curriculum_stats["upper_section_used"] += 1
                    else:  # Lower section
                        curriculum_stats["lower_section_used"] += 1
                
                # Enhanced observation for the buffer
                enhanced_new_obs = self.agent.enhance_observation(new_observation)
                
                # Apply curriculum-specific reward shaping
                shaped_reward = self.agent.shape_reward(reward, action, observation, new_observation, done, info)
                total_reward += reward  # Track original reward for reporting
                
                # If this was a category selection, store it in replay buffer
                if is_category_selection:
                    # Store transition in replay buffer
                    self.agent.buffer.append((enhanced_obs, action, shaped_reward, enhanced_new_obs, done))
                    
                    # Update target category statistics
                    if target_category == action:
                        self.agent.stats['final_category_matches_target'] += 1
                    
                    # Train the network
                    loss = self.agent.optimize_model()
                    if loss is not None:
                        episode_losses.append(loss)
                elif not done:
                    # After reroll, recalculate target category based on new observation
                    new_target = self.agent.select_target_category(new_observation)
                    
                    # Track target changes
                    if new_target != target_category:
                        self.agent.stats['target_switches'] += 1
                    
                    target_category = new_target
                
                # Update for next step
                observation = new_observation
                enhanced_obs = enhanced_new_obs
                step_count += 1
            
            decay_factor = 0.99995
            self.agent.epsilon = max(self.agent.epsilon_min, self.agent.epsilon * decay_factor)
            
            # Record episode reward
            rewards.append(total_reward)
            
            # Record mean loss if available
            if episode_losses:
                losses.append(np.mean(episode_losses))
            
            # Evaluate periodically
            if (episode + 1) % eval_freq == 0:
                eval_score = self.agent.evaluate(num_episodes=50)
                eval_scores.append(eval_score)
                
                # Print progress with curriculum info
                print(f"Episode {episode+1}/{num_episodes}, Stage: {self.curriculum_stages[self.curriculum_stage]['name']}")
                print(f"Avg. Reward: {np.mean(rewards[-eval_freq:]):.1f}, Eval Score: {eval_score:.1f}, Epsilon: {self.agent.epsilon:.3f}")
                
                # Save best model
                if eval_score > best_eval_score:
                    best_eval_score = eval_score
                    best_model_state = {
                        'target_intuition_net': self.agent.target_intuition_net.state_dict(),
                        'optimizer': self.agent.optimizer.state_dict(),
                        'epsilon': self.agent.epsilon,
                        'curriculum_stage': self.curriculum_stage
                    }

                print()
      
        # Save the best model
        if best_model_state:
            torch.save(best_model_state, "best_yahtzee_curriculum_model.pt")

        final_model_state = {
            'target_intuition_net': self.agent.target_intuition_net.state_dict(),
            'optimizer': self.agent.optimizer.state_dict(),
            'epsilon': self.agent.epsilon,
            'curriculum_stage': self.curriculum_stage
        }
        torch.save(final_model_state, "final_curriculum_model.pt")
        
        # Return training history with curriculum information
        return {
            'rewards': rewards,
            'eval_scores': eval_scores,
            'losses': losses,
            'stats': self.agent.stats,
            'curriculum_transitions': curriculum_transitions
        }

    def load_model(self, path):
        """Load a trained model from a file (unchanged)"""
        # This method is unchanged
        checkpoint = torch.load(path)
        self.agent.target_intuition_net.load_state_dict(checkpoint['target_intuition_net'])
        self.agent.optimizer.load_state_dict(checkpoint['optimizer'])
        self.agent.epsilon = checkpoint['epsilon']
        print(f"Model loaded from {path}")
        
        # Also load curriculum stage if available
        checkpoint = torch.load(path)
        if 'curriculum_stage' in checkpoint:
            self.curriculum_stage = checkpoint['curriculum_stage']
            print(f"Loaded curriculum stage: {self.curriculum_stages[self.curriculum_stage]['name']}")

In [4]:
env = YahtzeeEnv()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

seed = 50
random.seed(seed)  # Python random module
np.random.seed(seed)  # NumPy
torch.manual_seed(seed)  # PyTorch CPU
torch.cuda.manual_seed(seed)  # PyTorch GPU (if available)
torch.cuda.manual_seed_all(seed)  # Multi-GPU
torch.backends.cudnn.deterministic = True  # Ensures deterministic behavior
torch.backends.cudnn.benchmark = False  # Disables auto-tuning for convolutions

trainer = YahtzeeTrainerWithCurriculum(env, device)
history = trainer.train(eval_freq=100)

  and should_run_async(code)


Episode 100/50000, Stage: Upper Section Focus
Avg. Reward: 74.8, Eval Score: 114.7, Epsilon: 1.200

Episode 200/50000, Stage: Upper Section Focus
Avg. Reward: 81.8, Eval Score: 140.5, Epsilon: 1.200

Episode 300/50000, Stage: Upper Section Focus
Avg. Reward: 82.2, Eval Score: 139.9, Epsilon: 1.200

Episode 400/50000, Stage: Upper Section Focus
Avg. Reward: 89.9, Eval Score: 143.0, Epsilon: 1.200

Episode 500/50000, Stage: Upper Section Focus
Avg. Reward: 87.2, Eval Score: 147.5, Epsilon: 1.200

Episode 600/50000, Stage: Upper Section Focus
Avg. Reward: 86.7, Eval Score: 150.8, Epsilon: 1.200

Episode 700/50000, Stage: Upper Section Focus
Avg. Reward: 89.9, Eval Score: 159.4, Epsilon: 1.200

Episode 800/50000, Stage: Upper Section Focus
Avg. Reward: 87.0, Eval Score: 152.9, Epsilon: 1.200

Episode 900/50000, Stage: Upper Section Focus
Avg. Reward: 82.1, Eval Score: 148.7, Epsilon: 1.200

Episode 1000/50000, Stage: Upper Section Focus
Avg. Reward: 85.0, Eval Score: 147.1, Epsilon: 1.200


In [5]:
torch.save(trainer.agent.target_intuition_net.state_dict(), "/kaggle/working/target_intuition_net.pth")

  and should_run_async(code)
