In [1]:
import os
import torch
import torch.optim as optim
import torch.nn.functional as F
import random
from collections import deque

import matplotlib.pyplot as plt

from env import PursuitEvasionEnv
from models import GCN_QNetwork
from torch_geometric.utils import dense_to_sparse
import numpy as np

In [2]:
class ReplayBuffer:
    def __init__(self, capacity=100000):
        self.buffer = deque(maxlen=capacity)

    def push(self, id, state, actions, reward, next_state, done):
        self.buffer.append((id, state, actions, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        return zip(*batch)

    def __len__(self):
        return len(self.buffer)

In [3]:
def obs_to_tensors(obs, pid, device):
    """
    obs[pid] contains:
      - 'features': np.array shape (N, F)
      - 'adj':      np.array shape (N, N)
    Returns (x, edge_index) or (None, None) if no nodes.
    """
    info = obs.get(pid, None)
    if info is None: return None, None

    feats = info["features"]
    if feats.shape[0] == 0:
        return None, None

    adj = info["adj"]
    # to torch
    x = torch.tensor(feats, dtype=torch.float32, device=device)
    edge_index, _ = dense_to_sparse(torch.tensor(adj, dtype=torch.float32, device=device))
    return x, edge_index

In [4]:
def calc_evader_speed(e, total_eps=400, min_speed=0.35, max_speed=1.0):
    # ramp parameters
    ramp_eps = total_eps - 100       # 300 if total=400
    a = 2                             # power >1 => fast start, slow end

    if e >= ramp_eps:
        return max_speed
    t = e / ramp_eps                  # normalized [0,1)
    return min_speed + (max_speed-min_speed)*(1 - (1 - t)**a)

In [5]:
def train_shared_policy(
    num_episodes=200,
    batch_size=16,
    gamma=0.99,
    lr=1e-3,
    save_dir='checkpoints',
    save_interval=50,
    plot_interval=10
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    os.makedirs(save_dir, exist_ok=True)

    env = PursuitEvasionEnv(random_start=True, evader_strat="potential_fields")
    evader_speed = 0.1
    feature_dim = env._get_observation()[0]["features"].shape[1]

    policy_net = GCN_QNetwork(in_features=feature_dim, hidden_dim=64).to(device)
    target_net = GCN_QNetwork(in_features=feature_dim, hidden_dim=64).to(device)
    target_net.load_state_dict(policy_net.state_dict())
    target_net.eval()

    optimizer = optim.Adam(policy_net.parameters(), lr=lr)
    replay_buffer = ReplayBuffer()

    epsilon, eps_end, eps_decay = 1.0, 0.1, 0.99995

    rewards_history = []

    for ep in range(1, num_episodes+1):
        successful_pursuers = 0
        obs = env.reset(random_start=True, evader_strat="potential_fields")
        done = False
        episode_reward = 0

        print('##################################################')
        print(f'Evader Start Pos: {env.evader.x}, {env.evader.y}')

        while not done:
            actions = {}
            for p in env.pursuers:
                if not p.is_moving:
                    x, edge_index = obs_to_tensors(obs, p.id, device)
                    if x is None:
                        actions[p.id] = 0
                    elif random.random() < epsilon:
                        actions[p.id] = random.randrange(x.size(0))
                    else:
                        with torch.no_grad():
                            q_vals = policy_net(x, edge_index)
                            actions[p.id] = int(q_vals.argmax().item())

            obs, transitions, done = env.step(actions, evader_speed=evader_speed)
            if done:
                for p in env.pursuers:
                    if p.found_evader:
                        successful_pursuers += 1

            if env.step_count % 200 == 0:
                print(f'Episode {ep} step count: {env.step_count}, evader pos: {env.evader.x}, {env.evader.y}')

            for id, prev_obs, prev_action, reward, p_done in transitions:
                replay_buffer.push(id, prev_obs, prev_action, reward, obs[id], p_done)
                episode_reward += reward


            # --- DQN update ---
            if len(replay_buffer) >= batch_size:
                ids, states, acts, rewards, next_states, dones = replay_buffer.sample(batch_size)
                losses = []
                for pid, s, action_idx, r, ns, d in zip(ids, states, acts, rewards, next_states, dones):
                    x_s, ei_s = obs_to_tensors(s, pid, device)
                    if x_s is None or action_idx >= x_s.size(0): continue
                    Q_s = policy_net(x_s, ei_s)[action_idx]

                    x_ns, ei_ns = obs_to_tensors(ns, pid, device)
                    with torch.no_grad():
                        Q_ns = target_net(x_ns, ei_ns).max().item() if x_ns is not None else 0.0
                    target = r + gamma * (1 - float(d)) * Q_ns
                    target_t = torch.tensor(target, dtype=torch.float32, device=device)
                    losses.append(F.mse_loss(Q_s, target_t))

                if losses:
                    optimizer.zero_grad()
                    torch.stack(losses).mean().backward()
                    optimizer.step()

            epsilon = max(eps_end, epsilon * eps_decay)

        # --- End of episode bookkeeping ---
        rewards_history.append(episode_reward)

        # Increase Evader Speed
        evader_speed = calc_evader_speed(ep)

        # Sync target
        if ep % 10 == 0:
            target_net.load_state_dict(policy_net.state_dict())

        # Save model
        if ep % save_interval == 0:
            torch.save(policy_net.state_dict(), os.path.join(save_dir, f'policy_ep{ep}.pt'))

        # Plot rewards
        if ep % plot_interval == 0:
            plt.figure()
            # Actual rewards (transparent)
            plt.plot(
                rewards_history,
                label='Episode Reward',
                alpha=0.3
            )
            # Moving average
            if len(rewards_history) >= plot_interval:
                window = plot_interval
                mov_avg = np.convolve(
                    rewards_history,
                    np.ones(window) / window,
                    mode='valid'
                )
                plt.plot(
                    np.arange(window-1, len(rewards_history)),
                    mov_avg,
                    label=f'{window}-Ep Moving Avg',
                    color='blue'
                )
            plt.xlabel('Episode')
            plt.ylabel('Reward')
            plt.title('Training Reward Curve')
            plt.legend()
            plt.grid(True)
            plt.tight_layout()
            plt.savefig(os.path.join(save_dir, 'reward_curve.png'))
            plt.close()

        print(f"Length of Buffer: {len(replay_buffer)}")
        print(f"Episode {ep}/{num_episodes}  Reward: {episode_reward:.2f}  Epsilon: {epsilon:.3f}   | Successful Pursuers: {successful_pursuers}   Evader Speed: {evader_speed}")
        print(f"P1 Area Seen: {env.pursuers[0].prev_area_seen}  P2 Area Seen: {env.pursuers[1].prev_area_seen}  P3 Area Seen: {env.pursuers[2].prev_area_seen}")

    # --- Final save & plot ---
    torch.save(policy_net.state_dict(), os.path.join(save_dir, 'policy_final.pt'))
    plt.figure()
    plt.plot(rewards_history, label='Episode Reward', alpha=0.3)
    if len(rewards_history) >= plot_interval:
        window = plot_interval
        mov_avg = np.convolve(
            rewards_history,
            np.ones(window) / window,
            mode='valid'
        )
        plt.plot(
            np.arange(window-1, len(rewards_history)),
            mov_avg,
            label=f'{window}-Ep Moving Avg',
            color='blue'
        )
    plt.xlabel('Episode')
    plt.ylabel('Reward')
    plt.title('Training Reward Curve (Final)')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, 'reward_curve_final.png'))
    plt.close()

    print("Training complete!")

In [6]:
train_shared_policy(
    num_episodes=400,
    batch_size=16,
    gamma=0.99,
    lr=1e-4,
    save_dir='checkpoints',
    save_interval=50,
    plot_interval=10
)

##################################################
Evader Start Pos: 56, 34
Episode 1 step count: 200, evader pos: 53, 18
Episode 1 step count: 400, evader pos: 56, 3
Length of Buffer: 65
Episode 1/400  Reward: 163.97  Epsilon: 0.971   | Successful Pursuers: 3   Evader Speed: 0.35432611111111106
P1 Area Seen: 667  P2 Area Seen: 1843  P3 Area Seen: 759
##################################################
Evader Start Pos: 57, 48
Episode 2 step count: 200, evader pos: 59, 48
Episode 2 step count: 400, evader pos: 59, 48
Episode 2 step count: 600, evader pos: 59, 48
Episode 2 step count: 800, evader pos: 59, 48
Episode 2 step count: 1000, evader pos: 59, 48
Episode 2 step count: 1200, evader pos: 59, 48
Episode 2 step count: 1400, evader pos: 59, 48
Episode 2 step count: 1600, evader pos: 59, 48
Length of Buffer: 165
Episode 2/400  Reward: 95.20  Epsilon: 0.896   | Successful Pursuers: 1   Evader Speed: 0.35863777777777783
P1 Area Seen: 766  P2 Area Seen: 3477  P3 Area Seen: 3249
##########