In [1]:
import numpy as np
import gymnasium as gym
from gymnasium import spaces
import random
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from SnakeEnv import SnakeEnv

In [3]:
class ActorCritic(nn.Module):
    def __init__(self, input_dim, action_dim):
        super(ActorCritic, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.actor = nn.Linear(128, action_dim)
        self.critic = nn.Linear(128, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        policy_logits = self.actor(x)
        value = self.critic(x)
        return policy_logits, value

# Actor-Critic Agent
class ActorCriticAgent:
    def __init__(self, env, learning_rate=0.001, gamma=0.99):
        self.env = env
        self.gamma = gamma
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = ActorCritic(env.observation_space.shape[0] * env.observation_space.shape[1] * env.observation_space.shape[2], env.action_space.n).to(self.device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)

    def choose_action(self, state):
        state = torch.tensor(state, dtype=torch.float32).to(self.device)
        policy_logits, _ = self.model(state)
        policy = torch.softmax(policy_logits, dim=-1)
        action = torch.multinomial(policy, 1).item()
        return action

    def train(self, num_episodes):
        for episode in tqdm(range(num_episodes), desc='Training', unit='Episode'):
            state = self.env.reset()
            state = state.flatten()
            done = False
            while not done:
                action = self.choose_action(state)
                next_state, reward, done, _ = self.env.step(action)
                next_state = next_state.flatten()

                state_tensor = torch.tensor(state, dtype=torch.float32).to(self.device)
                next_state_tensor = torch.tensor(next_state, dtype=torch.float32).to(self.device)
                reward_tensor = torch.tensor(reward, dtype=torch.float32).to(self.device)
                done_tensor = torch.tensor(done, dtype=torch.float32).to(self.device)

                policy_logits, value = self.model(state_tensor)
                _, next_value = self.model(next_state_tensor)

                advantage = reward_tensor + (1 - done_tensor) * self.gamma * next_value - value
                actor_loss = -torch.log(torch.softmax(policy_logits, dim=-1)[action]) * advantage
                critic_loss = advantage.pow(2)

                loss = actor_loss + critic_loss

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                state = next_state

In [6]:
env = SnakeEnv(grid_size=5)
agent = ActorCriticAgent(env)
agent.train(num_episodes=10000)

Training: 100%|██████████| 10000/10000 [01:02<00:00, 161.10Episode/s]


In [7]:
state = env.reset()
done = False
while not done:
    action = agent.choose_action(state.flatten())
    state, reward, done, _ = env.step(action)
    env.render()
    print(f"Reward: {reward}")

. . . . . 
. . . . . 
. . . S . 
. . . . . 
. . F . . 
Reward: -1
. . . . . 
. . . . . 
. . . . S 
. . . . . 
. . F . . 
Reward: -1
. . . . . 
. . . . . 
. . . . S 
. . . . . 
. . F . . 
Reward: -10
