In [2]:
import torch
from torch import nn
import numpy as np
from long_nardy import LongNardy

In [None]:
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")

In [None]:
class ANN(nn.Module):
    def __init__(self):
        super.__init__()

        hidden_layers = []
        hidden_layers.append(nn.Linear(100, 40))
        hidden_layers.append(nn.ReLU())
        for _ in range(80):
            hidden_layers.append(nn.Linear(40, 40))
            hidden_layers.append(nn.ReLU())
        hidden_layers.append(nn.Sigmoid())

        self.net = nn.Sequential(*hidden_layers)

    def forward(self, x):
        return self.net(x)

In [None]:
class Agent(nn.Module):
    def __init__(self, lr, gamma):
        self.net = ANN()
        self.gamma = gamma
        self.eligibility_traces = {name: torch.zeros_like(param) for name, param in self.net.named_parameters()}
        self.optimizer = torch.optim.Adam(self.net.parameters(), lr=lr)

    def get_value(self, state):
        """Predict V(s) for a state tensor."""
        return self.net(state)

    def update_eligibility_traces(self):
        """Decay and reset eligibility traces."""
        with torch.no_grad():
            for name in self.eligibility_traces:
                self.eligibility_traces[name] *= (self.gamma * self.lambda_)

    def reset_eligibility_traces(self):
        """Reset traces to zero."""
        for name in self.eligibility_traces:
            self.eligibility_traces[name].zero_()

    def epsilon_greedy(self, candidate_states):
        """Select a state from candidates using ε-greedy policy."""
        if np.random.rand() < self.epsilon:
            # Random exploration
            chosen_idx = np.random.randint(len(candidate_states))
        else:
            # Exploitation: choose state with highest value
            with torch.no_grad():
                values = [self.net(state).item() for state in candidate_states]
            chosen_idx = np.argmax(values)
        return candidate_states[chosen_idx], chosen_idx
    
    def update_weights(self, td_error):
        """Update network weights using eligibility traces and TD error."""
        for name, param in self.net.named_parameters():
            param.data += self.optimizer.param_groups[0]['lr'] * td_error * self.eligibility_traces[name]


In [None]:
game = LongNardy()

In [None]:
agent1 = Agent().to(device)
agent2 = Agent().to(device)

In [None]:
def train_two_agents(agent1: Agent, agent2: Agent, num_episodes, state_generator: LongNardy):
    for episode in range(num_episodes):
        # Initialize episode
        state_generator.roll_dice()
        candidate_states = state_generator.get_valid_moves()
        done = False
        agent_turn = 0  # Alternate turns between agents (0: agent1, 1: agent2)

        while not done:
            # Select agent based on turn
            agent = agent1 if agent_turn == 0 else agent2

            # 1. Select state using ε-greedy
            chosen_state, _ = agent.epsilon_greedy(candidate_states)
            
            # 2. Observe reward and next states
            state_generator.step(chosen_state)

            if state_generator.is_finished():
                reward = 1
            else:
                reward = 0
            
            state_generator.roll_dice()
            next_candidate_states = state_generator.get_valid_moves()
            
            # 3. Compute TD error
            with torch.no_grad():
                if not next_candidate_states:  # Terminal state
                    next_value = 0.0
                    done = True
                else:
                    next_values = [agent.get_value(ns).item() for ns in next_candidate_states]
                    next_value = max(next_values)
                    
            current_value = agent.get_value(chosen_state)
            td_error = reward + agent.gamma * next_value - current_value.item()

            # 4. Compute gradients and update eligibility traces
            agent.net.zero_grad()
            current_value.backward()
            agent.update_eligibility_traces()
            for name, param in agent.net.named_parameters():
                agent.eligibility_traces[name] += param.grad

            # 5. Update weights
            agent.update_weights(td_error)

            # 6. Prepare for next step
            candidate_states = next_candidate_states

            # Switch turns
            agent_turn = 1 - agent_turn

        if episode < 10:
            print(f"Episode {episode+1}")
        elif episode < 100 and episode % 10 == 0:
            print(f"Episode {episode+1}")
        elif episode < 1000 and episode % 100:
            print(f"Episode {episode+1}")
        elif episode % 1000:
            print(f"Episode {episode+1}")
    game.reset()