# Import Dependencies

In [1]:
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
import random
import cv2

In [6]:
environment_name = 'Seaquest-v4'
env = gym.make(environment_name, render_mode="human")


# Define the neural network with convolutional layers for the Q-learning agent

In [7]:
class QNetwork(nn.Module):
    def __init__(self, action_size):
        super(QNetwork, self).__init__()
        self.conv1 = nn.Conv2d(4, 32, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 512)  # Adjust dimensions according to the conv layers' output
        self.fc2 = nn.Linear(512, action_size)
    
    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.relu(self.conv3(x))
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Preprocess the frames to grayscale and resize
def preprocess_frame(frame):
    frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
    frame = cv2.resize(frame, (84, 84))
    return frame / 255.0

# Stack frames for better temporal understanding
def stack_frames(frames, frame):
    frames.append(frame)
    if len(frames) < 4:
        for _ in range(4 - len(frames)):
            frames.append(frame)
    stacked_frames = np.stack(frames, axis=0)
    return stacked_frames

# Function to select an action using an epsilon-greedy policy
def select_action(state, q_network, epsilon):
    if random.random() < epsilon:
        return env.action_space.sample()
    else:
        with torch.no_grad():
            state = torch.FloatTensor(state).unsqueeze(0)
            q_values = q_network(state)
            return q_values.argmax().item()

# Trianing the DQN agent

In [8]:
# Training the DQN agent
def train_dqn(env, num_episodes=1000, gamma=0.99, epsilon_start=1.0, epsilon_end=0.01, epsilon_decay=0.995, lr=0.001):
    action_size = env.action_space.n
    q_network = QNetwork(action_size)
    optimizer = optim.Adam(q_network.parameters(), lr=lr)
    memory = deque(maxlen=10000)
    epsilon = epsilon_start
    batch_size = 64
    frames = deque(maxlen=4)

    for episode in range(num_episodes):
        state, _ = env.reset()
        state = preprocess_frame(state)
        frames.append(state)
        stacked_state = stack_frames(frames, state)
        done = False
        total_reward = 0

        while not done:
            action = select_action(stacked_state, q_network, epsilon)
            next_state, reward, done, _, _ = env.step(action)
            next_state = preprocess_frame(next_state)
            stacked_next_state = stack_frames(frames, next_state)
            memory.append((stacked_state, action, reward, stacked_next_state, done))
            stacked_state = stacked_next_state
            total_reward += reward

            if len(memory) >= batch_size:
                batch = random.sample(memory, batch_size)
                states, actions, rewards, next_states, dones = zip(*batch)
                
                states = torch.FloatTensor(np.array(states))
                actions = torch.LongTensor(actions).unsqueeze(1)
                rewards = torch.FloatTensor(rewards).unsqueeze(1)
                next_states = torch.FloatTensor(np.array(next_states))
                dones = torch.FloatTensor(dones).unsqueeze(1)
                
                q_values = q_network(states).gather(1, actions)
                next_q_values = q_network(next_states).max(1)[0].unsqueeze(1)
                target_q_values = rewards + gamma * next_q_values * (1 - dones)

                loss = nn.MSELoss()(q_values, target_q_values)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        epsilon = max(epsilon_end, epsilon * epsilon_decay)
        print(f"Episode {episode + 1}/{num_episodes}, Total Reward: {total_reward:.2f}")

    return q_network

In [9]:

q_network = train_dqn(env)

# Save the trained model
torch.save(q_network.state_dict(), 'seaquest_dqn.pth')
env.close()

  states = torch.FloatTensor(states)


KeyboardInterrupt: 

In [11]:
env.close()