In [1]:
import gym, random
import numpy as np
import torch, os
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from collections import Counter
from collections import deque

Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.
Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.
Users of this version of Gym should be able to simply replace 'import gym' with 'import gymnasium as gym' in the vast majority of cases.
See the migration guide at https://gymnasium.farama.org/introduction/migration_guide/ for additional information.


In [2]:
env = gym.make("CartPole-v1", render_mode='rgb_array')
state, _ = env.reset()
print("Initial state:", state)
print("Observation space:", env.observation_space)
print("Action space:", env.action_space)

Initial state: [ 0.04262742 -0.02276057  0.01158579  0.04954371]
Observation space: Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32)
Action space: Discrete(2)


In [3]:
class QNetwork(nn.Module):
    def __init__(self, num_features=4, num_actions=2, hidden_features=128) -> None:
        super().__init__()
        self.fc1 = nn.Linear(in_features=num_features, out_features=hidden_features)
        self.fc2 = nn.Linear(in_features=hidden_features, out_features=hidden_features * 2)
        self.fc3 = nn.Linear(in_features=hidden_features * 2, out_features=num_actions)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [4]:
class ReplayBuffer():
    def __init__(self, max_buffer_size=10000, batch_size=16):
        self.buffer = deque(maxlen=max_buffer_size)
        self.batch_size = batch_size
        
    def __len__(self):
        return len(self.buffer)
    
    def add_sample(self, element: tuple):
        self.buffer.append(element)
    
    def get_batch(self):
        return random.sample(self.buffer, k=self.batch_size) if len(self.buffer) > self.batch_size else list(self.buffer)

In [5]:
N_BUFFER_SIZE = 10000 
N_TRAINING_STEPS = 50000
N_START_LEARNING = 5000
TARGET_UPDATE_FREQUENCY = 128 
LEARNING_RATE = 5e-4
GAMMA = 0.99
BATCH_SIZE = 128
EPSILON = 0.95
DECAY = 0.995 
MIN_EPSILON = 0.01
epsilon = 1.0

state_size = env.observation_space.shape[0]
action_size = env.action_space.n

q_net = QNetwork(num_features=state_size, num_actions=action_size)
target_q_net = QNetwork(num_features=state_size, num_actions=action_size)
buffer = ReplayBuffer(batch_size=BATCH_SIZE, max_buffer_size=N_BUFFER_SIZE)

target_q_net.load_state_dict(q_net.state_dict())
optimizer = torch.optim.AdamW(q_net.parameters(), lr=LEARNING_RATE)

In [6]:
def select_action(state, q_net, action_size, epsilon=0.5):
    if random.random() < epsilon:
        return random.randrange(action_size)
    else:
        state_tensor = torch.from_numpy(state).float().unsqueeze(0)
        with torch.no_grad():
            logits = q_net(state_tensor)
            return torch.argmax(logits, dim=-1).item()

In [7]:
state, _ = env.reset()
episode_reward = 0
episode = 0
losses = []

for step in range(N_TRAINING_STEPS):
    action = select_action(
        state=state,
        q_net=q_net,
        action_size=action_size,
        epsilon=epsilon
    )
    
    new_state, reward, terminated, truncated, info = env.step(action=action)
    done = truncated or terminated
    episode_reward += reward
    buffer.add_sample((state, action, reward, new_state, done))
    
    if step % 100 == 0 and step < N_BUFFER_SIZE:
        print(f"Buffer filled so far: {len(buffer)}")
    
    if step > N_START_LEARNING and len(buffer) > BATCH_SIZE:
        if step % 10 == 0:
            epsilon = max(MIN_EPSILON, epsilon * DECAY)
        q_net.train()
        target_q_net.eval()
        batch = buffer.get_batch()
        batch_len = len(batch)
        states = torch.zeros(batch_len, state_size)
        actions = torch.zeros(batch_len, 1, dtype=torch.int64)
        rewards = torch.zeros(batch_len, 1)
        next_states = torch.zeros(batch_len, state_size)
        dones = torch.zeros(batch_len, 1)
        
        for idx, i in enumerate(batch):
            states[idx] = torch.tensor(i[0], dtype=torch.float32)
            actions[idx] = torch.tensor(i[1], dtype=torch.int64)
            rewards[idx] = torch.tensor(i[2], dtype=torch.float32)
            next_states[idx] = torch.tensor(i[3], dtype=torch.float32)
            dones[idx] = torch.tensor(i[4], dtype=torch.float32)
        
        predicted_q_values = q_net(states)
        extracted_q_values = predicted_q_values.gather(1, actions).squeeze(1)
        
        with torch.no_grad():
            next_q_values = target_q_net(next_states)
            max_next_q = next_q_values.max(dim=1)[0]
            
        target = rewards.squeeze(1) + GAMMA * max_next_q * (1 - dones.squeeze(1))
        
        loss = F.smooth_l1_loss(extracted_q_values, target)
        optimizer.zero_grad()
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(q_net.parameters(), max_norm=1.0)
        
        optimizer.step()
        losses.append(loss.item())
        
        if step % TARGET_UPDATE_FREQUENCY == 0:
            target_q_net.load_state_dict(q_net.state_dict())
            avg_loss = np.mean(losses[-100:]) if len(losses) >= 100 else np.mean(losses)
            print(f"[STEP]: {step}, [LOSS]: {loss.item():.4f}, [AVG_LOSS]: {avg_loss:.4f}, [EPISODE]: {episode}, [REWARD]: {episode_reward}, [EPSILON]: {epsilon:.4f}")
            
    if done:
        state, _ = env.reset()
        episode += 1
        episode_reward = 0
    else:
        state = new_state

Buffer filled so far: 1
Buffer filled so far: 101
Buffer filled so far: 201
Buffer filled so far: 301
Buffer filled so far: 401
Buffer filled so far: 501
Buffer filled so far: 601
Buffer filled so far: 701
Buffer filled so far: 801
Buffer filled so far: 901
Buffer filled so far: 1001
Buffer filled so far: 1101
Buffer filled so far: 1201
Buffer filled so far: 1301
Buffer filled so far: 1401
Buffer filled so far: 1501
Buffer filled so far: 1601
Buffer filled so far: 1701
Buffer filled so far: 1801
Buffer filled so far: 1901
Buffer filled so far: 2001
Buffer filled so far: 2101
Buffer filled so far: 2201
Buffer filled so far: 2301
Buffer filled so far: 2401
Buffer filled so far: 2501
Buffer filled so far: 2601
Buffer filled so far: 2701
Buffer filled so far: 2801
Buffer filled so far: 2901
Buffer filled so far: 3001
Buffer filled so far: 3101
Buffer filled so far: 3201
Buffer filled so far: 3301
Buffer filled so far: 3401
Buffer filled so far: 3501
Buffer filled so far: 3601
Buffer filled

In [8]:
def record_video(q_net, video_folder=r"B:\Pytorch\RL\videos", episodes=1, epsilon=0.0, env_name='CartPole-v1'):
    env = gym.wrappers.RecordVideo(
        gym.make(env_name, render_mode="rgb_array"),
        video_folder=video_folder,
        name_prefix="cartpole_dqn"
    )

    for ep in range(episodes):
        state, _ = env.reset()
        done, truncated = False, False
        total_reward = 0

        while not (done or truncated):
            if np.random.rand() < epsilon:
                action = env.action_space.sample()
            else:
                with torch.no_grad():
                    state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
                    q_values = q_net(state_tensor)
                    action = q_values.argmax(dim=1).item()

            state, reward, done, truncated, _ = env.step(action)
            total_reward += reward

        print(f"[EPISODE {ep+1}] Reward: {total_reward}")

    env.close()
    print(f"Video saved in '{video_folder}'")


In [9]:
record_video(q_net)

  logger.warn(


MoviePy - Building video B:\Pytorch\RL\videos\cartpole_dqn-episode-0.mp4.
MoviePy - Writing video B:\Pytorch\RL\videos\cartpole_dqn-episode-0.mp4



                                                                          

MoviePy - Done !
MoviePy - video ready B:\Pytorch\RL\videos\cartpole_dqn-episode-0.mp4
[EPISODE 1] Reward: 500.0
Video saved in 'B:\Pytorch\RL\videos'




In [10]:
MODEL_FILE_PATH = r"B:\Pytorch\RL\models\dqn_cartpole.pth"
torch.save(q_net.state_dict(), MODEL_FILE_PATH)
print(f"Model state_dict saved to {MODEL_FILE_PATH}")

Model state_dict saved to B:\Pytorch\RL\models\dqn_cartpole.pth
