In [None]:
import torch
from torch import nn
import numpy as np
from long_nardy import LongNardy

In [None]:
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")

In [None]:
class ANN(nn.Module):
    def __init__(self):
        super.__init__()

        hidden_layers = []
        hidden_layers.append(nn.Linear(100, 40))
        hidden_layers.append(nn.ReLU())
        for _ in range(80):
            hidden_layers.append(nn.Linear(40, 40))
            hidden_layers.append(nn.ReLU())
        hidden_layers.append(nn.Sigmoid())

        self.sequence = nn.Sequential(*hidden_layers)

    def forward(self, x):
        return self.sequence(x)

In [None]:
game = LongNardy()

In [None]:
model = ANN().to(device)

In [None]:
alpha=0.9

In [None]:
eligibility_traces = {name: torch.zeros_like(param) for name, param in model.named_parameters()}
optimizer = torch.optim.Adam(model.parameters(), lr=alpha)

In [None]:
def train_agent(agent, num_episodes, state_generator):
    for episode in range(num_episodes):
        # Initialize episode
        game.roll_dice()
        candidate_states = game.get_valid_moves()
        episode_history = []
        done = False

        while not done:
            # 1. Select state using ε-greedy
            chosen_state, _ = agent.epsilon_greedy(candidate_states)
            
            # 2. Observe reward and next states
            reward, next_candidate_states = state_generator.step(chosen_state)
            
            # 3. Compute TD error
            with torch.no_grad():
                if not next_candidate_states:  # Terminal state
                    next_value = 0.0
                    done = True
                else:
                    next_values = [agent.get_value(ns).item() for ns in next_candidate_states]
                    next_value = max(next_values)
                    
            current_value = agent.get_value(chosen_state)
            td_error = reward + agent.gamma * next_value - current_value.item()

            # 4. Compute gradients and update eligibility traces
            agent.net.zero_grad()
            current_value.backward()
            agent.update_eligibility_traces()
            for name, param in agent.net.named_parameters():
                agent.eligibility_traces[name] += param.grad

            # 5. Update weights
            agent.update_weights(td_error)

            # 6. Prepare for next step
            candidate_states = next_candidate_states
            agent.decay_epsilon()

        print(f"Episode {episode+1}, Final Reward: {reward:.2f}")