In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import gymnasium as gym
from collections import deque
import random

In [None]:
# Parameters
env_name = "CartPole-v1"
timesteps_before_update = 1000
updates_before_target_update = 5
epsilon = 0.1
gamma = 0.99
lr = 1e-2
buffer_size = 100_000
batch_fraction = 0.2  # 20%
seed = 42

In [None]:
# Environment and seeds
env = gym.make(env_name)
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)
obs_dim = env.observation_space.shape[0]
n_actions = env.action_space.n

In [None]:
class QNet(nn.Module):
    def __init__(self, obs_dim, n_actions):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, 128),
            nn.ReLU(),
            nn.Linear(128, n_actions)
        )

    def forward(self, x):
        return self.net(x)

def copy_weights(target, source):
    target.load_state_dict(source.state_dict())

In [None]:
class ReplayBuffer:
    def __init__(self, maxlen):
        self.buffer = deque(maxlen=maxlen)
    def add(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))
    def sample(self, fraction):
        n = int(len(self.buffer) * fraction)
        return random.sample(self.buffer, n)
    def __len__(self):
        return len(self.buffer)

buffer = ReplayBuffer(buffer_size)

In [None]:
# Main & target tet initialization
q_net = QNet(obs_dim, n_actions)
target_q_net = QNet(obs_dim, n_actions)
copy_weights(target_q_net, q_net)
optimizer = optim.Adam(q_net.parameters(), lr=lr)
loss_fn = nn.MSELoss()

In [None]:
# Training loop
global_step = 0
update_step = 0

state, _ = env.reset(seed=seed)
episode_reward = 0

print("Starting training. Interrupt this cell to stop.")

while True:
    # Epsilon-greedy action selection
    if np.random.rand() < epsilon:
        action = env.action_space.sample()
    else:
        with torch.no_grad():
            state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
            q_vals = q_net(state_tensor)
            action = q_vals.argmax(dim=1).item()

    # Step environment
    next_state, reward, terminated, truncated, _ = env.step(action)
    done = terminated or truncated
    buffer.add(state, action, reward, next_state, done)
    state = next_state
    episode_reward += reward
    global_step += 1

    if done:
        state, _ = env.reset()
        print(f"Episode finished, reward: {episode_reward}, steps: {global_step}")
        episode_reward = 0

    # Value iteration update every steps_before_update steps
    if len(buffer) >= 10 and (global_step % steps_before_update == 0):
        batch = buffer.sample(batch_fraction)
        states, actions, rewards, next_states, dones = zip(*batch)
        states = torch.tensor(np.array(states), dtype=torch.float32)
        actions = torch.tensor(actions, dtype=torch.long)
        rewards = torch.tensor(rewards, dtype=torch.float32)
        next_states = torch.tensor(np.array(next_states), dtype=torch.float32)
        dones = torch.tensor(dones, dtype=torch.bool)

        # Q(s,a)
        q_values = q_net(states).gather(1, actions.unsqueeze(1)).squeeze(1)
        # max_a' Q_target(s',a')
        with torch.no_grad():
            next_q_values = target_q_net(next_states).max(1)[0]
            targets = rewards + gamma * next_q_values * (~dones)

        loss = loss_fn(q_values, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        update_step += 1

    # Target network update
    if update_step > 0 and (update_step % updates_before_target_update == 0):
        copy_weights(target_q_net, q_net)
        print(f"Target network updated at step {global_step}")