In [None]:
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import gymnasium as gym
from collections import deque
import matplotlib.pyplot as plt
import imageio
from IPython.display import Video
from tetris_gymnasium.envs.tetris import Tetris

### **Wrapper**

In [67]:
import gymnasium as gym
import numpy as np

class TetrisObsWrapper(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)

        # Define final observation space (4 channels, 24x18)
        self.observation_space = gym.spaces.Box(
            low=0,
            high=1,
            shape=(4, 24, 18),
            dtype=np.float32,
        )

    def observation(self, obs):
        board = obs["board"].astype(np.float32) / 9.0
        mask = obs["active_tetromino_mask"].astype(np.float32)
        holder = obs["holder"].astype(np.float32) / 9.0
        queue = obs["queue"].astype(np.float32) / 9.0

        # pad holder to (24, 18)
        holder_padded = np.zeros((24, 18), dtype=np.float32)
        holder_padded[:holder.shape[0], :holder.shape[1]] = holder

        # pad or crop queue safely to fit (24, 18)
        queue_padded = np.zeros((24, 18), dtype=np.float32)
        h, w = queue.shape
        queue_padded[:min(4, h), :min(18, w)] = queue[:min(4, h), :min(18, w)]

        # stack all channels
        stacked = np.stack([board, mask, holder_padded, queue_padded], axis=0)
        return stacked


### **Neural Network**

In [68]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DQN(nn.Module):
    def __init__(self, obs_shape, n_actions):
        super(DQN, self).__init__()
        C, H, W = obs_shape  # e.g. (4, 24, 18)

        # Safer kernel/stride combo for small inputs
        self.conv = nn.Sequential(
            nn.Conv2d(C, 32, kernel_size=3, stride=1, padding=1),  # keeps 24x18
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),  # halves size
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),  # halves again
            nn.ReLU(),
        )

        # Automatically compute flattened conv output size
        with torch.no_grad():
            dummy = torch.zeros(1, C, H, W)
            conv_out = self.conv(dummy)
            conv_out_size = conv_out.view(1, -1).size(1)

        # Fully connected layers
        self.fc = nn.Sequential(
            nn.Linear(conv_out_size, 256),
            nn.ReLU(),
            nn.Linear(256, n_actions)
        )

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)


### **ReplayBuffer**

In [69]:
class ReplayMemory:
    def __init__(self, capacity, device):
        self.memory = deque(maxlen=capacity)
        self.device = device  # 'cuda' or 'cpu'

    def push(self, state, action, reward, next_state, done):
        # Store raw (numpy or list) to save memory instead of storing tensors directly
        self.memory.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.memory, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)

        # Convert to tensors and pin memory for faster GPU transfer
        states = torch.tensor(np.array(states), dtype=torch.float32).pin_memory()
        next_states = torch.tensor(np.array(next_states), dtype=torch.float32).pin_memory()
        actions = torch.tensor(actions, dtype=torch.long).pin_memory()
        rewards = torch.tensor(rewards, dtype=torch.float32).pin_memory()
        dones = torch.tensor(dones, dtype=torch.float32).pin_memory()

        # Transfer to GPU asynchronously for speed
        return (
            states.to(self.device, non_blocking=True),
            actions.to(self.device, non_blocking=True),
            rewards.to(self.device, non_blocking=True),
            next_states.to(self.device, non_blocking=True),
            dones.to(self.device, non_blocking=True)
        )

    def __len__(self):
        return len(self.memory)

### **Training**

In [70]:
def train_dqn(episodes=50, batch_size=32, gamma=0.99, lr=1e-4, 
              epsilon_start=1.0, epsilon_end=0.1, epsilon_decay=5000, 
              target_update=1000, buffer_size=100000):
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    env = gym.make("tetris_gymnasium/Tetris", render_mode="ansi")
    env = TetrisObsWrapper(env)

    obs, _ = env.reset()
    obs_shape = obs.shape
    num_actions = env.action_space.n

    policy_net = DQN(obs_shape, num_actions).to(device)  # Main network
    target_net = DQN(obs_shape, num_actions).to(device)  # Target network (copy)
    target_net.load_state_dict(policy_net.state_dict())  # Sync weights initially
    optimizer = optim.Adam(policy_net.parameters(), lr=lr)
    memory = ReplayMemory(buffer_size, device=device)

    steps_done = 0
    epsilon = epsilon_start

    for ep in range(episodes):
        state, _ = env.reset()
        state = np.array(state)

        done = False
        rewards_history = []
        total_reward = 0

        while not done:
            # Epsilon-greedy action
            epsilon = epsilon_end + (epsilon_start - epsilon_end) * \
                    np.exp(-1. * steps_done / epsilon_decay)

            if random.random() < epsilon:
                action = env.action_space.sample()
            else:
                s = torch.tensor(state, dtype=torch.float32).unsqueeze(0).cuda()
                with torch.no_grad():
                    action = policy_net(s).argmax(1).item()

            next_state, reward, terminated, truncated, _ = env.step(action)
            next_state = np.array(next_state)
            done = terminated or truncated
            
            shaped_reward = reward
            
            if reward == 0:
                shaped_reward = -0.01  # penalty for doing nothing

            if reward > 0:
                shaped_reward = reward * 10  # amplify reward for clearing lines

            memory.push(state, action, shaped_reward, next_state, done)
            state = next_state
            total_reward += reward
            steps_done += 1

            # Learn
            if len(memory) > batch_size:
                states, actions, rewards, next_states, dones = memory.sample(batch_size)

                states, next_states = states.cuda(), next_states.cuda()
                actions, rewards, dones = actions.cuda(), rewards.cuda(), dones.cuda()

                q_values = policy_net(states).gather(1, actions.unsqueeze(1)).squeeze(1)
                next_q_values = target_net(next_states).max(1)[0]
                targets = rewards + gamma * next_q_values * (1 - dones)

                loss = nn.MSELoss()(q_values, targets)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            # Update target net
            if steps_done % target_update == 0:
                target_net.load_state_dict(policy_net.state_dict())

        rewards_history.append(total_reward)
        
        # if ep % 50 == 0:
        print(f"Episode {ep}, Reward: {total_reward}, Epsilon: {epsilon:.2f}")

    env.close()
    return policy_net, rewards_history


In [71]:
q_net, rewards = train_dqn()

Episode 0, Reward: 11, Epsilon: 0.98
Episode 1, Reward: 11, Epsilon: 0.97
Episode 2, Reward: 11, Epsilon: 0.96
Episode 3, Reward: 11, Epsilon: 0.95
Episode 4, Reward: 11, Epsilon: 0.93
Episode 5, Reward: 10, Epsilon: 0.92
Episode 6, Reward: 8, Epsilon: 0.92
Episode 7, Reward: 10, Epsilon: 0.91
Episode 8, Reward: 10, Epsilon: 0.90
Episode 9, Reward: 14, Epsilon: 0.89
Episode 10, Reward: 8, Epsilon: 0.89
Episode 11, Reward: 9, Epsilon: 0.88
Episode 12, Reward: 10, Epsilon: 0.88
Episode 13, Reward: 10, Epsilon: 0.87
Episode 14, Reward: 9, Epsilon: 0.86
Episode 15, Reward: 8, Epsilon: 0.86
Episode 16, Reward: 11, Epsilon: 0.85
Episode 17, Reward: 10, Epsilon: 0.84
Episode 18, Reward: 7, Epsilon: 0.84
Episode 19, Reward: 9, Epsilon: 0.83
Episode 20, Reward: 9, Epsilon: 0.83
Episode 21, Reward: 8, Epsilon: 0.83
Episode 22, Reward: 10, Epsilon: 0.82
Episode 23, Reward: 10, Epsilon: 0.81
Episode 24, Reward: 9, Epsilon: 0.81
Episode 25, Reward: 9, Epsilon: 0.80
Episode 26, Reward: 9, Epsilon: 0