<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Complete_DQN_Code_with_Gym_API_Compatibility.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install numpy==1.23.5

In [None]:
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from collections import deque

# DQN Network
class DQN(nn.Module):
    def __init__(self, state_size, action_size):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(state_size, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, action_size)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

# ε-greedy action selection
def select_action(state, policy_net, epsilon, action_size):
    if random.random() < epsilon:
        return torch.tensor([[random.randrange(action_size)]], dtype=torch.long)
    with torch.no_grad():
        return policy_net(state).argmax(dim=1).view(1, 1)

# Replay Memory
class ReplayMemory:
    def __init__(self, capacity):
        self.memory = deque(maxlen=capacity)

    def push(self, transition):
        self.memory.append(transition)

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

# Optimize model
def optimize_model(memory, policy_net, target_net, optimizer, batch_size, gamma):
    if len(memory) < batch_size:
        return

    transitions = memory.sample(batch_size)
    state_batch, action_batch, reward_batch, next_state_batch, done_batch = zip(*transitions)

    state_batch = torch.cat(state_batch)
    action_batch = torch.cat(action_batch)
    reward_batch = torch.cat(reward_batch)
    next_state_batch = torch.cat(next_state_batch)
    done_batch = torch.tensor(done_batch, dtype=torch.float32).unsqueeze(1)

    current_q_values = policy_net(state_batch).gather(1, action_batch)
    next_q_values = target_net(next_state_batch).max(1)[0].unsqueeze(1).detach()
    target_q_values = reward_batch + (1 - done_batch) * gamma * next_q_values

    loss = nn.MSELoss()(current_q_values, target_q_values)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

# Set up environment and networks
env = gym.make("CartPole-v1")
state_size = env.observation_space.shape[0]
action_size = env.action_space.n

policy_net = DQN(state_size, action_size)
target_net = DQN(state_size, action_size)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer = optim.Adam(policy_net.parameters(), lr=1e-3)
memory = ReplayMemory(10000)

# Hyperparameters
batch_size = 64
gamma = 0.99
epsilon = 1.0
epsilon_decay = 0.995
epsilon_min = 0.01
target_update_freq = 10
num_episodes = 300

# Training loop
for episode in range(num_episodes):
    reset_result = env.reset()
    state_np = reset_result[0] if isinstance(reset_result, tuple) else reset_result
    state = torch.from_numpy(np.array(state_np)).unsqueeze(0).float()
    total_reward = 0

    for t in range(500):
        action = select_action(state, policy_net, epsilon, action_size)
        step_result = env.step(action.item())
        if len(step_result) == 5:
            next_state_np, reward, terminated, truncated, _ = step_result
        else:
            next_state_np, reward, done, _ = step_result
            terminated = done
            truncated = False
        done = terminated or truncated

        next_state = torch.from_numpy(np.array(next_state_np)).unsqueeze(0).float()
        reward_tensor = torch.tensor([[reward]], dtype=torch.float32)

        memory.push((state, action, reward_tensor, next_state, done))
        state = next_state
        total_reward += reward

        optimize_model(memory, policy_net, target_net, optimizer, batch_size, gamma)

        if done:
            break

    epsilon = max(epsilon_min, epsilon * epsilon_decay)

    if episode % target_update_freq == 0:
        target_net.load_state_dict(policy_net.state_dict())

    print(f"Episode {episode+1}, Total Reward: {total_reward:.1f}, Epsilon: {epsilon:.3f}")

env.close()