In [10]:
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import gymnasium as gym
from gymnasium.wrappers import MaxAndSkipObservation
from gymnasium.wrappers import FrameStackObservation, ResizeObservation
from collections import deque
import matplotlib.pyplot as plt
import imageio
from IPython.display import Video
import ale_py

gym.register_envs(ale_py)

# Environment

In [11]:
def make_env(env_name):
    env = gym.make(env_name, obs_type='grayscale')
    env = MaxAndSkipObservation(env, skip=4) 
    env = ResizeObservation(env, (84, 84))
    env = FrameStackObservation(env, 4)
    return env

# Neural Network

In [13]:
class DQN(nn.Module):
    def __init__(self, num_actions):
        super().__init__()
        self.conv1 = nn.Conv2d(4, 32, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
        self.fc1 = nn.Linear(64*7*7, 512)
        self.fc2 = nn.Linear(512, num_actions)
    
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        return self.fc2(x)

In [14]:
# class DQN(nn.Module):
#     def __init__(self, input_shape, num_actions):
#         super(DQN, self).__init__()
#         c, h, w = input_shape  # (4, 84, 84) after transpose
#         assert h == 84 and w == 84, "Input must be 84x84"

#         self.conv = nn.Sequential(
#             nn.Conv2d(c, 32, kernel_size=8, stride=4),
#             nn.ReLU(),
#             nn.Conv2d(32, 64, kernel_size=4, stride=2),
#             nn.ReLU(),
#             nn.Conv2d(64, 64, kernel_size=3, stride=1),
#             nn.ReLU()
#         )

#         # compute conv output size dynamically
#         def conv2d_size_out(size, kernel_size, stride):
#             return (size - kernel_size) // stride + 1
#         convw = conv2d_size_out(conv2d_size_out(conv2d_size_out(w,8,4),4,2),3,1)
#         convh = conv2d_size_out(conv2d_size_out(conv2d_size_out(h,8,4),4,2),3,1)
#         linear_input_size = convw * convh * 64

#         self.fc = nn.Sequential(
#             nn.Linear(linear_input_size, 512),
#             nn.ReLU(),
#             nn.Linear(512, num_actions)
#         )

#     def forward(self, x):
#         x = x / 255.0  # normalize [0,255] → [0,1]
#         x = self.conv(x)
#         x = x.view(x.size(0), -1)
#         return self.fc(x)

# Replay Memory

In [15]:
class ReplayMemory:
    def __init__(self, capacity, device):
        self.memory = deque(maxlen=capacity)
        self.device = device  # 'cuda' or 'cpu'

    def push(self, state, action, reward, next_state, done):
        # Store raw (numpy or list) to save memory instead of storing tensors directly
        self.memory.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.memory, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)

        # Convert to tensors and pin memory for faster GPU transfer
        states = torch.tensor(np.array(states), dtype=torch.float32).pin_memory()
        next_states = torch.tensor(np.array(next_states), dtype=torch.float32).pin_memory()
        actions = torch.tensor(actions, dtype=torch.long).pin_memory()
        rewards = torch.tensor(rewards, dtype=torch.float32).pin_memory()
        dones = torch.tensor(dones, dtype=torch.float32).pin_memory()

        # Transfer to GPU asynchronously for speed
        return (
            states.to(self.device, non_blocking=True),
            actions.to(self.device, non_blocking=True),
            rewards.to(self.device, non_blocking=True),
            next_states.to(self.device, non_blocking=True),
            dones.to(self.device, non_blocking=True)
        )

    def __len__(self):
        return len(self.memory)


In [16]:
# class ReplayBuffer:
#     def __init__(self, capacity):
#         self.buffer = deque(maxlen=capacity)

#     def push(self, state, action, reward, next_state, done):
#         self.buffer.append((state, action, reward, next_state, done))

#     def sample(self, batch_size):
#         batch = random.sample(self.buffer, batch_size)
#         states, actions, rewards, next_states, dones = map(np.array, zip(*batch))

#         # convert to torch tensors
#         states = torch.tensor(states, dtype=torch.float32).permute(0,3,1,2)  # NHWC -> NCHW
#         actions = torch.tensor(actions, dtype=torch.long)
#         rewards = torch.tensor(rewards, dtype=torch.float32)
#         next_states = torch.tensor(next_states, dtype=torch.float32).permute(0,3,1,2)
#         dones = torch.tensor(dones, dtype=torch.float32)
#         return states, actions, rewards, next_states, dones

#     def __len__(self):
#         return len(self.buffer)


# Training Loop

In [17]:
def train_dqn(episodes=1000, batch_size=32, gamma=0.99, lr=1e-4, 
              epsilon_start=1.0, epsilon_end=0.1, epsilon_decay=200000, 
              target_update=1000, buffer_size=100000):
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    env = make_env("ALE/Tetris-v5")
    num_actions = env.action_space.n

    policy_net = DQN(num_actions).to(device)  # Main network
    target_net = DQN(num_actions).to(device)  # Target network (copy)
    target_net.load_state_dict(policy_net.state_dict())  # Sync weights initially
    optimizer = optim.Adam(policy_net.parameters(), lr=lr)
    memory = ReplayMemory(buffer_size, device=device)

    steps_done = 0
    epsilon = epsilon_start

    for ep in range(episodes):
        state, _ = env.reset()
        state = np.array(state)

        done = False
        total_reward = 0

        while not done:
            # Epsilon-greedy action
            epsilon = epsilon_end + (epsilon_start - epsilon_end) * \
                    np.exp(-1. * steps_done / epsilon_decay)

            if random.random() < epsilon:
                action = env.action_space.sample()
            else:
                s = torch.tensor(state, dtype=torch.float32).unsqueeze(0).cuda()
                with torch.no_grad():
                    action = policy_net(s).argmax(1).item()

            next_state, reward, terminated, truncated, _ = env.step(action)
            next_state = np.array(next_state)
            done = terminated or truncated
            
            shaped_reward = reward
            
            if reward == 0:
                shaped_reward = -0.01  # penalty for doing nothing

            if reward > 0:
                shaped_reward = reward * 2  # amplify reward for clearing lines

            memory.push(state, action, shaped_reward, next_state, done)
            state = next_state
            total_reward += reward
            steps_done += 1

            # Learn
            if len(memory) > batch_size:
                states, actions, rewards, next_states, dones = memory.sample(batch_size)

                states, next_states = states.cuda(), next_states.cuda()
                actions, rewards, dones = actions.cuda(), rewards.cuda(), dones.cuda()

                q_values = policy_net(states).gather(1, actions.unsqueeze(1)).squeeze(1)
                next_q_values = target_net(next_states).max(1)[0]
                targets = rewards + gamma * next_q_values * (1 - dones)

                loss = nn.MSELoss()(q_values, targets)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            # Update target net
            if steps_done % target_update == 0:
                target_net.load_state_dict(policy_net.state_dict())

        print(f"Episode {ep}, Reward: {total_reward:.2f}, Epsilon: {epsilon:.3f}")


In [18]:
q_net, rewards = train_dqn()

Episode 0, Reward: 0.00, Epsilon: 0.999
Episode 1, Reward: 0.00, Epsilon: 0.999
Episode 2, Reward: 0.00, Epsilon: 0.998
Episode 3, Reward: 0.00, Epsilon: 0.997
Episode 4, Reward: 0.00, Epsilon: 0.997
Episode 5, Reward: 0.00, Epsilon: 0.996
Episode 6, Reward: 0.00, Epsilon: 0.995
Episode 7, Reward: 0.00, Epsilon: 0.995
Episode 8, Reward: 0.00, Epsilon: 0.994
Episode 9, Reward: 0.00, Epsilon: 0.994
Episode 10, Reward: 0.00, Epsilon: 0.993
Episode 11, Reward: 0.00, Epsilon: 0.993
Episode 12, Reward: 0.00, Epsilon: 0.992
Episode 13, Reward: 0.00, Epsilon: 0.991
Episode 14, Reward: 0.00, Epsilon: 0.991
Episode 15, Reward: 0.00, Epsilon: 0.990
Episode 16, Reward: 0.00, Epsilon: 0.990
Episode 17, Reward: 0.00, Epsilon: 0.989
Episode 18, Reward: 0.00, Epsilon: 0.989
Episode 19, Reward: 0.00, Epsilon: 0.988
Episode 20, Reward: 0.00, Epsilon: 0.988
Episode 21, Reward: 0.00, Epsilon: 0.987
Episode 22, Reward: 0.00, Epsilon: 0.986
Episode 23, Reward: 0.00, Epsilon: 0.986
Episode 24, Reward: 0.00, 

KeyboardInterrupt: 