# **Agent Training with Supervised Learning**

## Objectives

* Parse hand history data
* Train DQN agent with parsed data
* Save pre-trained model

## Inputs

* outputs/datasets/collection/0.05 - 0.1 - 6max.txt

## Outputs

* Pre-trained model: outputs/models/pretrained

---

# Change working directory

Since jupyter notebooks are in a subfolder we need to change the working directory from its current folder to its parent folder
* We access the current directory with os.getcwd()

In [None]:
import os
current_dir = os.getcwd()
current_dir

We want to make the parent of the current directory the new current directory
* os.path.dirname() gets the parent directory
* os.chir() defines the new current directory

In [None]:
os.chdir(os.path.dirname(current_dir))
print("You set a new current directory")

Confirm the new current directory

In [None]:
current_dir = os.getcwd()
current_dir

# Load Data

In [None]:
import glob

input_folder = "outputs/datasets/collection"
txt_files = glob.glob(f"{input_folder}/*.txt")
txt_files[0]

---

# Custom Class to Parse Data

In [None]:
import re
from typing import List, Dict, Tuple

class PokerHandHistoryParser:
    """Parser for poker hand history from a poker website."""
    
    def __init__(self, hero_name: str = 'Hero'):
        """
        Initialize the parser with hand history file.
        
        Args:
            hero_name (str, optional): _description_. Defaults to 'Hero'.
        """
        self.hero_name = hero_name
        self.hands = []  # Will contain parsed hand data
    
    def parse(self):
        """Parse the hand history file into structured data, focusing on hero's gameplay."""
        try:
            with open('outputs/datasets/collection/0.05 - 0.1 - 6max.txt', 'r', encoding="utf-8") as file:
                content = file.read()
            
            # Split into individual hands
            hand_texts = self._split_into_hands(content)
            
            # Parse each hand where hero is present
            for hand_text in hand_texts:
                # Only parse hands where the hero is playing
                if self.hero_name in hand_text:
                    parsed_hand = self._parse_hand(hand_text)
                    if parsed_hand:
                        self.hands.append(parsed_hand)
            
            print(f"Successfully parsed {len(self.hands)} hands from history file where {self.hero_name} was playing")
            return self.hands

        except Exception as e:
            print(f"Error parsing hand history: {e}")
            return []
    
    def _split_into_hands(self, content: str) -> List[str]:
        """Split the entire file content into individual hands."""
        hand_pattern = r"Poker Hand #RC\d+: .*?(?=Poker Hand #RC\d+:|$)"
        hands = re.findall(hand_pattern, content, re.DOTALL)
        return hands
    
    def _parse_hand(self, hand_text: str) -> Dict:
        """
        Parse a single hand's text, focusing on hero's perspective.
        
        This needs to be customized for your specific poker site's format.
        """
        hand_data = {
            'hand_id': None,
            'table_name': None,
            'button_seat': None,
            'game_type': None,
            'stakes': None,
            'date_time': None,
            'hero_position': None,
            'hero_seat': None,
            'hero_stack': None,
            'hero_cards': [],  # Hero's hole cards
            'players': {},  # Will contain player names, positions, stacks (but not cards)
            'blinds': {'sb': 0, 'bb': 0},
            'players_actions': [],
            'community_cards': {
                'flop': [],
                'turn': None,
                'river': None
            },
            'showdown': False,  # Whether hand went to showdown
            'shown_cards': {},  # Cards shown at showdown (if any)
            'hero_result': {
                'won': False,
                'amount': 0,
                'net_win': 0  # Amount won minus amount invested
            }
        }
        
        try:
            # Parse hand ID and basic info - similar to before
            header_match = re.search(r"Poker Hand #(\w+):.*?(\$\d+\.\d+/\$\d+\.\d+).*?- (\d{4}/\d{2}/\d{2} \d{2}:\d{2}:\d{2})", hand_text)
            if header_match:
                hand_data['hand_id'] = header_match.group(1)
                hand_data['stakes'] = header_match.group(2)
                hand_data['date_time'] = header_match.group(3)
            
            # Parse table info
            table_match = re.search(r"Table '([\w\d]+)'.*?(\d+)-max Seat #(\d+) is the button", hand_text)
            if table_match:
                hand_data['table_name'] = table_match.group(1)
                hand_data['max_players'] = int(table_match.group(2))
                hand_data['button_seat'] = int(table_match.group(3))
            
            # Parse player info
            player_pattern = r"Seat (\d+): ([^\(]+) \(\$?([\d.]+) in chips\)"
            for seat, name, stack in re.findall(player_pattern, hand_text):
                name = name.strip()
                player_info = {
                    'seat': int(seat),
                    'stack': float(stack) * 10,
                    'position': None  # Will determine position later
                }
                
                # Mark the hero player
                if name == self.hero_name:
                    hand_data['hero_seat'] = int(seat)
                    hand_data['hero_stack'] = float(stack) * 10
                
                hand_data['players'][name] = player_info
            
            # Parse positions based on button
            self._assign_positions(hand_data)
            
            # Parse hero's hole cards
            hero_cards_match = re.search(r"Dealt to " + re.escape('Hero') + r" \[([^\]]+)\]", hand_text)
            if hero_cards_match:
                hand_data['hero_cards'] = hero_cards_match.group(1).split()
                
            # Parse blinds
            sb_match = re.search(r"(\w+): posts small blind \$(\d+\.\d+)", hand_text)
            if sb_match:
                player = sb_match.group(1).strip()
                hand_data['blinds']['sb'] = float(sb_match.group(2)) * 10

            bb_match = re.search(r"(\w+): posts big blind \$(\d+\.\d+)", hand_text)
            if bb_match:
                player = bb_match.group(1).strip()
                hand_data['blinds']['bb'] = float(bb_match.group(2)) * 10
            
            # Parse actions with particular focus on hero's actions
            self._parse_actions_by_street(hand_text, hand_data)
            
            # Parse community cards
            self._parse_community_cards(hand_text, hand_data)
            
            # Parse showdown results
            self._parse_showdown_and_results(hand_text, hand_data)
            
            return hand_data
            
        except Exception as e:
            print(f"Error parsing individual hand: {e}")
            return None
    
    def _assign_positions(self, hand_data: Dict):
        """Assign poker positions based on button position."""
        # Get the number of players
        num_players = len(hand_data['players'])
        if num_players < 2:
            return
            
        # Get button seat
        btn_seat = hand_data['button_seat']
        
        # Create ordered list of seats
        seats = []
        for name, info in hand_data['players'].items():
            seats.append((info['seat'], name))
        
        # Sort by seat number
        seats.sort()
        
        # Find button index
        btn_idx = -1
        for i, (seat, _) in enumerate(seats):
            if seat == btn_seat:
                btn_idx = i
                break
        
        if btn_idx == -1:
            return
            
        # Assign positions
        positions = ['BTN', 'SB', 'BB', 'UTG', 'MP', 'CO']
        
        for i in range(num_players):
            pos_idx = (btn_idx + i) % num_players
            seat, name = seats[pos_idx]
            
            # Assign position based on table size and index
            if i < len(positions):
                pos = positions[i]
            else:
                pos = f"MP{i-3}"  # MP2, MP3, etc. for larger tables
                
            hand_data['players'][name]['position'] = pos
            
            # Store hero's position
            if name == self.hero_name:
                hand_data['hero_position'] = pos
    
    def _parse_actions_by_street(self, hand_text: str, hand_data: Dict):
        """Parse actions for each betting round, focusing on hero's actions."""
        # Parse preflop actions
        preflop_match = re.search(r"\*\*\* HOLE CARDS \*\*\*(.*?)(?:\*\*\* FLOP \*\*\*|\*\*\* SUMMARY \*\*\*)", hand_text, re.DOTALL)
        if preflop_match:
            self._parse_street_actions(preflop_match.group(1), 'preflop', hand_data)
            
        # Parse flop actions if available
        flop_match = re.search(r"\*\*\* FLOP \*\*\* \[([^\]]+)\](.*?)(?:\*\*\* TURN \*\*\*|\*\*\* SUMMARY \*\*\*)", hand_text, re.DOTALL)
        if flop_match:
            self._parse_street_actions(flop_match.group(2), 'flop', hand_data)
            
        # Parse turn actions if available
        turn_match = re.search(r"\*\*\* TURN \*\*\* \[[^\]]+\] \[([^\]]+)\](.*?)(?:\*\*\* RIVER \*\*\*|\*\*\* SUMMARY \*\*\*)", hand_text, re.DOTALL)
        if turn_match:
            self._parse_street_actions(turn_match.group(2), 'turn', hand_data)
            
        # Parse river actions if available
        river_match = re.search(r"\*\*\* RIVER \*\*\* \[[^\]]+\] \[([^\]]+)\](.*?)(?:\*\*\* SHOWDOWN \*\*\*|\*\*\* SUMMARY \*\*\*)", hand_text, re.DOTALL)
        if river_match:
            self._parse_street_actions(river_match.group(2), 'river', hand_data)
    
    def _parse_street_actions(self, street_text: str, street: str, hand_data: Dict):
        """Parse all actions for a specific betting round."""
        # Extract actions like "PlayerName: raises $X to $Y"
        action_pattern = r"(\w+): (folds|calls|bets|raises|checks)(?: \$?([\d.]+)(?: to \$?([\d.]+))?)?"
        for action_match in re.finditer(action_pattern, street_text):
            player = action_match.group(1).strip()
            action_type = action_match.group(2)
            amount = action_match.group(3)
            raised_to = action_match.group(4)
            
            action_data = {
                'street': street,
                'player': player,
                'action': action_type,
                'amount': float(amount) * 10 if amount else None,
                'raised_to': float(raised_to) * 10 if raised_to else None
            }
            
            hand_data['players_actions'].append(action_data)
    
    def _parse_community_cards(self, hand_text: str, hand_data: Dict):
        """Parse community cards for flop, turn, and river."""
        # Parse flop
        flop_match = re.search(r"\*\*\* FLOP \*\*\* \[([^\]]+)\]", hand_text)
        if flop_match:
            hand_data['community_cards']['flop'] = flop_match.group(1).strip().split()
            
        # Parse turn
        turn_match = re.search(r"\*\*\* TURN \*\*\* \[[^\]]+\] \[([^\]]+)\]", hand_text)
        if turn_match:
            hand_data['community_cards']['turn'] = turn_match.group(1).strip()
            
        # Parse river
        river_match = re.search(r"\*\*\* RIVER \*\*\* \[[^\]]+\] \[([^\]]+)\]", hand_text)
        if river_match:
            hand_data['community_cards']['river'] = river_match.group(1).strip()
    
    def _parse_shown_cards(self, hand_text: str, hand_data: Dict):
        """Parse hands shown at the end of hand."""
        if "*** SHOWDOWN ***" in hand_text:
            hand_data['showdown'] = True
            
            # Parse shown cards
            shown_cards_pattern = r"(\w+): shows \[([^\]]+)\]"
            for player, cards in re.findall(shown_cards_pattern, hand_text):
                hand_data['shown_cards'][player.strip()] = cards.strip().split()
        
    def _parse_showdown_and_results(self, hand_text: str, hand_data: Dict):
        # Find Hero's blinds
        small_blind_match = re.search(r"Hero: posts small blind \$(\d+\.\d+)", hand_text)
        big_blind_match = re.search(r"Hero: posts big blind \$(\d+\.\d+)", hand_text)
        
        small_blind = float(small_blind_match.group(1)) * 10 if small_blind_match else 0.0
        big_blind = float(big_blind_match.group(1)) * 10 if big_blind_match else 0.0

        # Find Hero's raises (only the final value)
        hero_raises = re.findall(r"Hero: raises \$(?:[\d.]+) to \$(\d+\.\d+)", hand_text)
        total_raise = float(hero_raises[-1]) * 10 if hero_raises else 0.0

        # Find Hero's calls and bets
        hero_bets_calls = re.findall(r"Hero: (bets|calls) \$(\d+\.\d+)", hand_text)
        total_bet_call = sum(float(amount) for _, amount in hero_bets_calls) * 10

        # Determine Hero's total bet amount
        if hero_raises:
            total_bet = total_raise  # Use final raise value
            if hero_bets_calls:  # Add any bets/calls AFTER a raise
                total_bet += total_bet_call
        else:
            total_bet = total_bet_call  # If no raise, just sum bets/calls

        # If Hero did not bet/raise, add blinds
        if not (hero_bets_calls or hero_raises):
            total_bet += small_blind + big_blind

        # Find Hero's winnings
        uncalled_bet_match = re.search(r"Uncalled bet \(\$(\d+\.\d+)\) returned to Hero", hand_text)
        collected_pot_match = re.search(r"Hero collected \$(\d+\.\d+) from pot", hand_text)

        uncalled_bet = float(uncalled_bet_match.group(1)) * 10 if uncalled_bet_match else 0.0
        collected_pot = float(collected_pot_match.group(1)) * 10 if collected_pot_match else 0.0

        total_won = round(uncalled_bet + collected_pot, 2)
        net_result = round(total_won - total_bet, 2)
        
        hand_data['hero_result']['won'] = True if total_won > 0 else False
        hand_data['hero_result']['amount'] = total_won
        hand_data['hero_result']['net_win'] = net_result

* Initialize custom hand history parser class, parse the hand history and print an example

In [None]:
import json
parser = PokerHandHistoryParser()
parser.parse()
print(f"Parsedand Example:\n{json.dumps(parser.hands[11], indent=4)}")

# Custom Class for Converting Hand Data To Game State

In [None]:
class HandDataConverter:
    """
    Converts parsed hand history data into state-action pairs for training the hero agent.
    """
    
    def __init__(self, feature_extractor, hero_name: str = 'Hero'):
        """
        Initialize the converter.
        
        Args:
            feature_extractor: The feature extraction function to convert game states to vectors
            hero_name: Hero's name in the hand history
        """
        self.feature_extractor = feature_extractor
        self.hero_name = hero_name
    
    def convert_hand_to_training_data(self, hand_data: Dict) -> List[Tuple]:
        """
        Convert a single hand's data into a list of (state, action, reward) tuples for the hero.
        
        Args:
            hand_data: Parsed hand data
            
        Returns:
            List of (state, action_idx, reward) tuples for supervised learning
        """
        training_data = []
        
        try:
            # Reconstruct the hand's progression to create state representations
            game_state = self._initialize_game_state(hand_data)
            
            # Process each action in order to update game state
            hero_action_indices = []  # To keep track of hero's action indices
            
            for i, action_data in enumerate(hand_data['players_actions']):
                player_name = action_data['player']
                
                # Before hero's action, capture state for training
                if player_name == self.hero_name:
                    # Create state representation before this action
                    state_vector = self.feature_extractor(game_state, self.hero_name)
                    
                    # Convert action to action index
                    action_idx = self._action_to_index(action_data['action'], 
                                                      action_data.get('amount'), 
                                                      action_data.get('raised_to'), 
                                                      game_state)
                    
                    # Initially, reward is 0 (will update after processing all actions)
                    reward = 0.0
                    
                    # Add to training data
                    training_data.append((state_vector, action_idx, reward))
                    hero_action_indices.append(len(training_data) - 1)
                
                # Update game state with this action
                self._update_game_state(game_state, action_data)
                
                # At street changes, update community cards
                if 'street_change' in action_data and action_data['street_change']:
                    self._update_community_cards(game_state, action_data['street'], hand_data)
            
            # After processing all actions, update reward for the last hero action
            if hero_action_indices and hand_data['hero_result']:
                last_idx = hero_action_indices[-1]
                state, action, _ = training_data[last_idx]
                
                # Update with the net win/loss
                net_result = hand_data['hero_result']['net_win']
                
                # Optionally normalize the reward
                normalized_reward = net_result / hand_data['blinds']['bb']  # Normalize by big blind
                
                training_data[last_idx] = (state, action, normalized_reward)
            
            return training_data
            
        except Exception as e:
            print(f"Error converting hand to training data: {e}")
            return []
    
    def _initialize_game_state(self, hand_data: Dict) -> Dict:
        """Initialize a game state dictionary from hand data."""
        game_state = {
            'players': {},
            'pot': 0.0,
            'current_bet': 0.0,
            'community_cards': [],
            'betting_round': 'preflop',
            'hero_name': self.hero_name,
            'hero_cards': hand_data['hero_cards']
        }
        
        # Add players
        for name, info in hand_data['players'].items():
            game_state['players'][name] = {
                'stack': info['stack'],
                'position': info['position'],
                'current_bet': 0.0,
                'total_bet': 0.0,
                'is_active': True,
                'has_folded': False,
                'is_all_in': False,
                # Only include hero's hole cards
                'hole_cards': hand_data['hero_cards'] if name == self.hero_name else []
            }
        
        # Apply blinds
        for name, info in hand_data['players'].items():
            if info['position'] == 'SB':
                game_state['players'][name]['current_bet'] = hand_data['blinds']['sb']
                game_state['players'][name]['total_bet'] = hand_data['blinds']['sb']
                game_state['players'][name]['stack'] -= hand_data['blinds']['sb']
                game_state['pot'] += hand_data['blinds']['sb']
            elif info['position'] == 'BB':
                game_state['players'][name]['current_bet'] = hand_data['blinds']['bb']
                game_state['players'][name]['total_bet'] = hand_data['blinds']['bb']
                game_state['players'][name]['stack'] -= hand_data['blinds']['bb']
                game_state['pot'] += hand_data['blinds']['bb']
                game_state['current_bet'] = hand_data['blinds']['bb']
        
        return game_state
    
    def _action_to_index(self, action_type: str, amount: float, raised_to: float, game_state: Dict) -> int:
        """
        Convert a poker action to an action index for the DQNAgent.
        """
        # This is a simplified example - adapt to your actual action space
        if action_type == 'folds':
            return 0  # FOLD
        elif action_type == 'checks':
            return 1  # CHECK
        elif action_type == 'calls':
            return 2  # CALL
        elif action_type == 'bets':
            # Determine which bet sizing category this falls into
            pot_size = game_state['pot']
            if amount is None:
                return 3  # Default minimum bet
            
            if amount <= pot_size * 0.33:
                return 3  # BET 1/3 pot
            elif amount <= pot_size * 0.67:
                return 4  # BET 2/3 pot
            elif amount <= pot_size:
                return 5  # BET 1x pot
            else:
                return 6  # BET 2x+ pot
        elif action_type == 'raises':
            # Determine raise sizing category
            pot_size = game_state['pot']
            current_bet = game_state['current_bet']
            
            if raised_to is not None:
                raise_size = raised_to - current_bet
                if raise_size <= pot_size * 0.67:
                    return 7  # RAISE to 2/3 pot
                elif raise_size <= pot_size:
                    return 8  # RAISE to 1x pot
                else:
                    return 9  # RAISE to 2x+ pot
            else:
                return 7  # Default minimum raise
        
        # Default
        return 1  # CHECK
    
    def _update_game_state(self, game_state: Dict, action_data: Dict):
        """Update the game state based on an action."""
        player = action_data['player']
        action = action_data['action']
        amount = action_data.get('amount', 0)
        raised_to = action_data.get('raised_to')
        
        # Check for street change
        if 'street' in action_data and game_state['betting_round'] != action_data['street']:
            game_state['betting_round'] = action_data['street']
            # Mark this action as causing a street change
            action_data['street_change'] = True
            
            # Reset current bets for new street
            for p in game_state['players'].values():
                p['current_bet'] = 0.0
            game_state['current_bet'] = 0.0
        else:
            action_data['street_change'] = False
        
        # Apply the action
        if player in game_state['players']:
            player_state = game_state['players'][player]
            
            if action == 'folds':
                player_state['has_folded'] = True
                player_state['is_active'] = False
            
            elif action == 'checks':
                player_state['is_active'] = False
                pass
            
            elif action == 'calls':
                player_state['is_active'] = False
                # Calculate call amount
                call_amount = game_state['current_bet'] - player_state['current_bet']
                if call_amount > 0:
                    player_state['stack'] -= call_amount
                    player_state['current_bet'] = game_state['current_bet']
                    player_state['total_bet'] += call_amount
                    game_state['pot'] += call_amount
            
            elif action == 'bets':
                # When bet is made all ather players becoe active
                player_state['is_active'] = False
                for p in game_state['players']:
                    if p != player and not game_state['players'][p]['has_folded'] and not game_state['players'][p]['is_all_in']:
                        game_state['players'][p]['is_active'] = True
                        
                if amount is not None:
                    player_state['stack'] -= amount
                    player_state['current_bet'] = amount
                    player_state['total_bet'] += amount
                    game_state['pot'] += amount
                    game_state['current_bet'] = amount
            
            elif action == 'raises':
                # When bet is made all ather players becoe active
                player_state['is_active'] = False
                for p in game_state['players']:
                    if p != player and not game_state['players'][p]['has_folded'] and not game_state['players'][p]['is_all_in']:
                        game_state['players'][p]['is_active'] = True
                if raised_to is not None:
                    # Calculate actual amount added
                    additional = raised_to - player_state['current_bet']
                    player_state['stack'] -= additional
                    player_state['current_bet'] = raised_to
                    player_state['total_bet'] += additional
                    game_state['pot'] += additional
                    game_state['current_bet'] = raised_to
    
    def _update_community_cards(self, game_state: Dict, street: str, hand_data: Dict):
        """Update community cards based on the current street."""
        # Clear current community cards
        game_state['community_cards'] = []
        
        # Add appropriate cards based on street
        if street == 'flop':
            game_state['community_cards'] = hand_data['community_cards']['flop']
        elif street == 'turn':
            game_state['community_cards'] = hand_data['community_cards']['flop'] + [hand_data['community_cards']['turn']]
        elif street == 'river':
            game_state['community_cards'] = hand_data['community_cards']['flop'] + [hand_data['community_cards']['turn']] + [hand_data['community_cards']['river']]

* Add feature extractor function

In [None]:
import numpy as np

def custom_feature_extractor(game_state: Dict, player_name: str) -> np.ndarray:
    """
    Extract features from the game state for a specific player.
    
    This should be similar to your RL environment's state representation.
    """
    # Get player info
    player_info = game_state['players'].get(player_name, {})

    # Extract features
    state = []
    
    # Player stack (normalized)
    state.append(player_info["stack"] / 100.0)  # Normalize stack
    state.append(1.0 if player_info["is_active"] else 0.0)
    state.append(1.0 if player_info["has_folded"] else 0.0)
    state.append(1.0 if player_info["is_all_in"] else 0.0)
    state.append(game_state["current_bet"] / 100.0)  # Normalize
    
    # Community cards
    community_cards = game_state["community_cards"]
    num_community = len(community_cards)
    state.append(num_community / 5.0)  # Normalize
    
    # Pot and current bet
    state.append(game_state["pot"] / 200.0)  # Normalize
    state.append(game_state["current_bet"] / 100.0)  # Normalize 
    
    # Betting round
    round_map = {"preflop": 0, "flop": 1, "turn": 2, "river": 3, None: 4}
    betting_round = round_map.get(game_state['betting_round'], 0)
    for i in range(5):
        state.append(1.0 if betting_round == i else 0.0)
    
    # Add more features as needed
    
    return np.array(state, dtype=np.float32)

In [None]:
import numpy as np

def custom_feature_extractor(game_state: Dict, player_name: str) -> np.ndarray:
    """
    Extract features from the game state for a specific player.
    
    This should be similar to your RL environment's state representation.
    """
    # Get player info
    player_info = game_state['players'].get(player_name, {})

    # Extract features
    state = []
    
    # Player stack (normalized)
    state.append(player_info["stack"] / 100.0)  # Normalize stack
    state.append(1.0 if player_info["is_active"] else 0.0)
    state.append(1.0 if player_info["has_folded"] else 0.0)
    state.append(1.0 if player_info["is_all_in"] else 0.0)
    state.append(game_state["current_bet"] / 100.0)  # Normalize
    
    # Community cards
    community_cards = game_state["community_cards"]
    num_community = len(community_cards)
    state.append(num_community / 5.0)  # Normalize
    
    # Pot and current bet
    state.append(game_state["pot"] / 200.0)  # Normalize
    state.append(game_state["current_bet"] / 100.0)  # Normalize 
    
    # Betting round
    round_map = {"preflop": 0, "flop": 1, "turn": 2, "river": 3}
    betting_round = round_map.get(game_state['betting_round'], 0)
    for i in range(4):
        state.append(1.0 if betting_round == i else 0.0)
    
    # Player position
    position_map = {'BTN': 0, 'SB': 1, 'BB': 2, 'UTG': 3, 'MP': 4, 'CO': 5}
    position = position_map.get(player_info.get('position'), 0) / 5.0  # Normalize
    state.append(position)

    # Pot size relative to player stack
    pot_to_stack = game_state['pot'] / max(1.0, player_info.get('stack', 1.0))
    pot_to_stack = min(3.0, pot_to_stack) / 3.0  # Cap at 3x stack and normalize
    state.append(pot_to_stack)

    # Current bet relative to pot
    if game_state['pot'] > 0:
        bet_to_pot = game_state['current_bet'] / game_state['pot']
        bet_to_pot = min(2.0, bet_to_pot) / 2.0  # Cap at 2x pot and normalize
    else:
        bet_to_pot = 0.0
    state.append(bet_to_pot)

    # Number of active players
    active_players = sum(1 for p in game_state['players'].values() 
    if p.get('is_all_in', False) and not p.get('has_folded', False))
    active_ratio = active_players / 6.0  # Normalize by max players
    state.append(active_ratio)
        
    # Add more features as needed
    
    return np.array(state, dtype=np.float32)

## Pre-train DQN agent

In [None]:
import random
import torch
import torch.nn as nn
from rl_model.agent import DQNAgent
import matplotlib.pyplot as plt

def pretrain_hero_agent_from_history(agent, num_epochs: int = 5):
    """
    Pretrain a DQN agent from poker hand history using hero-centric supervised learning.
    
    Args:
        agent: The DQNAgent to pretrain
        num_epochs: Number of training epochs
    """
    # Parse hand history
    parser = PokerHandHistoryParser()
    hands = parser.parse()
    
    if not hands:
        print("No hands parsed, cannot pretrain")
        return
    
    # Convert to training data
    converter = HandDataConverter(custom_feature_extractor)
    training_data = []
    
    for hand in hands:
        hand_data = converter.convert_hand_to_training_data(hand)
        training_data.extend(hand_data)
    
    print(f"Generated {len(training_data)} training examples from {len(hands)} hands")
    
    # Set up supervised learning (classification)
    criterion = nn.MSELoss()
    
    loss_values = []
    i = 0
    # Training loop
    for epoch in range(num_epochs):
        total_loss = 0.0
        random.shuffle(training_data)
        
        for state, action_idx, reward in training_data:
            # Convert state to tensor
            state_tensor = torch.FloatTensor(state).unsqueeze(0)
            # Forward pass
            agent.model.train()
            q_values = agent.model(state_tensor)
            
            # Create target: set the Q-value for the taken action to the reward
            target = q_values.clone().detach()
            target[0, action_idx] = reward
            
            # Compute loss and update weights
            agent.optimizer.zero_grad()
            loss = criterion(q_values, target)
            loss.backward()
            agent.optimizer.step()
            
            total_loss += loss.item()
        
        # Store the average loss for this epoch
        avg_loss = total_loss / len(training_data)
        loss_values.append(avg_loss)
        
        # Print progress
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(training_data):.6f}")
    
    # Update target network to match the trained network
    agent.update_target_model()
    
    print("Pretraining complete")
    
    # Plot the loss values
    plt.figure(figsize=(10, 5))
    plt.plot(range(1, num_epochs + 1), loss_values, marker='o')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Loss Throughout Epochs')
    plt.grid(True)
    plt.show()
    
    return agent


# Create agent
state_size = 16
action_size = 10
agent = DQNAgent(state_size, action_size, player_id=0)

# Pretrain from hand history
pretrained_agent = pretrain_hero_agent_from_history(
    agent=agent,
    num_epochs=10
)

---

# Push files to Repo

In [None]:
import os

output_folder = "outputs/models/sl"
try:
  os.makedirs(name=output_folder) # outputs/models/sl folder
except Exception as e:
  print(e)

In [None]:
# Save pretrained model
pretrained_agent.save("outputs/models/sl/pretrained_hero_agent.pt", player_total=500)