In [None]:
import torch
from torch import nn
import numpy as np
from long_nardy import LongNardy
from state import State
from typing import Tuple, List

In [None]:
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")

In [None]:
class ANN(nn.Module):
    def __init__(self):
        super().__init__()

        self.net = nn.Sequential(
            nn.Linear(100, 80),
            nn.Sigmoid(),
            nn.Linear(80, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.net(x)

In [None]:
class Agent(nn.Module):
    def __init__(self, lr=0.1, epsilon=0.1):
        super().__init__()
        self.net = ANN().to(device)
        self.epsilon = epsilon
        self.lr = lr
        self.eligibility_traces = {name: torch.zeros_like(param) 
                                  for name, param in self.net.named_parameters()}
        
    def get_value(self, state: State, grad=False):
        """Get V(s) with optional gradient tracking"""
        with torch.set_grad_enabled(grad):
            state_tensor = torch.tensor(state.get_tensor_representation(), 
                                      dtype=torch.float32).to(device)
            return self.net(state_tensor)
        
    def update_eligibility_traces(self):
        """Update traces with current gradients (lambda=1, gamma=1)"""
        with torch.no_grad():
            for name, param in self.net.named_parameters():
                self.eligibility_traces[name] = param.grad + self.eligibility_traces[name]

    def reset_eligibility_traces(self):
        for name in self.eligibility_traces:
            self.eligibility_traces[name].zero_()
            
    def epsilon_greedy(self, candidate_states: List, current_player: int) -> Tuple[State, torch.Tensor]:
        """Epsilon-greedy selection considering player perspective"""
        if np.random.rand() < self.epsilon:
            chosen_idx = np.random.randint(len(candidate_states))
            chosen_state = candidate_states[chosen_idx]
            value = self.get_value(chosen_state)
        else:
            # Evaluate all states without gradient tracking
            with torch.no_grad():
                values = [self.get_value(state) for state in candidate_states]
                
            # Flip values for opponent's perspective
            if current_player == 1:
                values = [1 - v for v in values]
                
            chosen_idx = np.argmax([v.item() for v in values])
            chosen_state = candidate_states[chosen_idx]
            value = values[chosen_idx]
            
        return chosen_state, value

In [None]:
agent = Agent(lr=0.1, epsilon=0.1)

In [None]:
num_episodes = 1000000
save_interval = 100
for episode in range(num_episodes):
    game = LongNardy()
    agent.reset_eligibility_traces()
    current_player = 0  # 0 = agent, 1 = opponent (both using same network)
    done = False
    
    while not done:
        # Get legal moves for current player
        candidate_states = game.get_states_after_dice(current_player)
        
        if not candidate_states:
            # Handle no valid moves according to game rules
            game.apply_dice(game.state)
            current_player = 1 - current_player
            continue
            
        # Select move using epsilon-greedy
        chosen_state, value = agent.epsilon_greedy(candidate_states, current_player)
        
        # Store pre-update state for TD calculation
        prev_value = value.detach()
        if current_player == 1:
            prev_value = 1 - prev_value  # Flip perspective for opponent
            
        # Make the move
        game.apply_dice(chosen_state)
        
        # Check terminal state
        if game.is_finished():
            reward = 1 if current_player == 0 else 0  # Reward for original agent
            next_value = torch.tensor(0.0, device=device)
            done = True
        else:
            reward = 0
            # Next player's value is flipped perspective
            with torch.no_grad():
                next_candidates = game.get_states_after_dice()
                if next_candidates:
                    next_values = [agent.get_value(s) for s in next_candidates]
                    # For current_player == 0 (agent), next player is opponent (minimize agent's value)
                    next_value = min(next_values) if current_player == 0 else max(next_values)
                else:
                    game.apply_dice(game.state)
                    current_player = 1 - current_player
                    continue
        
        # Calculate TD error from current player's perspective
        td_error = reward + next_value - prev_value
        
        agent.net.zero_grad()
        value.backward()
        agent.update_eligibility_traces()

        # Update weights for all moves:
        with torch.no_grad():
            for name, param in agent.net.named_parameters():
                param += agent.lr * td_error * agent.eligibility_traces[name]
        
        # Switch player perspective
        current_player = 1 - current_player

    # Periodic saving and logging
    if episode % save_interval == 0:
        torch.save(agent.state_dict(), f"td_gammon_selfplay_{episode}.pth")
        print(f"Episode {episode} | Avg TD Error: {td_error.item():.4f}")