In [1]:
import torch
from torch import nn
import numpy as np
from long_nardy import LongNardy
from state import State
from typing import Tuple, List

In [2]:
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")

Using cuda device


In [3]:
class ANN(nn.Module):
    def __init__(self):
        super().__init__()

        hidden_layers = []
        hidden_layers.append(nn.Linear(100, 40))
        hidden_layers.append(nn.ReLU())
        for _ in range(80):
            hidden_layers.append(nn.Linear(40, 40))
            hidden_layers.append(nn.ReLU())
        hidden_layers.append(nn.Linear(40, 1))
        hidden_layers.append(nn.Sigmoid())

        self.net = nn.Sequential(*hidden_layers)

    def forward(self, x):
        return self.net(x)

In [4]:
class Agent(nn.Module):
    def __init__(self, lr, gamma, epsilon, lambda_):
        super().__init__()
        self.net = ANN()
        self.gamma = gamma
        self.epsilon = epsilon
        self.lambda_ = lambda_
        self.eligibility_traces = {name: torch.zeros_like(param).to(device) for name, param in self.net.named_parameters()}
        self.optimizer = torch.optim.Adam(self.net.parameters(), lr=lr)

    def get_value(self, state):
        """Predict V(s) for a state tensor."""
        state_tensor = torch.tensor(state.get_tensor_representation(), dtype=torch.float32, requires_grad=True).to(device)
        return self.net(state_tensor)

    def update_eligibility_traces(self):
        """Decay and reset eligibility traces."""
        with torch.no_grad():
            for name in self.eligibility_traces:
                self.eligibility_traces[name] *= (self.gamma * self.lambda_)

    def reset_eligibility_traces(self):
        """Reset traces to zero."""
        for name in self.eligibility_traces:
            self.eligibility_traces[name].zero_()

    def epsilon_greedy(self, candidate_states: List[State]):
        """Select a state from candidates using ε-greedy policy."""
        if np.random.rand() < self.epsilon:
            # Random exploration
            chosen_idx = np.random.randint(len(candidate_states))
            value = self.get_value(candidate_states[chosen_idx])
        else:
            values = [self.get_value(state) for state in candidate_states]
            values_float = [value.item() for value in values]
            chosen_idx = np.argmax(values_float)
            value = values[chosen_idx]
        return candidate_states[chosen_idx], value
    
    def update_weights(self, td_error):
        """Update network weights using eligibility traces and TD error."""
        for name, param in self.net.named_parameters():
            param.data += self.optimizer.param_groups[0]['lr'] * td_error * self.eligibility_traces[name]

In [5]:
game = LongNardy()

In [6]:
agent1 = Agent(0.9, 0.9, 0.9, 0.9).to(device)
agent2 = Agent(0.9, 0.9, 0.9, 0.9).to(device)

In [None]:
num_episodes = 1
state_generator = game
for episode in range(num_episodes):
    # Initialize episode
    candidate_states = state_generator.get_states_after_dice()
    done = False
    agent_turn = 0  # Alternate turns between agents (0: agent1, 1: agent2)

    i = 0
    while not done:
        # state_generator.state.pretty_print()
        i += 1
        print(f"Turn {i}")
        # Select agent based on turn
        agent = agent1 if agent_turn == 0 else agent2

        # 1. Select state using ε-greedy
        chosen_state, current_value = agent.epsilon_greedy(candidate_states)
        
        # 2. Observe reward and next states
        state_generator.step(chosen_state)

        if state_generator.is_finished():
            reward = 1
            next_value = 0.0
            done = True
        else:
            reward = 0
            # print("getting states")
            next_candidate_states = state_generator.get_states_after_dice()
            # print(f"got {len(next_candidate_states)} states")
            if (len(next_candidate_states) > 1000):
                # print("STOP")
                episode = num_episodes
                error_state = state_generator.state.copy()
                error_states = next_candidate_states
                break

            if len(next_candidate_states) == 0:
                state_generator.step(chosen_state)
                continue
            
            # 3. Compute TD error
            with torch.no_grad():
                # print("calculating next values")
                next_values = [agent.get_value(ns).item() for ns in next_candidate_states]
                # print("calculated next values")
                next_value = max(next_values)

        td_error = reward + agent.gamma * next_value - current_value.item()

        # 4. Compute gradients and update eligibility traces
        # print("zeroing grad")
        agent.net.zero_grad()
        # print("backprop")
        current_value.backward()
        agent.update_eligibility_traces()
        for name, param in agent.net.named_parameters():
            agent.eligibility_traces[name] += param.grad

        # 5. Update weights
        agent.update_weights(td_error)

        # 6. Prepare for next step
        candidate_states = next_candidate_states

        # Switch turns
        agent_turn = 1 - agent_turn

    if episode < 10:
        print(f"Episode {episode+1}")
    elif episode < 100 and episode % 10 == 0:
        print(f"Episode {episode+1}")
    elif episode < 1000 and episode % 100:
        print(f"Episode {episode+1}")
    elif episode % 1000:
        print(f"Episode {episode+1}")

        torch.save(agent1.state_dict(), f"agent1_epoch_{episode}_latest.pth")
        torch.save(agent2.state_dict(), f"agent2_epoch_{episode}_latest.pth")
game.state.reset()


Nardy Board State:
-------------------------------------------------
12 [ .]  13 [ .]  14 [ .]  15 [ .]  16 [ .]  17 [ .]  18 [ .]  19 [ 1]  20 [ .]  21 [ .]  22 [ .]  23 [14]  
-------------------------------------------------
11 [-15]  10 [ .]   9 [ .]   8 [ .]   7 [ .]   6 [ .]   5 [ .]   4 [ .]   3 [ .]   2 [ .]   1 [ .]   0 [ .]  
-------------------------------------------------

White Turn:  True
Dice:  [4, 4, 4]
White off:  0
Black off:  0

Nardy Board State:
-------------------------------------------------
12 [ .]  13 [ .]  14 [ .]  15 [ 1]  16 [ .]  17 [ .]  18 [ .]  19 [ .]  20 [ .]  21 [ .]  22 [ .]  23 [14]  
-------------------------------------------------
11 [-15]  10 [ .]   9 [ .]   8 [ .]   7 [ .]   6 [ .]   5 [ .]   4 [ .]   3 [ .]   2 [ .]   1 [ .]   0 [ .]  
-------------------------------------------------

White Turn:  True
Dice:  [4, 4]
White off:  0
Black off:  0

Nardy Board State:
-------------------------------------------------
12 [ .]  13 [ .]  14 [ .]  

KeyboardInterrupt: 

In [None]:
error_state.dice_remaining

In [None]:
error_state.pretty_print()

In [None]:
test = LongNardy()
test.state = error_state

In [None]:
test_states = test.get_states_after_dice()

In [None]:
# import json
# with open("error_state.json", "w") as f:
#     json.dump(error_state.to_dict(), f)