### FrameStack

Run this to have in memory, as importing FrameStack didn't work.

In [1]:
"""Wrapper that stacks frames."""
from collections import deque
from typing import Union

import numpy as np
# !pip install gymnasium[accept-rom-license]
# !pip install gymnasium[atari]==0.28.1 ale-py==0.8.1
# !pip install gymnasium[other]==0.28.1
import gymnasium as gym
from gymnasium.error import DependencyNotInstalled
from gymnasium.spaces import Box


class LazyFrames:
    """Ensures common frames are only stored once to optimize memory use.

    To further reduce the memory use, it is optionally to turn on lz4 to compress the observations.

    Note:
        This object should only be converted to numpy array just before forward pass.
    """

    __slots__ = ("frame_shape", "dtype", "shape", "lz4_compress", "_frames")

    def __init__(self, frames: list, lz4_compress: bool = False):
        """Lazyframe for a set of frames and if to apply lz4.

        Args:
            frames (list): The frames to convert to lazy frames
            lz4_compress (bool): Use lz4 to compress the frames internally

        Raises:
            DependencyNotInstalled: lz4 is not installed
        """
        self.frame_shape = tuple(frames[0].shape)
        self.shape = (len(frames),) + self.frame_shape
        self.dtype = frames[0].dtype
        if lz4_compress:
            try:
                from lz4.block import compress
            except ImportError as e:
                raise DependencyNotInstalled(
                    "lz4 is not installed, run `pip install gymnasium[other]`"
                ) from e

            frames = [compress(frame) for frame in frames]
        self._frames = frames
        self.lz4_compress = lz4_compress

    def __array__(self, dtype=None):
        """Gets a numpy array of stacked frames with specific dtype.

        Args:
            dtype: The dtype of the stacked frames

        Returns:
            The array of stacked frames with dtype
        """
        arr = self[:]
        if dtype is not None:
            return arr.astype(dtype)
        return arr

    def __len__(self):
        """Returns the number of frame stacks.

        Returns:
            The number of frame stacks
        """
        return self.shape[0]

    def __getitem__(self, int_or_slice: Union[int, slice]):
        """Gets the stacked frames for a particular index or slice.

        Args:
            int_or_slice: Index or slice to get items for

        Returns:
            np.stacked frames for the int or slice

        """
        if isinstance(int_or_slice, int):
            return self._check_decompress(self._frames[int_or_slice])  # single frame
        return np.stack(
            [self._check_decompress(f) for f in self._frames[int_or_slice]], axis=0
        )

    def __eq__(self, other):
        """Checks that the current frames are equal to the other object."""
        return self.__array__() == other

    def _check_decompress(self, frame):
        if self.lz4_compress:
            from lz4.block import decompress

            return np.frombuffer(decompress(frame), dtype=self.dtype).reshape(
                self.frame_shape
            )
        return frame


class FrameStack(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
    """Observation wrapper that stacks the observations in a rolling manner.

    For example, if the number of stacks is 4, then the returned observation contains
    the most recent 4 observations. For environment 'Pendulum-v1', the original observation
    is an array with shape [3], so if we stack 4 observations, the processed observation
    has shape [4, 3].

    Note:
        - To be memory efficient, the stacked observations are wrapped by :class:`LazyFrame`.
        - The observation space must be :class:`Box` type. If one uses :class:`Dict`
          as observation space, it should apply :class:`FlattenObservation` wrapper first.
        - After :meth:`reset` is called, the frame buffer will be filled with the initial observation.
          I.e. the observation returned by :meth:`reset` will consist of `num_stack` many identical frames.

    Example:
        >>> import gymnasium as gym
        >>> from gymnasium.wrappers import FrameStack
        >>> env = gym.make("CarRacing-v2")
        >>> env = FrameStack(env, 4)
        >>> env.observation_space
        Box(0, 255, (4, 96, 96, 3), uint8)
        >>> obs, _ = env.reset()
        >>> obs.shape
        (4, 96, 96, 3)
    """

    def __init__(
        self,
        env: gym.Env,
        num_stack: int,
        lz4_compress: bool = False,
    ):
        """Observation wrapper that stacks the observations in a rolling manner.

        Args:
            env (Env): The environment to apply the wrapper
            num_stack (int): The number of frames to stack
            lz4_compress (bool): Use lz4 to compress the frames internally
        """
        gym.utils.RecordConstructorArgs.__init__(
            self, num_stack=num_stack, lz4_compress=lz4_compress
        )
        gym.ObservationWrapper.__init__(self, env)

        self.num_stack = num_stack
        self.lz4_compress = lz4_compress

        self.frames = deque(maxlen=num_stack)

        low = np.repeat(self.observation_space.low[np.newaxis, ...], num_stack, axis=0)
        high = np.repeat(
            self.observation_space.high[np.newaxis, ...], num_stack, axis=0
        )
        self.observation_space = Box(
            low=low, high=high, dtype=self.observation_space.dtype
        )

    def observation(self, observation):
        """Converts the wrappers current frames to lazy frames.

        Args:
            observation: Ignored

        Returns:
            :class:`LazyFrames` object for the wrapper's frame buffer,  :attr:`self.frames`
        """
        assert len(self.frames) == self.num_stack, (len(self.frames), self.num_stack)
        return LazyFrames(list(self.frames), self.lz4_compress)

    def step(self, action):
        """Steps through the environment, appending the observation to the frame buffer.

        Args:
            action: The action to step through the environment with

        Returns:
            Stacked observations, reward, terminated, truncated, and information from the environment
        """
        observation, reward, terminated, truncated, info = self.env.step(action)
        self.frames.append(observation)
        return self.observation(None), reward, terminated, truncated, info

    def reset(self, **kwargs):
        """Reset the environment with kwargs.

        Args:
            **kwargs: The kwargs for the environment reset

        Returns:
            The stacked observations
        """
        obs, info = self.env.reset(**kwargs)

        [self.frames.append(obs) for _ in range(self.num_stack)]

        return self.observation(None), info

## Training DQN

This code didn't learn well, the score didn't improve during the 6000 episodes.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
%pip install gymnasium[accept-rom-license]
import gymnasium as gym
import ale_py
import numpy as np
import random
import threading
import os
import pickle
import time
from collections import deque
from queue import Queue
from gymnasium.wrappers.atari_preprocessing import AtariPreprocessing
from tqdm import tqdm

# Benchmark mode in CuDNN for potentially faster convolutions
torch.backends.cudnn.benchmark = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class DQN(nn.Module):
    def __init__(self, num_actions):
        super(DQN, self).__init__()
        # Convolutional layers
        self.conv1 = nn.Conv2d(4, 32, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
        self.flatten = nn.Flatten()

        # Fully connected layers
        self.fc1 = nn.Linear(64 * 7 * 7, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 64)
        self.fc4 = nn.Linear(64, num_actions)

        self.dropout = nn.Dropout(p=0.2)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = self.flatten(x)

        x = F.relu(self.fc1(x))
        x = self.dropout(x)

        x = F.relu(self.fc2(x))
        x = self.dropout(x)

        x = F.relu(self.fc3(x))
        x = self.dropout(x)

        x = self.fc4(x)
        return x

class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)
        self.lock = threading.Lock()
            
    def push(self, state, action, reward, next_state, done):
        with self.lock:
            self.buffer.append((state, action, reward, next_state, done))
                
    def sample(self, batch_size):
        with self.lock:
            return random.sample(self.buffer, batch_size)
                            
    def __len__(self):
        with self.lock:
            return len(self.buffer)


def prefetch_batches(memory, batch_size, prefetch_queue, stop_event):
    while not stop_event.is_set():
        if len(memory) >= batch_size:
            batch = memory.sample(batch_size)
            prefetch_queue.put(batch)
        else:
            stop_event.wait(0.1)


def save_checkpoint(model, optimizer, episode, epsilon, total_steps, checkpoint_path="checkpoint.pth"):
    print("Starting torch.save for checkpoint (no replay buffer)...")
    checkpoint = {
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "episode": episode,
        "epsilon": epsilon,
        "total_steps": total_steps
    }
    torch.save(checkpoint, checkpoint_path)
    print(f"Checkpoint saved at {checkpoint_path}")

def load_checkpoint(model, optimizer, memory, checkpoint_path="checkpoint.pth"):
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        episode = checkpoint["episode"]
        epsilon = checkpoint["epsilon"]
        total_steps = checkpoint["total_steps"]
        print(f"Checkpoint loaded from {checkpoint_path}")
        return episode, epsilon, total_steps
    else:
        print("No checkpoint found, starting from scratch.")
        return 0, 1.0, 0


def prefill_replay_buffer(env, memory, prefill_size):
    print(f"Prefilling replay buffer with {prefill_size} transitions...")
    state, _ = env.reset()
    for _ in tqdm(range(prefill_size), desc="Filling Replay Buffer"):
        action = env.action_space.sample()
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated

        state_array = np.array(state)
        next_state_array = np.array(next_state)

        # Kontrollera format: Om (H, W, C), transponera till (C, H, W)
        if state_array.shape == (84, 84, 4):
            state_array = state_array.transpose(2, 0, 1)
        if next_state_array.shape == (84, 84, 4):
            next_state_array = next_state_array.transpose(2, 0, 1)

        memory.push(state_array, action, reward, next_state_array, done)

        if done:
            state, _ = env.reset()
        else:
            state = next_state
    print("Replay buffer prefilled.")

num_actions = 6
batch_size = 16   
gamma = 0.99
epsilon_start = 1.0
epsilon_end = 0.1
epsilon_decay = 0.995
learning_rate = 1e-4
target_update_freq = 1000
memory_size = 100000
num_episodes = 10000
prefetch_queue_size = 20
save_frequency = 50
updates_per_step = 2  

env = gym.make("SpaceInvadersNoFrameskip-v4")
env = AtariPreprocessing(env)
env = FrameStack(env, 4)

policy_net = DQN(num_actions).to(device)
target_net = DQN(num_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer = optim.Adam(policy_net.parameters(), lr=learning_rate)

memory = ReplayBuffer(memory_size)

prefill_size = max(batch_size * 1000, batch_size)
prefill_replay_buffer(env, memory, prefill_size)

prefetch_queue = Queue(maxsize=prefetch_queue_size)
stop_event = threading.Event()
prefetch_thread = threading.Thread(target=prefetch_batches, args=(memory, batch_size, prefetch_queue, stop_event))
prefetch_thread.start()


start_episode, epsilon, total_steps = load_checkpoint(policy_net, optimizer, memory, "checkpoint.pth")

scaler = torch.cuda.amp.GradScaler()

try:
    for episode in range(start_episode, num_episodes):
        start_time = time.time()
        state, _ = env.reset()
        done = False
        total_reward = 0

        while not done:
            state_array = np.array(state)
            if state_array.shape == (84, 84, 4):
                state_array = state_array.transpose(2, 0, 1)

            if random.random() < epsilon:
                action = env.action_space.sample()
            else:
                state_tensor = torch.tensor(state_array, dtype=torch.float32).unsqueeze(0).to(device)
                with torch.no_grad():
                    action_values = policy_net(state_tensor)
                action = torch.argmax(action_values, dim=1).item()

            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            total_reward += reward

            next_state_array = np.array(next_state)
            if next_state_array.shape == (84, 84, 4):
                next_state_array = next_state_array.transpose(2, 0, 1)

            memory.push(state_array, action, reward, next_state_array, done)

            for _ in range(updates_per_step):
                if not prefetch_queue.empty():
                    batch = prefetch_queue.get()
                    prefetch_queue.task_done()

                    states, actions, rewards, next_states, dones = zip(*batch)
                    states = torch.tensor(np.array(states), dtype=torch.float32).to(device)
                    actions = torch.tensor(np.array(actions), dtype=torch.int64).unsqueeze(1).to(device)
                    rewards = torch.tensor(np.array(rewards), dtype=torch.float32).unsqueeze(1).to(device)
                    next_states = torch.tensor(np.array(next_states), dtype=torch.float32).to(device)
                    dones = torch.tensor(np.array(dones), dtype=torch.float32).unsqueeze(1).to(device)

                    # Double DQN logic
                    with torch.no_grad():
                        next_actions = policy_net(next_states).argmax(dim=1, keepdim=True)
                        next_q_values = target_net(next_states).gather(1, next_actions)
                    target_q_values = rewards + gamma * next_q_values * (1 - dones)

                    with torch.cuda.amp.autocast():
                        q_values = policy_net(states).gather(1, actions)
                        loss = F.smooth_l1_loss(q_values, target_q_values)

                    optimizer.zero_grad()
                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()

            epsilon = max(epsilon_end, epsilon_decay * epsilon)
            if total_steps % target_update_freq == 0:
                target_net.load_state_dict(policy_net.state_dict())

            state = next_state
            total_steps += 1

        if episode % 10 == 0:
            torch.cuda.empty_cache()

        if (episode + 1) % save_frequency == 0:
            save_checkpoint(policy_net, optimizer, episode, epsilon, total_steps)

        episode_time = time.time() - start_time
        print(f"Episode {episode+1}/{num_episodes}, Total Reward: {total_reward}, Epsilon: {epsilon:.2f}, Time: {episode_time:.2f}s")

except KeyboardInterrupt:
    print("Training interrupted by user.")
finally:
    stop_event.set()
    prefetch_thread.join()
    # Save final checkpoint without replay buffer
    save_checkpoint(policy_net, optimizer, num_episodes, epsilon, total_steps)
    torch.save(policy_net.state_dict(), 'dqn_space_invaders_final.pth')


Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.comNote: you may need to restart the kernel to use updated packages.





Prefilling replay buffer with 16000 transitions...


Filling Replay Buffer: 100%|██████████| 16000/16000 [00:32<00:00, 486.53it/s]
  scaler = torch.cuda.amp.GradScaler()


Replay buffer prefilled.
No checkpoint found, starting from scratch.


  with torch.cuda.amp.autocast():


Episode 1/10000, Total Reward: 80.0, Epsilon: 0.24, Time: 22.84s
Episode 2/10000, Total Reward: 155.0, Epsilon: 0.10, Time: 21.07s
Episode 3/10000, Total Reward: 110.0, Epsilon: 0.10, Time: 25.01s
Episode 4/10000, Total Reward: 210.0, Epsilon: 0.10, Time: 20.75s


# Loading the saved model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Number of actions in Space Invaders
num_actions = 6

class DQN(nn.Module):
    def __init__(self, num_actions):
        super(DQN, self).__init__()
        self.conv1 = nn.Conv2d(4, 32, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(64 * 7 * 7, 512)
        self.fc2 = nn.Linear(512, num_actions)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))   # (batch_size, 32, 20, 20)
        x = F.relu(self.conv2(x))   # (batch_size, 64, 9, 9)
        x = F.relu(self.conv3(x))   # (batch_size, 64, 7, 7)
        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Initialize the model and load the trained weights
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
agent = DQN(num_actions).to(device)
agent.load_state_dict(torch.load('dqn_space_invaders.pth', map_location=device))
agent.eval()  # Set the model to evaluation mode


  agent.load_state_dict(torch.load('dqn_space_invaders.pth', map_location=device))


DQN(
  (conv1): Conv2d(4, 32, kernel_size=(8, 8), stride=(4, 4))
  (conv2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
  (conv3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (fc1): Linear(in_features=3136, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=6, bias=True)
)

## Setting up env with rendering

In [None]:
import gymnasium as gym
import ale_py
from gymnasium.wrappers.atari_preprocessing import AtariPreprocessing

# Create the environment with rendering enabled
env = gym.make("SpaceInvadersNoFrameskip-v4", render_mode="human")
env = AtariPreprocessing(env)
env = FrameStack(env, num_stack=4)


## Running the agent in the env

In [None]:
import numpy as np
import time

state, _ = env.reset()
done = False

while not done:
    state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(device)
    
    with torch.no_grad():
        action_values = agent(state_tensor)
        action = torch.argmax(action_values, dim=1).item()
    
    next_state, reward, done, _, _ = env.step(action)
    
    state = next_state
    
    env.render()
    
env.close()


: 

In [None]:
# with recording

from gymnasium.wrappers import RecordVideo

env = gym.make("SpaceInvadersNoFrameskip-v4")
env = AtariPreprocessing(env)
env = FrameStack(env, 4)

# Wrap the environment to record videos
env = RecordVideo(env, video_folder='videos', episode_trigger=lambda x: True)

state, _ = env.reset()
done = False
total_reward = 0

while not done:
    state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(device)
    
    with torch.no_grad():
        # Get action from the trained model
        action_values = agent(state_tensor)
        action = torch.argmax(action_values, dim=1).item()
    
    next_state, reward, done, _, _ = env.step(action)
    total_reward += reward
    state = next_state

env.close()
print(f"Total Reward: {total_reward}")

## Show video

In [None]:
import os
import io
import base64
from IPython.display import HTML, display

# Function to display the video
def show_video():
    video_files = [f for f in os.listdir('./videos') if f.endswith('.mp4')]
    if len(video_files) > 0:
        video_path = os.path.join('./videos', video_files[-1])
        video = io.open(video_path, 'r+b').read()
        encoded = base64.b64encode(video)
        display(HTML(data='''
            <video width="480" height="320" controls>
                <source src="data:video/mp4;base64,{0}" type="video/mp4" />
            </video>
            '''.format(encoded.decode('ascii'))))
    else:
        print("No video found.")

# Display the recorded video
show_video()