In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import gymnasium as gym

#### Define the neural network

In [None]:
class DQN(nn.Module):
    def __init__(self, n_obs, n_act):
        super(DQN, self).__init__()
        self.l1 = nn.Linear(n_obs, 128)
        self.l2 = nn.Linear(128, 128)
        self.l3 = nn.Linear(128, n_act)

    def forward(self, x):
        out = F.relu(self.l1(x))
        out = F.relu(self.l2(out))
        return self.l3(out)

#### Initialize memory, environment and neural network

In [None]:
memory = [] #memory for storing state, action, reward, next state and terminal values for every state

env = gym.make("CartPole-v1", render_mode="human")
state, info = env.reset()
n_obs, n_act = len(state), env.action_space.n

net = DQN(n_obs, n_act) #initialize the Deep Q-Network
loss = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

#### Define the parameters

In [None]:
gamma = 0.95 #discount rate
epsilon = 1.0 #exploration rate
eps_decay = 0.995 
batch_size = 128 #batch size for training after every episode
episodes = 1000
max_mem_size = 10000

#### Define prediction and training functions

In [None]:
def act(epsilon, state, episode):
    state = torch.from_numpy(state)
    if np.random.random() < epsilon and episode < 500:
        return env.action_space.sample()
    act_vals = net(state)
    return torch.argmax(act_vals).cpu().detach().numpy()

def replay(memory, batch_size, epsilon, eps_decay):
    indices = np.random.randint(len(memory), size=batch_size)
    for ind in indices:
        state, action, reward, next_state, done = memory[ind]
        target = reward
        if not done:
            target = (reward + gamma * torch.argmax(net(next_state)))
        target_f = net(state)
        target_f[action] = target
        optimizer.zero_grad()
        out = net(state)
        l = loss(out, target_f)
        l.backward()
        optimizer.step()
    epsilon *= eps_decay
    return epsilon

#### Execute the loop

In [None]:
for i in range(episodes):
    state, info = env.reset()
    done = False
    score = 0
    while not done:
        action = act(epsilon, state, i)
        next_state, reward, done, _, _ = env.step(action)
        score += float(reward)
        reward = reward if not done else -10
        memory.append((torch.from_numpy(state), action, reward, torch.from_numpy(next_state), done))
        if len(memory) >= max_mem_size:
            memory.pop(0)
        state = next_state
    print("Episode:", i, " Score:", score)
    if len(memory) > batch_size:
        epsilon = replay(memory, batch_size, epsilon, eps_decay)

In [None]:
env.close()