In [None]:
# 🧠 RL Algorithm Comparison: DQN vs PPO vs A2C

import os, random
import numpy as np
import pygame
import matplotlib.pyplot as plt
from collections import deque, namedtuple
from itertools import count

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# Environment setup
class SnakeEnv:
    def __init__(self, width=10, height=10, block_size=20):
        pygame.init()
        self.width, self.height, self.block = width, height, block_size
        self.display = pygame.Surface((self.width * self.block, self.height * self.block))
        self.clock = pygame.time.Clock()
        self.reset()

    def reset(self, seed=None):
        if seed is not None:
            random.seed(seed)
            np.random.seed(seed)
        self.direction = (1, 0)
        self.snake = [(self.width//2, self.height//2)]
        self._place_food()
        self.done, self.score = False, 0
        return self._get_obs()

    def _place_food(self):
        while True:
            self.food = (random.randrange(self.width), random.randrange(self.height))
            if self.food not in self.snake:
                break

    def _get_obs(self):
        head_x, head_y = self.snake[0]
        dir_x, dir_y = self.direction
        food_x, food_y = self.food

        def danger_at(offset):
            dx, dy = offset
            new_x, new_y = head_x + dx, head_y + dy
            return int(
                new_x < 0 or new_x >= self.width or
                new_y < 0 or new_y >= self.height or
                (new_x, new_y) in self.snake
            )

        left  = (-dir_y, dir_x)
        right = (dir_y, -dir_x)
        front = (dir_x, dir_y)

        danger = [danger_at(front), danger_at(right), danger_at(left)]

        food_dx = int(np.sign(food_x - head_x))
        food_dy = int(np.sign(food_y - head_y))

        dir_features = [int(dir_x == 1), int(dir_x == -1), int(dir_y == 1), int(dir_y == -1)]

        return np.array(danger + dir_features + [food_dx, food_dy], dtype=np.float32)

    def step(self, action):
        dirs = [(-1,0),(0,1),(1,0),(0,-1)]
        new_dir = dirs[action]
        if (new_dir[0]==-self.direction[0] and new_dir[1]==-self.direction[1]):
            new_dir = self.direction
        self.direction = new_dir

        head = (self.snake[0][0]+new_dir[0], self.snake[0][1]+new_dir[1])
        if (not 0<=head[0]<self.width or not 0<=head[1]<self.height or head in self.snake):
            self.done = True
            return self._get_obs(), -10, True, {}

        self.snake.insert(0, head)
        if head == self.food:
            reward, self.score = 10, self.score+1
            self._place_food()
        else:
            reward = -0.1
            self.snake.pop()

        return self._get_obs(), reward, False, {}

    def render(self):
        H_px = self.height * self.block
        W_px = self.width  * self.block
        frame = np.zeros((H_px, W_px, 3), dtype=np.uint8)

        for x, y in self.snake:
            y0, y1 = y*self.block, (y+1)*self.block
            x0, x1 = x*self.block, (x+1)*self.block
            frame[y0:y1, x0:x1] = [0,255,0]

        fx, fy = self.food
        y0, y1 = fy*self.block, (fy+1)*self.block
        x0, x1 = fx*self.block, (fx+1)*self.block
        frame[y0:y1, x0:x1] = [255,0,0]

        return frame

# 🧱 Models
class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.policy = nn.Sequential(nn.Linear(state_dim, 128), nn.ReLU(), nn.Linear(128, action_dim), nn.Softmax(dim=-1))
        self.value = nn.Sequential(nn.Linear(state_dim, 128), nn.ReLU(), nn.Linear(128, 1))

    def forward(self, x):
        return self.policy(x), self.value(x)

class DQN(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(state_dim, 128), nn.ReLU(), nn.Linear(128, 128), nn.ReLU(), nn.Linear(128, action_dim))

    def forward(self, x):
        return self.net(x)

# 🛠️ Placeholder Training Loops

def train_dqn(env, policy_net, target_net, memory, optimizer, config):
    episode_rewards, env_steps_list = [], []
    total_env_steps = 0
    for ep in range(config["num_episodes"]):
        state = env.reset()
        total_reward, done = 0, False
        while not done:
            action = np.random.randint(4)
            next_state, reward, done, _ = env.step(action)
            total_reward += reward
            total_env_steps += 1
        episode_rewards.append(total_reward)
        env_steps_list.append(total_env_steps)
    return env_steps_list, episode_rewards

def train_ppo(env, model, optimizer, config):
    total_env_steps, env_steps_list, rewards = 0, [], []
    for i in range(config["num_updates"]):
        total_env_steps += config["update_steps"]
        avg_reward = np.random.randint(0, 100)
        rewards.append(avg_reward)
        env_steps_list.append(total_env_steps)
    return env_steps_list, rewards

def train_a2c(env, model, optimizer, config):
    total_env_steps, env_steps_list, rewards = 0, [], []
    for ep in range(config["num_episodes"]):
        state = env.reset()
        total_reward, done = 0, False
        while not done:
            action = np.random.randint(4)
            next_state, reward, done, _ = env.step(action)
            total_reward += reward
            total_env_steps += 1
        rewards.append(total_reward)
        env_steps_list.append(total_env_steps)
    return env_steps_list, rewards

# 📊 Final Plot

def smooth(y, window=10):
    return np.convolve(y, np.ones(window)/window, mode='valid')

def uniform_plot(dqn_steps, dqn_rewards, ppo_steps, ppo_rewards, a2c_steps, a2c_rewards):
    max_step = max(dqn_steps[-1], ppo_steps[-1], a2c_steps[-1])
    common_x = np.linspace(0, max_step, 300)

    def prep_interp(steps, rewards):
        smoothed = smooth(rewards)
        steps_trimmed = steps[-len(smoothed):]
        return steps_trimmed, smoothed

    dqn_x, dqn_y = prep_interp(dqn_steps, dqn_rewards)
    ppo_x, ppo_y = prep_interp(ppo_steps, ppo_rewards)
    a2c_x, a2c_y = prep_interp(a2c_steps, a2c_rewards)

    plt.figure(figsize=(10, 5))
    plt.plot(common_x, np.interp(common_x, dqn_x, dqn_y), label="DQN")
    plt.plot(common_x, np.interp(common_x, ppo_x, ppo_y), label="PPO")
    plt.plot(common_x, np.interp(common_x, a2c_x, a2c_y), label="A2C")
    plt.xlabel("Environment Steps")
    plt.ylabel("Smoothed Reward")
    plt.title("Fair Comparison of RL Algorithms")
    plt.grid(True)
    plt.legend()
    plt.show()

# 🚀 Launch Experiments

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = {"num_episodes": 200, "num_updates": 200, "update_steps": 1024}

env_dqn = SnakeEnv()
env_ppo = SnakeEnv()
env_a2c = SnakeEnv()

state_dim = len(env_dqn._get_obs())
action_dim = 4

print("Training DQN...")
dqn_steps, dqn_rewards = train_dqn(env_dqn, None, None, None, None, config)

print("Training PPO...")
ppo_steps, ppo_rewards = train_ppo(env_ppo, None, None, config)

print("Training A2C...")
a2c_steps, a2c_rewards = train_a2c(env_a2c, None, None, config)

uniform_plot(dqn_steps, dqn_rewards, ppo_steps, ppo_rewards, a2c_steps, a2c_rewards)


Training DQN...
Training PPO...
Training A2C...


ValueError: fp and xp are not of the same length.