In [1]:
import numpy as np
import gymnasium as gym
from gymnasium import spaces
import random
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from SnakeEnv import SnakeEnv

In [2]:
class ActorCritic(nn.Module):
    def __init__(self, input_dim, action_dim):
        super(ActorCritic, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.actor = nn.Linear(128, action_dim)
        self.critic = nn.Linear(128, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        policy_logits = self.actor(x)
        value = self.critic(x)
        return policy_logits, value

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn.utils.convert_parameters import parameters_to_vector
from tqdm import tqdm
import numpy as np

# Actor-Critic Network
class ActorCritic(nn.Module):
    def __init__(self, input_dim, action_dim):
        super(ActorCritic, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.actor = nn.Linear(128, action_dim)
        self.critic = nn.Linear(128, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        policy_logits = self.actor(x)
        value = self.critic(x)
        return policy_logits, value

# Actor-Critic Agent
class ActorCriticAgent:
    def __init__(self, env, learning_rate=0.001, gamma=0.99):
        self.env = env
        self.gamma = gamma
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        input_dim = env.observation_space.shape[0] * env.observation_space.shape[1] * env.observation_space.shape[2]
        self.model = ActorCritic(input_dim, env.action_space.n).to(self.device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)

    def choose_action(self, state):
        state = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(self.device)
        policy_logits, _ = self.model(state)
        policy = F.softmax(policy_logits, dim=-1)
        action_dist = torch.distributions.Categorical(policy)
        action = action_dist.sample().item()
        return action

    def compute_returns(self, rewards, dones):
        returns = []
        g = 0
        for reward, done in zip(reversed(rewards), reversed(dones)):
            g = reward + self.gamma * g * (1 - done)
            returns.insert(0, g)
        return torch.tensor(returns, dtype=torch.float32).to(self.device)

    def train(self, num_episodes):
        for episode in tqdm(range(num_episodes), desc='Training', unit='Episode'):
            state = self.env.reset()
            state = state.flatten()
            done = False

            states, actions, rewards, dones = [], [], [], []

            while not done:
                action = self.choose_action(state)
                next_state, reward, done, _ = self.env.step(action)
                next_state = next_state.flatten()

                states.append(state)
                actions.append(action)
                rewards.append(reward)
                dones.append(done)

                state = next_state

            # Compute discounted returns
            returns = self.compute_returns(rewards, dones)

            # Convert lists to tensors
            states_tensor = torch.tensor(np.array(states), dtype=torch.float32).to(self.device)
            actions_tensor = torch.tensor(actions, dtype=torch.long).to(self.device)
            returns_tensor = returns.unsqueeze(-1)

            # Forward pass
            policy_logits, values = self.model(states_tensor)

            # Compute advantages
            advantages = returns_tensor - values

            # Actor loss
            log_probs = F.log_softmax(policy_logits, dim=-1)
            selected_log_probs = log_probs.gather(1, actions_tensor.unsqueeze(-1))
            actor_loss = -selected_log_probs * advantages.detach()

            # Critic loss
            critic_loss = advantages.pow(2)

            # Total loss
            loss = actor_loss + critic_loss

            # Backpropagation
            self.optimizer.zero_grad()
            loss.mean().backward()
            self.optimizer.step()

In [4]:
env = SnakeEnv(grid_size=5)
agent = ActorCriticAgent(env)

In [None]:
agent.train(num_episodes=50000)

Training:   8%|▊         | 4052/50000 [00:17<03:41, 207.43Episode/s]

In [None]:
state = env.reset()
done = False
while not done:
    action = agent.choose_action(state.flatten())
    state, reward, done, _ = env.step(action)
    env.render()
    print(f"Reward: {reward}")

Previous distace: 1
Current distance: 2
. . . . . 
. . F . . 
. . . S . 
. . . . . 
. . . . . 
Reward: -1
Previous distace: 2
Current distance: 3
. . . . . 
. . F . . 
. . . . S 
. . . . . 
. . . . . 
Reward: -1
Previous distace: 3
Current distance: 2
. . . . . 
. . F . S 
. . . . . 
. . . . . 
. . . . . 
Reward: 1
Previous distace: 2
Current distance: 1
. . . . . 
. . F S . 
. . . . . 
. . . . . 
. . . . . 
Reward: 1
Previous distace: 1
Current distance: 0
. . . . . 
. . S S . 
. F . . . 
. . . . . 
. . . . . 
Reward: 11
Previous distace: 2
Current distance: 1
. . . . . 
. S S . . 
. F . . . 
. . . . . 
. . . . . 
Reward: 1
Previous distace: 1
Current distance: 2
. . . . . 
S S . . . 
. F . . . 
. . . . . 
. . . . . 
Reward: -1
. . . . . 
S S . . . 
. F . . . 
. . . . . 
. . . . . 
Reward: -10
