In [24]:
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
from collections import deque
import matplotlib.pyplot as plt

In [25]:
env = gym.make("CartPole-v1")

In [26]:
class DQN(nn.Module):
  def __init__(self,input_dimension,output_dimension):
    super().__init__()
    self.model = nn.Sequential(
        nn.Linear(input_dimension, 128),
        nn.ReLU(),
        nn.Linear(128, 128),
        nn.ReLU(),
        nn.Linear(128, output_dimension),
    )

  def forward(self,x):
    return self.model(x)

In [27]:
learning_rate = 0.001
discount_factor = 0.99
exploration = 1.0
exploration_min = 0.01
exploration_decay = 0.995
batch_size = 64
target_update_freq = 1000
memory_size = 10000
epsiodes = 1000

In [28]:
input_dimension = env.observation_space.shape[0]
output_dimension = env.action_space.n

In [29]:
policy_network = DQN(input_dimension,output_dimension)
target_network = DQN(input_dimension,output_dimension)

In [30]:
target_network.load_state_dict(policy_network.state_dict())
target_network.eval()

DQN(
  (model): Sequential(
    (0): Linear(in_features=4, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=128, bias=True)
    (3): ReLU()
    (4): Linear(in_features=128, out_features=2, bias=True)
  )
)

In [31]:
optimizer = optim.Adam(policy_network.parameters(), lr=learning_rate)
memory = deque(maxlen=memory_size)

In [32]:
def select_action(state,epsilon):
  if random.random() < exploration:
    return env.action_space.sample()
  else:
    state = torch.FloatTensor(state).unsqueeze(0)
    with torch.no_grad():
      q_values = policy_network(state)

    return torch.argmax(q_values).item()

In [33]:
def optimize_model():
  if len(memory) < batch_size:
    return

  batch = random.sample(memory, batch_size)
  states, actions, rewards, next_states, dones = zip(*batch)

  states = torch.FloatTensor(states)
  actions = torch.LongTensor(actions).unsqueeze(1)
  rewards = torch.FloatTensor(rewards)
  next_states = torch.FloatTensor(next_states)
  dones = torch.FloatTensor(dones)

  q_values = policy_network(states).gather(1, actions)

  with torch.no_grad():
    max_next_q_values = target_network(next_states).max(1)[0]
    target_q_values = rewards + discount_factor * max_next_q_values * (1 - dones)

  loss = nn.MSELoss()(q_values, target_q_values)

  optimizer.zero_grad()
  loss.backward()
  optimizer.step()


In [34]:
rewards_per_episode = []
steps_done = 0

In [35]:
for episode in range(epsiodes):
  state, _ = env.reset()
  episode_reward = 0
  done = False

  while not done:
    action = select_action(state, exploration)

    next_state, reward, terminated, truncated, _ = env.step(action)

    done = terminated or truncated

    memory.append((state,action,reward,next_state,done))
    state = next_state
    episode_reward += reward

    optimize_model()

    if steps_done % target_update_freq == 0:
      target_network.load_state_dict(policy_network.state_dict())

    steps_done += 1

  exploration = max(exploration_min, exploration_decay * exploration)
  rewards_per_episode.append(episode_reward)


In [None]:
render_env = gym.make("CartPole-v1",render_mode="human")
state, _ = render_env.reset()
done = False

while not done:
  state_tensor = torch.FloatTensor(state).unsqueeze(0)

  with torch.no_grad():
    action = torch.argmax(policy_network(state_tensor)).item()

  state, reward, terminated, truncated, _ = render_env.step(action)
  done = truncated or terminated

render_env.close()