In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque
import numpy as np
import sys
import os

# Add the project root directory to sys.path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '../..')))

from src.ParObsSnakeEnv import ParObsSnakeEnv
from src.FullObsSnakeEnv import FullObsSnakeEnv

In [56]:
class DQN(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=128):
        super(DQN, self).__init__()
        self.hidden_dim = hidden_dim
        self.fc1 = nn.Linear(input_dim, self.hidden_dim)
        self.fc2 = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.fc3 = nn.Linear(self.hidden_dim, output_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [73]:
# Define the DQN agent
class DQNAgent:
    def __init__(self, state_dim, action_dim, lr=0.001, gamma=0.99, epsilon=1.0, epsilon_decay=0.9995, epsilon_min=0.01, memory_size=10000, batch_size=64):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay
        self.epsilon_min = epsilon_min
        self.batch_size = batch_size

        self.memory = deque(maxlen=memory_size)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = DQN(state_dim, action_dim).to(self.device)
        self.target_model = DQN(state_dim, action_dim).to(self.device)
        self.target_model.load_state_dict(self.model.state_dict())
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        self.criterion = nn.MSELoss()

    def act(self, state):
        if random.random() < self.epsilon:
            return random.randint(0, self.action_dim - 1)
        state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        state = state.reshape((1, -1))
        with torch.no_grad():
            # print(">>>>", state.shape)
            q_values = self.model(state)
        return torch.argmax(q_values).item()

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def replay(self):
        if len(self.memory) < self.batch_size:
            return

        batch = random.sample(self.memory, self.batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)

        states = torch.FloatTensor(states).to(self.device)
        actions = torch.LongTensor(actions).to(self.device)
        rewards = torch.FloatTensor(rewards).to(self.device)
        next_states = torch.FloatTensor(next_states).to(self.device)
        dones = torch.FloatTensor(dones).to(self.device)

        states = states.reshape((self.batch_size, -1))
        next_states = next_states.reshape((self.batch_size, -1))

        # Compute current Q-values
        q_values = self.model(states)
        q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)

        # Compute target Q-values
        next_q_values = self.target_model(next_states).max(1)[0]
        target_q_values = rewards + (1 - dones) * self.gamma * next_q_values

        # Update the Q-network
        loss = self.criterion(q_values, target_q_values)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Decay epsilon
        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)

    def update_target_model(self):
        self.target_model.load_state_dict(self.model.state_dict())

# Train the agent
def train_dqn(env, agent, episodes=1000, update_target_every=10):
    for episode in range(episodes):
        state = env.reset()
        total_reward = 0
        done = False

        while not done:
            action = agent.act(state)
            next_state, reward, done, _ = env.step(action)
            agent.remember(state, action, reward, next_state, done)
            agent.replay()
            state = next_state
            total_reward += reward

        agent.update_target_model()
        print(f"Episode {episode + 1}/{episodes}, Total Reward: {total_reward}, Epsilon: {agent.epsilon:.3f}")

        if (episode + 1) % update_target_every == 0:
            print("Updated target model.")

In [None]:
# env = FullObsSnakeEnv(grid_size=3, interact=False)
env = ParObsSnakeEnv(grid_size=10, interact=False)
if isinstance(env, FullObsSnakeEnv):
    state_dim = env.observation_space.shape[0] * env.observation_space.shape[1] * env.observation_space.shape[2]
else:
    state_dim = env.observation_space.shape[0]

action_dim = env.action_space.n
agent = DQNAgent(state_dim, action_dim)

In [None]:
train_dqn(env, agent, episodes=600)

Episode 1/6000, Total Reward: 79, Epsilon: 1.000
Episode 2/6000, Total Reward: -74, Epsilon: 1.000
Episode 3/6000, Total Reward: -75, Epsilon: 1.000
Episode 4/6000, Total Reward: -76, Epsilon: 1.000
Episode 5/6000, Total Reward: -74, Epsilon: 1.000
Episode 6/6000, Total Reward: -75, Epsilon: 1.000
Episode 7/6000, Total Reward: -73, Epsilon: 1.000
Episode 8/6000, Total Reward: -74, Epsilon: 1.000
Episode 9/6000, Total Reward: -74, Epsilon: 1.000
Episode 10/6000, Total Reward: 154, Epsilon: 1.000
Updated target model.
Episode 11/6000, Total Reward: -75, Epsilon: 1.000
Episode 12/6000, Total Reward: -76, Epsilon: 1.000
Episode 13/6000, Total Reward: -76, Epsilon: 1.000
Episode 14/6000, Total Reward: -76, Epsilon: 1.000
Episode 15/6000, Total Reward: -76, Epsilon: 1.000
Episode 16/6000, Total Reward: -75, Epsilon: 1.000
Episode 17/6000, Total Reward: -76, Epsilon: 1.000
Episode 18/6000, Total Reward: -75, Epsilon: 1.000
Episode 19/6000, Total Reward: -73, Epsilon: 1.000
Episode 20/6000, To

KeyboardInterrupt: 

In [None]:
if isinstance(env, FullObsSnakeEnv):
    env.interact = True
else:
    grid_size = env.grid_size
    env = ParObsSnakeEnv(grid_size=2*grid_size)
    
state = env.reset()
done = False
while not done:
    action = agent.act(state)
    state, reward, done, _ = env.step(action)
    env.render()
    print(f"Reward: {reward}")

Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: -1
Reward: -1
Reward: 1
Reward: 1
Reward: -1
Reward: -1
Reward: 1
Reward: 1
Reward: -1
Reward: -1
Reward: 1
Reward: 1
Reward: -1
Reward: -1
Reward: 1
Reward: 1
Reward: -1
Reward: -1
Reward: 1
Reward: 1
Reward: -1
Reward: -1
Reward: 1
Reward: 1
Reward: -1
Reward: -1
Reward: 1
Reward: 1


KeyboardInterrupt: 