In [1]:
import pygame, random, math, time, os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque, namedtuple

  from pkg_resources import resource_stream, resource_exists


pygame 2.6.1 (SDL 2.28.4, Python 3.13.9)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [None]:
SCREEN_W, SCREEN_H = 600, 400
ROBOT_R = 10
GOAL_R = 12
NUM_OBS = 8
MAX_EPISODE_STEPS = 200
GAMMA = 0.99
BATCH_SIZE = 64
LR = 1e-3
BUFFER_SIZE = 20000
MIN_REPLAY = 500
EPS_START, EPS_END, EPS_DECAY = 1.0, 0.05, 50000
TRAIN_EPISODES_PER_LOOP = 4
MODEL_PATH = "dqn_nav.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class NavEnv:
    def __init__(self):
        self.w, self.h = SCREEN_W, SCREEN_H
        self.reset()

    def reset(self):
        self.robot = np.array([random.choice([40, self.w-40]), random.choice([40, self.h-40])], dtype=np.float32)
        
        while True:
            self.goal = np.array([random.uniform(60, self.w-60), random.uniform(60, self.h-60)], dtype=np.float32)
            if np.linalg.norm(self.goal - self.robot) > 120:
                break
        
        self.obs = []
        for _ in range(NUM_OBS):
            ow, oh = random.randint(30,80), random.randint(20,60)
            ox = random.uniform(60, self.w-60-ow)
            oy = random.uniform(60, self.h-60-oh)
            self.obs.append((ox,oy,ow,oh))
        self.t = 0
        return self._get_obs()

    def _get_obs(self):
        obs = [self.robot[0]/self.w, self.robot[1]/self.h, self.goal[0]/self.w, self.goal[1]/self.h]
        for (ox,oy,ow,oh) in self.obs:
            cx = (ox+ow/2)/self.w
            cy = (oy+oh/2)/self.h
            obs += [cx, cy, ow/self.w, oh/self.h]
        return np.array(obs, dtype=np.float32)

    def step(self, action):
        step_size = 4.5
        if action==0: self.robot[1] -= step_size
        elif action==1: self.robot[1] += step_size
        elif action==2: self.robot[0] -= step_size
        elif action==3: self.robot[0] += step_size
        self.robot = np.clip(self.robot, [ROBOT_R, ROBOT_R], [self.w-ROBOT_R, self.h-ROBOT_R])
        self.t += 1
        done = False
        reward = -0.01
        for (ox,oy,ow,oh) in self.obs:
            if ox <= self.robot[0] <= ox+ow and oy <= self.robot[1] <= oy+oh:
                reward -= 1.0
                done = True
        if np.linalg.norm(self.robot - self.goal) < (ROBOT_R + GOAL_R):
            reward += 2.0
            done = True
        if self.t >= MAX_EPISODE_STEPS:
            done = True
        return self._get_obs(), reward, done, {}

Transition = namedtuple('Transition', ('s','a','r','s2','d'))
class ReplayBuffer:
    def __init__(self, capacity=BUFFER_SIZE):
        self.buffer = deque(maxlen=capacity)
    def push(self, *args): self.buffer.append(Transition(*args))
    def sample(self, n):
        batch = random.sample(self.buffer, n)
        return Transition(*zip(*batch))
    def __len__(self): return len(self.buffer)

class QNet(nn.Module):
    def __init__(self, input_dim, n_actions):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim,128),
            nn.ReLU(),
            nn.Linear(128,128),
            nn.ReLU(),
            nn.Linear(128,n_actions)
        )
    def forward(self,x): return self.net(x)

class Agent:
    def __init__(self, obs_dim, n_actions):
        self.n_actions = n_actions
        self.policy = QNet(obs_dim, n_actions).to(DEVICE)
        self.target = QNet(obs_dim, n_actions).to(DEVICE)
        self.target.load_state_dict(self.policy.state_dict())
        self.opt = optim.Adam(self.policy.parameters(), lr=LR)
        self.buffer = ReplayBuffer()
        self.steps = 0
        self.eps = EPS_START

    def act(self, obs, eval=False):
        self.steps += 1
        if not eval and random.random() < self.eps:
            return random.randrange(self.n_actions)
        with torch.no_grad():
            x = torch.tensor(obs, dtype=torch.float32, device=DEVICE).unsqueeze(0)
            q = self.policy(x)
            return int(q.argmax().item())

    def push(self, *transition): self.buffer.push(*transition)

    def update(self):
        if len(self.buffer) < MIN_REPLAY:
            return 0.0
        batch = self.buffer.sample(BATCH_SIZE)
        s = torch.tensor(np.array(batch.s), dtype=torch.float32, device=DEVICE)
        a = torch.tensor(batch.a, dtype=torch.int64, device=DEVICE).unsqueeze(1)
        r = torch.tensor(batch.r, dtype=torch.float32, device=DEVICE).unsqueeze(1)
        s2 = torch.tensor(np.array(batch.s2), dtype=torch.float32, device=DEVICE)
        d = torch.tensor(batch.d, dtype=torch.float32, device=DEVICE).unsqueeze(1)
        q = self.policy(s).gather(1,a)
        with torch.no_grad():
            q2 = self.target(s2).max(1)[0].unsqueeze(1)
            target = r + (1-d)*GAMMA*q2
        loss = nn.functional.mse_loss(q, target)
        self.opt.zero_grad(); loss.backward(); self.opt.step()
        # epsilon decay
        self.eps = max(EPS_END, EPS_START - self.steps/ EPS_DECAY)
        # soft update
        for p, tp in zip(self.policy.parameters(), self.target.parameters()):
            tp.data.copy_(0.995*tp.data + 0.005*p.data)
        return loss.item()

    def save(self, path=MODEL_PATH):
        torch.save(self.policy.state_dict(), path)

    def load(self, path=MODEL_PATH):
        if os.path.exists(path):
            self.policy.load_state_dict(torch.load(path, map_location=DEVICE))
            self.target.load_state_dict(self.policy.state_dict())
            print("Model loaded:", path)
        else:
            print("No model found at", path)

def draw_env(screen, env):
    screen.fill((30,30,30))
    # obstacles
    for (ox,oy,ow,oh) in env.obs:
        pygame.draw.rect(screen,(120,60,60), pygame.Rect(int(ox),int(oy),int(ow),int(oh)))
    # goal
    pygame.draw.circle(screen, (50,200,50), (int(env.goal[0]), int(env.goal[1])), GOAL_R)
    # robot
    pygame.draw.circle(screen, (200,200,50), (int(env.robot[0]), int(env.robot[1])), ROBOT_R)


In [3]:
def main():
    pygame.init()
    screen = pygame.display.set_mode((SCREEN_W, SCREEN_H))
    pygame.display.set_caption("RL Nav (DQN) - press t to train, p play, s save, l load, r reset")
    clock = pygame.time.Clock()
    env = NavEnv()
    obs0 = env.reset()
    agent = Agent(obs0.shape[0], 4)
    mode = "idle"  # 'train', 'play'
    running = True
    train_losses = []
    episode = 0

    while running:
        for event in pygame.event.get():
            if event.type==pygame.QUIT:
                running=False
            elif event.type==pygame.KEYDOWN:
                if event.key==pygame.K_t:
                    mode = "train" if mode!="train" else "idle"
                    print("Mode:", mode)
                elif event.key==pygame.K_p:
                    mode = "play" if mode!="play" else "idle"
                    print("Mode:", mode)
                elif event.key==pygame.K_r:
                    obs0 = env.reset()
                    episode = 0
                    print("Reset")
                elif event.key==pygame.K_s:
                    agent.save()
                    print("Saved model.")
                elif event.key==pygame.K_l:
                    agent.load()
        if mode=="train":
            for _ in range(TRAIN_EPISODES_PER_LOOP):
                s = env.reset()
                done=False
                ep_reward=0.0
                while not done:
                    a = agent.act(s, eval=False)
                    s2, r, done, _ = env.step(a)
                    agent.push(s,a,r,s2, float(done))
                    loss = agent.update()
                    s = s2
                    ep_reward += r
                episode +=1
            if episode % 10 == 0:
                print(f"Episode {episode}, eps {agent.eps:.3f}, buffer {len(agent.buffer)}")
        elif mode=="play":
            s = env.reset()
            done=False
            while not done and mode=="play":
                for ev in pygame.event.get():
                    if ev.type==pygame.QUIT:
                        running=False
                        mode="idle"
                a = agent.act(s, eval=True)
                s2, r, done, _ = env.step(a)
                s = s2
                draw_env(screen, env)
                pygame.display.flip()
                clock.tick(60)
            mode="idle"
        draw_env(screen, env)
        font = pygame.font.SysFont("Arial", 14)
        txt = font.render(f"Mode: {mode}  Episodes: {episode}  Eps: {agent.eps:.2f}", True, (230,230,230))
        screen.blit(txt, (8,8))
        pygame.display.flip()
        clock.tick(60)
    pygame.quit()

if __name__=="__main__":
    main()

Mode: train
Episode 20, eps 0.050, buffer 3435
Episode 40, eps 0.050, buffer 7175
Episode 60, eps 0.050, buffer 10979
Episode 80, eps 0.050, buffer 14793
Episode 100, eps 0.050, buffer 18709
Episode 120, eps 0.050, buffer 20000
Mode: play
Mode: train
Episode 140, eps 0.050, buffer 20000
Episode 160, eps 0.050, buffer 20000
Episode 180, eps 0.050, buffer 20000
Mode: play
Mode: train
Episode 200, eps 0.050, buffer 20000
Episode 220, eps 0.050, buffer 20000
Episode 240, eps 0.050, buffer 20000
Episode 260, eps 0.050, buffer 20000
Episode 280, eps 0.050, buffer 20000
Episode 300, eps 0.050, buffer 20000
Episode 320, eps 0.050, buffer 20000
Mode: idle
Mode: play
Mode: play
Mode: play
Mode: play
Mode: play
Mode: train
Episode 340, eps 0.050, buffer 20000
Episode 360, eps 0.050, buffer 20000
Episode 380, eps 0.050, buffer 20000
Episode 400, eps 0.050, buffer 20000
Episode 420, eps 0.050, buffer 20000
Episode 440, eps 0.050, buffer 20000
Episode 460, eps 0.050, buffer 20000
Episode 480, eps 0.