In [3]:
!pip install "gymnasium[atari]" ale-py opencv-python imageio




In [4]:
# Imports & basic setup

import random
from collections import deque
from typing import Deque, Tuple, List

import numpy as np
import imageio.v2 as imageio
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim

import gymnasium as gym
import ale_py
import cv2

# registriramo ALE env-e
gym.register_envs(ale_py)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


class ReplayBuffer:
    def __init__(self, capacity: int):
        self.buffer: Deque[Tuple[np.ndarray, int, float, np.ndarray, bool]] = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        # state, next_state: (4, 84, 84)
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size: int):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)

        states      = np.array(states, dtype=np.float32)
        next_states = np.array(next_states, dtype=np.float32)
        actions     = np.array(actions, dtype=np.int64)
        rewards     = np.array(rewards, dtype=np.float32)
        dones       = np.array(dones, dtype=np.float32)

        return states, actions, rewards, next_states, dones

    def __len__(self):
        return len(self.buffer)


Using device: cpu


In [5]:
class QNetworkCNN(nn.Module):
    def __init__(self, num_actions: int):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(4, 32, kernel_size=8, stride=4),  # 4x84x84 -> 32x20x20
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2), # 64x9x9
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1), # 64x7x7
            nn.ReLU(),
        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 512),  # 3136 -> 512
            nn.ReLU(),
            nn.Linear(512, num_actions),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x shape: (batch, 4, 84, 84)
        x = self.conv(x)
        x = self.fc(x)
        return x


In [6]:
class DQNAgent:
    def __init__(
        self,
        num_actions: int,
        gamma: float = 0.99,
        lr: float = 1e-4,
        batch_size: int = 32,
        buffer_capacity: int = 200_000,
        epsilon_start: float = 1.0,
        epsilon_end: float = 0.01,
        epsilon_decay: float = 0.01,  # faktor za eksponentni decay
        target_update_every: int = 10_000,  # v korakih
        ddqn: bool = False,
    ):
        self.num_actions = num_actions
        self.gamma = gamma
        self.batch_size = batch_size
        self.ddqn = ddqn

        self.q_net = QNetworkCNN(num_actions).to(device)
        self.target_net = QNetworkCNN(num_actions).to(device)
        self.target_net.load_state_dict(self.q_net.state_dict())
        self.target_net.eval()

        self.optimizer = optim.Adam(self.q_net.parameters(), lr=lr)

        self.replay_buffer = ReplayBuffer(capacity=buffer_capacity)

        self.epsilon = epsilon_start
        self.epsilon_start = epsilon_start
        self.epsilon_end = epsilon_end
        self.epsilon_decay = epsilon_decay

        self.target_update_every = target_update_every
        self.train_steps = 0

    def select_action(self, state: np.ndarray) -> int:
        # state: (4, 84, 84)
        if random.random() < self.epsilon:
            return random.randrange(self.num_actions)

        state_t = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
        with torch.no_grad():
            q_values = self.q_net(state_t)
        return q_values.argmax(dim=1).item()

    def update_epsilon(self):
        # eksponenten decay: eps = eps_end + (eps_start - eps_end)*exp(-decay * t)
        self.epsilon = max(
            self.epsilon_end,
            self.epsilon_end + (self.epsilon_start - self.epsilon_end) * np.exp(-self.epsilon_decay * self.train_steps),
        )

    def push(self, *transition):
        self.replay_buffer.push(*transition)

    def can_learn(self):
        return len(self.replay_buffer) >= self.batch_size

    def learn(self):
        if not self.can_learn():
            return None

        states, actions, rewards, next_states, dones = self.replay_buffer.sample(self.batch_size)

        states_t      = torch.tensor(states, dtype=torch.float32, device=device)
        next_states_t = torch.tensor(next_states, dtype=torch.float32, device=device)
        actions_t     = torch.tensor(actions, dtype=torch.int64, device=device)
        rewards_t     = torch.tensor(rewards, dtype=torch.float32, device=device)
        dones_t       = torch.tensor(dones, dtype=torch.float32, device=device)

        # Q(s,a)
        q_values = self.q_net(states_t).gather(1, actions_t.unsqueeze(1)).squeeze(1)

        with torch.no_grad():
            if self.ddqn:
                # DDQN: akcije iz online mreže, vrednosti iz target mreže
                next_q_online = self.q_net(next_states_t)
                next_actions = next_q_online.argmax(dim=1)  # (batch,)
                next_q_target = self.target_net(next_states_t).gather(
                    1, next_actions.unsqueeze(1)
                ).squeeze(1)
                target_q = rewards_t + self.gamma * next_q_target * (1.0 - dones_t)
            else:
                # DQN: max_a Q_target(s', a)
                next_q_values = self.target_net(next_states_t).max(dim=1)[0]
                target_q = rewards_t + self.gamma * next_q_values * (1.0 - dones_t)

        loss = nn.MSELoss()(q_values, target_q)

        self.optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(self.q_net.parameters(), 10.0)
        self.optimizer.step()

        self.train_steps += 1
        self.update_epsilon()

        # posodobitev target mreže
        if self.train_steps % self.target_update_every == 0:
            self.target_net.load_state_dict(self.q_net.state_dict())

        return loss.item()

    def save(self, path: str):
        torch.save(self.q_net.state_dict(), path)

    def load(self, path: str):
        self.q_net.load_state_dict(torch.load(path, map_location=device))
        self.target_net.load_state_dict(self.q_net.state_dict())


In [7]:
# Preprocessing and frame stacking for Pong

def preprocess_frame(frame: np.ndarray) -> np.ndarray:
    """
    frame: (H, W, 3) uint8, RGB
    return: (84, 84) float32, [0,1]
    """
    # convert to grayscale
    frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
    # resize to 84x84
    frame = cv2.resize(frame, (84, 84), interpolation=cv2.INTER_AREA)
    frame = frame.astype(np.float32) / 255.0
    return frame


class FrameStackEnv:
    """
    Wraps Atari env:
    - preprocess to 84x84 grayscale
    - stack last 4 frames into (4,84,84) state
    """

    def __init__(self, env: gym.Env, k: int = 4):
        self.env = env
        self.k = k
        self.frames: Deque[np.ndarray] = deque(maxlen=k)

        # action space is the same
        self.action_space = env.action_space

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        frame = preprocess_frame(obs)
        self.frames.clear()
        for _ in range(self.k):
            self.frames.append(frame)
        state = np.stack(self.frames, axis=0)  # (4,84,84)
        return state, info

    def step(self, action: int):
        obs, reward, terminated, truncated, info = self.env.step(action)
        done = terminated or truncated
        frame = preprocess_frame(obs)
        self.frames.append(frame)
        state = np.stack(self.frames, axis=0)
        return state, reward, done, info

    def close(self):
        self.env.close()


def make_pong_env(render_mode=None, repeat_action_probability=0.25, frameskip=4):
    """
    ALE/Pong-v5, po farama dokumentaciji.
    """
    env = gym.make(
        "ALE/Pong-v5",
        obs_type="rgb",
        frameskip=frameskip,
        repeat_action_probability=repeat_action_probability,
        render_mode=render_mode,
    )
    env = FrameStackEnv(env, k=4)
    return env


In [8]:
def run_pong(
    ddqn: bool,
    num_episodes: int = 1000,
    max_steps_per_episode: int = 18_000,
    solved_score: float = 18.0,
):
    # env brez renderja za trening
    env = make_pong_env(render_mode=None)

    num_actions = env.action_space.n
    print("Num actions:", num_actions)

    agent = DQNAgent(
        num_actions=num_actions,
        gamma=0.99,
        lr=1e-4,
        batch_size=32,
        buffer_capacity=200_000,
        epsilon_start=1.0,
        epsilon_end=0.01,
        epsilon_decay=0.01,
        target_update_every=10_000,
        ddqn=ddqn,
    )

    episode_rewards: List[float] = []
    moving_avg: List[float] = []

    global_step = 0

    for episode in range(1, num_episodes + 1):
        state, info = env.reset()
        episode_reward = 0.0

        for t in range(max_steps_per_episode):
            global_step += 1
            action = agent.select_action(state)
            next_state, reward, done, info_step = env.step(action)

            agent.push(state, action, reward, next_state, done)
            state = next_state
            episode_reward += reward

            loss = agent.learn()

            if done:
                break

        episode_rewards.append(episode_reward)
        if len(episode_rewards) >= 20:
            avg = np.mean(episode_rewards[-20:])
        else:
            avg = np.mean(episode_rewards)
        moving_avg.append(avg)

        if episode % 10 == 0 or episode == 1:
            print(
                f"Episode {episode}/{num_episodes} | "
                f"Reward: {episode_reward:.1f} | "
                f"Avg(20): {avg:.2f} | "
                f"Epsilon: {agent.epsilon:.3f}"
            )

        # opcijski stop pogoji
        if avg >= solved_score and episode >= 100:
            print(f"Solved Pong with avg reward {avg:.2f} at episode {episode}")
            break

    env.close()
    return agent, episode_rewards, moving_avg


def plot_rewards(rewards: List[float], moving_avg: List[float], title: str):
    plt.figure(figsize=(10, 5))
    plt.plot(rewards, label="Reward per episode")
    plt.plot(moving_avg, label="Moving average (last 20)")
    plt.xlabel("Episode")
    plt.ylabel("Reward")
    plt.title(title)
    plt.legend()
    plt.grid(True)
    plt.show()


In [None]:
# DQN on Pong
dqn_agent, dqn_rewards, dqn_mavg = run_pong(
    ddqn=False,
    num_episodes=1000,
    max_steps_per_episode=18_000,
    solved_score=18.0,
)
plot_rewards(dqn_rewards, dqn_mavg, "DQN on Pong")

# Shrani model (za GIF)
dqn_agent.save("V4.4-Viktor-Rackov-DQN-Pong.pt")
np.save("V4.4-Viktor-Rackov-DQN-Pong-rewards.npy", np.array(dqn_rewards))

# DDQN on Pong
ddqn_agent, ddqn_rewards, ddqn_mavg = run_pong(
    ddqn=True,
    num_episodes=1000,
    max_steps_per_episode=18_000,
    solved_score=18.0,
)
plot_rewards(ddqn_rewards, ddqn_mavg, "DDQN on Pong")

ddqn_agent.save("V4.4-Viktor-Rackov-DDQN-Pong.pt")
np.save("V4.4-Viktor-Rackov-DDQN-Pong-rewards.npy", np.array(ddqn_rewards))


Num actions: 6


A.L.E: Arcade Learning Environment (version 0.11.2+ecc1138)
[Powered by Stella]


Episode 1/1000 | Reward: -21.0 | Avg(20): -21.00 | Epsilon: 0.010


In [None]:
def make_gif(
    agent: DQNAgent,
    filename: str,
    max_steps_per_episode: int = 18_000,
    fps: int = 30,
):
    env = make_pong_env(render_mode="rgb_array")

    frames = []
    state, info = env.reset()
    done = False

    while not done and len(frames) < max_steps_per_episode:
        # deterministična politika (brez epsilona)
        state_t = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
        with torch.no_grad():
            q_values = agent.q_net(state_t)
        action = q_values.argmax(dim=1).item()

        # render frame BEFORE step (ali po desire)
        raw_env = env.env.env  # FrameStackEnv.env -> ALE env
        frame = raw_env.render()
        frames.append(frame)

        next_state, reward, done, info_step = env.step(action)
        state = next_state

    env.close()

    # save gif
    imageio.mimsave(filename, frames, fps=fps)
    print(f"Saved GIF to {filename}")


# Primer: naredi GIF za DDQN agenta (če ti je ta boljši)
make_gif(
    ddqn_agent,
    filename="V4.4-Viktor-Rackov-Pong-DDQN.gif",
    max_steps_per_episode=18_000,
    fps=30,
)


In [None]:
make_gif(
    dqn_agent,
    filename="V4.4-Viktor-Rackov-Pong-DQN.gif",
    max_steps_per_episode=18_000,
    fps=30,
)
