In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import sys
import os
# Add the project root directory to sys.path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '../..')))

from src.ParObsSnakeEnv import ParObsSnakeEnv
from src.FullObsSnakeEnv import FullObsSnakeEnv

In [10]:
# Define the Policy Network
class PolicyNetwork(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=8):
        super(PolicyNetwork, self).__init__()
        self.hidden_dim = hidden_dim
        self.fc1 = nn.Linear(input_dim, self.hidden_dim)
        self.fc2 = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.fc3 = nn.Linear(self.hidden_dim, output_dim)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return self.softmax(x)

In [None]:
class REINFORCEAgent:
    def __init__(self, state_dim, action_dim, lr=0.001, gamma=0.99, device='cpu'):
        self.device = device
        self.policy = PolicyNetwork(state_dim, action_dim).to(self.device)
        self.optimizer = optim.Adam(self.policy.parameters(), lr=lr)
        self.gamma = gamma

    def act(self, state):
        state = torch.FloatTensor(state).unsqueeze(0).to(self.device)  # Add batch dimension
        action_probs = self.policy(state).detach().cpu().numpy().squeeze()
        action = np.random.choice(len(action_probs), p=action_probs)
        return action

    def compute_returns(self, rewards):
        """Compute discounted returns for an episode."""
        returns = []
        G = 0
        for reward in reversed(rewards):
            G = reward + self.gamma * G
            returns.insert(0, G)
        returns = torch.FloatTensor(returns).to(self.device)
        # Normalize returns to improve training stability
        if len(returns) > 1 and returns.std() > 1e-5:
            returns = (returns - returns.mean()) / (returns.std() + 1e-5)
        return returns

    def update_policy(self, log_probs, returns):
        """Perform policy gradient update."""
        loss = -torch.sum(torch.stack(log_probs) * returns)  # Negative log-prob * return
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

# Train the REINFORCE Agent
def train_reinforce(env, agent, episodes=1000):
    for episode in range(episodes):
        state = env.reset()
        log_probs = []
        rewards = []
        total_reward = 0
        done = False

        while not done:
            action = agent.act(state)
            state_tensor = torch.FloatTensor(state).to(agent.device)
            action_prob = agent.policy(state_tensor)[action]
            log_prob = torch.log(action_prob)
            log_probs.append(log_prob)

            next_state, reward, done, _ = env.step(action)
            rewards.append(reward)
            total_reward += reward
            state = next_state

        # Compute returns and update policy
        returns = agent.compute_returns(rewards)
        agent.update_policy(log_probs, returns)

        print(f"Episode {episode + 1}/{episodes}, Total Reward: {total_reward}")

In [29]:
env = ParObsSnakeEnv(grid_size=10, interact=False)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

# Determine the device (CPU or GPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

agent = REINFORCEAgent(state_dim, action_dim, device=device)

train_reinforce(env, agent, episodes=50000)

Using device: cuda
Episode 1/50000, Total Reward: -73
Episode 2/50000, Total Reward: -75
Episode 3/50000, Total Reward: -75
Episode 4/50000, Total Reward: -78
Episode 5/50000, Total Reward: -79
Episode 6/50000, Total Reward: -79
Episode 7/50000, Total Reward: -80
Episode 8/50000, Total Reward: -72
Episode 9/50000, Total Reward: -80
Episode 10/50000, Total Reward: -76
Episode 11/50000, Total Reward: -75
Episode 12/50000, Total Reward: -76
Episode 13/50000, Total Reward: -77
Episode 14/50000, Total Reward: 11
Episode 15/50000, Total Reward: -77
Episode 16/50000, Total Reward: -77
Episode 17/50000, Total Reward: -76
Episode 18/50000, Total Reward: -79
Episode 19/50000, Total Reward: -74
Episode 20/50000, Total Reward: -77
Episode 21/50000, Total Reward: -74
Episode 22/50000, Total Reward: -79
Episode 23/50000, Total Reward: -75
Episode 24/50000, Total Reward: 9
Episode 25/50000, Total Reward: -77
Episode 26/50000, Total Reward: -71
Episode 27/50000, Total Reward: -79
Episode 28/50000, Tot

KeyboardInterrupt: 

In [30]:
grid_size = 10

In [31]:
if isinstance(env, FullObsSnakeEnv):
    env.interact = True
else:
    env = ParObsSnakeEnv(grid_size=2*grid_size)
    
state = env.reset()
done = False
while not done:
    action = agent.act(state)
    state, reward, done, _ = env.step(action)
    env.render()
    print(f"Reward: {reward}")

Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: -1
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 76
Reward: -1
Reward: 1
Reward: 1
Reward: 1
Reward: 1