In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import sys
import os
# Add the project root directory to sys.path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '../..')))

from src.ParObsSnakeEnv import SnakeEnv

In [13]:
# Define the Policy Network
class PolicyNetwork(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=8):
        super(PolicyNetwork, self).__init__()
        self.hidden_dim = hidden_dim
        self.fc1 = nn.Linear(input_dim, self.hidden_dim)
        self.fc2 = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.fc3 = nn.Linear(self.hidden_dim, output_dim)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return self.softmax(x)

In [None]:
class REINFORCEAgent:
    def __init__(self, state_dim, action_dim, lr=0.001, gamma=0.99, device='cpu'):
        self.device = device
        self.policy = PolicyNetwork(state_dim, action_dim).to(self.device)
        self.optimizer = optim.Adam(self.policy.parameters(), lr=lr)
        self.gamma = gamma

    def act(self, state):
        state = torch.FloatTensor(state).unsqueeze(0).to(self.device)  # Add batch dimension
        action_probs = self.policy(state).detach().cpu().numpy().squeeze()
        print(action_probs)
        action = np.random.choice(len(action_probs), p=action_probs)
        return action

    def compute_returns(self, rewards):
        """Compute discounted returns for an episode."""
        returns = []
        G = 0
        for reward in reversed(rewards):
            G = reward + self.gamma * G
            returns.insert(0, G)
        returns = torch.FloatTensor(returns).to(self.device)
        # Normalize returns to improve training stability
        returns = (returns - returns.mean()) / (returns.std() + 1e-5)
        return returns

    def update_policy(self, log_probs, returns):
        """Perform policy gradient update."""
        loss = -torch.sum(torch.stack(log_probs) * returns)  # Negative log-prob * return
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

# Train the REINFORCE Agent
def train_reinforce(env, agent, episodes=1000):
    for episode in range(episodes):
        state = env.reset()
        log_probs = []
        rewards = []
        total_reward = 0
        done = False

        while not done:
            action = agent.act(state)
            state_tensor = torch.FloatTensor(state).to(agent.device)
            action_prob = agent.policy(state_tensor)[action]
            log_prob = torch.log(action_prob)
            log_probs.append(log_prob)

            next_state, reward, done, _ = env.step(action)
            rewards.append(reward)
            total_reward += reward
            state = next_state

        # Compute returns and update policy
        returns = agent.compute_returns(rewards)
        agent.update_policy(log_probs, returns)

        print(f"Episode {episode + 1}/{episodes}, Total Reward: {total_reward}")

In [15]:
env = SnakeEnv(grid_size=10, interact=False)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

# Determine the device (CPU or GPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

agent = REINFORCEAgent(state_dim, action_dim, device=device)

train_reinforce(env, agent, episodes=500)

Using device: cuda
[0.22000895 0.21754985 0.31720832 0.24523295]
[0.22000895 0.21754985 0.31720832 0.24523295]
[0.22000895 0.21754985 0.31720832 0.24523295]
[0.22777414 0.2177263  0.30226052 0.25223905]
[0.22777414 0.2177263  0.30226052 0.25223905]
[0.2214016  0.21717985 0.30085826 0.26056027]
[0.24398457 0.21784614 0.29525647 0.24291283]
[0.22536449 0.22169687 0.30950156 0.24343705]
[0.22536449 0.22169687 0.30950156 0.24343705]
[0.2441416  0.22446391 0.29220763 0.23918685]
Episode 1/500, Total Reward: -74
[0.2279274  0.23163623 0.29399836 0.24643804]
[0.23240316 0.23216662 0.2955215  0.23990868]
[0.2457022  0.2336044  0.27908117 0.24161229]
[0.25170746 0.22428249 0.2686575  0.2553525 ]
[0.25170746 0.22428249 0.2686575  0.2553525 ]
[0.23628801 0.22779863 0.28580946 0.2501039 ]
[0.23439482 0.21239856 0.28339705 0.2698096 ]
[0.22658981 0.21833886 0.3022195  0.25285184]
[0.22022016 0.21780613 0.3007522  0.26122156]
[0.22739038 0.23511699 0.2947845  0.24270815]
[0.23439482 0.21239856 0.283

  returns = (returns - returns.mean()) / (returns.std() + 1e-5)


ValueError: probabilities contain NaN