In [None]:
import os
import argparse
import random
from dataclasses import dataclass
from typing import Tuple, Deque, List, Optional
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"

import torch
torch.set_num_threads(1)
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from collections import deque

# ==============================
# ==============================
import gym
from gym import spaces

class DroneDelivery(gym.Env):

    metadata = {'render.modes': []}

    def __init__(self, size: int=6):
        super().__init__()
        self.max_step = 200
        self.current_step = 0
        self.reward = 0

        self.size = size
        self.num_control = 4

        self.agent_location = np.array([-1,-1], dtype=np.int32)
        self.pick_status = np.array([0], dtype=np.int32)
        self.package_location = np.array([-1, -1], dtype=np.int32)
        self.customer_location = np.array([-1,-1], dtype=np.int32)
        self.no_fly_zone = np.array([-1, -1], dtype=np.int32)

        self.observation = spaces.Dict(
            {
            "agent":spaces.Box(0, size-1, shape=(2,), dtype=int),
            "pick":spaces.Box(0,1,shape=(1,), dtype=int),
            "package":spaces.Box(0,size-1,shape=(2,), dtype=int),
            "customer":spaces.Box(0, size-1, shape=(2,), dtype=int),
            "no_fly_zone": spaces.Box(0, size - 1, shape=(2,), dtype=int),
            }
        )

        self.action_space = spaces.Discrete(self.num_control)

        self._num_states = (self.size * self.size) * 2 * (self.size * self.size) * (self.size * self.size) * (
                    self.size * self.size)

        self.observation_space = spaces.Discrete(self._num_states)

        self.action_to_direction = {
            0: np.array([1,0]),  # Move right
            1: np.array([0,1]),  # Move up
            2: np.array([-1,0]), # Move left
            3: np.array([0,-1]), # Move down
        }

    def _get_obs(self):
        return{
            "agent": self.agent_location,
            "pick": self.pick_status,
            "package": self.package_location,
            "customer": self.customer_location,
            "no_fly_zone": self.no_fly_zone
        }

    def _get_info(self):
        return {
            "distance": np.linalg.norm(
                self.agent_location - self.customer_location, ord=1
            )
        }

    def obs_to_key(self, obs):
        picked = int(obs["pick"][0])
        goal = obs["package"] if picked == 0 else obs["customer"]

        dx = int(obs["agent"][0] - goal[0])
        dy = int(obs["agent"][1] - goal[1])

        m = self.size - 1
        dx = max(-m, min(m, dx))
        dy = max(-m, min(m, dy))

        d_nfz = abs(obs["agent"][0] - obs["no_fly_zone"][0]) + abs(obs["agent"][1] - obs["no_fly_zone"][1])
        nfz_bucket = 0 if d_nfz == 0 else 1 if d_nfz <= 1 else 2 if d_nfz <= 2 else 3
        return (picked, dx, dy, nfz_bucket)

    def reset(self,seed: Optional[int] = None, options: Optional[dict] = None):
        super().reset(seed=seed)

        self.reward = 0
        self.current_step = 0

        grid = np.array([[x,y] for x in range(self.size) for y in range(self.size)])
        idx = np.random.choice(self.size*self.size, size=4, replace=False)

        self.agent_location = grid[idx[0]]
        self.pick_status = np.array([0], dtype=np.int32)
        self.package_location = grid[idx[1]]
        self.customer_location = grid[idx[2]]
        self.no_fly_zone = grid[idx[3]]

        self.state_grid = np.zeros((self.size, self.size))
        self.state_grid[tuple(self.agent_location)] = 1
        self.state_grid[tuple(self.package_location)] = 0.8
        self.state_grid[tuple(self.customer_location)] = 0.6
        self.state_grid[tuple(self.no_fly_zone)] = 0.4

        observation = self._get_obs()
        info = self._get_info()
        return observation, info

    def step(self, action: int):
        self.reward = 0
        terminated = False

        direction = self.action_to_direction[action]

        self.agent_location = np.clip(self.agent_location + direction, 0, self.size - 1)

        picked = np.array_equal(self.agent_location, self.package_location)
        drop_off = np.array_equal(self.agent_location, self.customer_location)
        in_no_fly_zone = np.array_equal(self.agent_location, self.no_fly_zone)

        if picked and self.pick_status[0] == 0:
            self.pick_status = [1]
            self.package_location = np.array([-1, -1], dtype=np.int32)
            self.reward += 25

        self.reward += -100 if in_no_fly_zone else 0

        if drop_off and self.pick_status[0] == 1:
            self.pick_status = [0]
            self.customer_location = np.array([-1, -1], dtype=np.int32)
            terminated = True
            self.reward += 100
        else:
            self.reward += -1  

        self.current_step += 1
        truncated = True if self.current_step == self.max_step else False

        self.state_grid = np.zeros((self.size, self.size))
        self.state_grid[tuple(self.agent_location)] = 1
        self.state_grid[tuple(self.package_location)] = 0.8
        self.state_grid[tuple(self.customer_location)] = 0.6
        self.state_grid[tuple(self.no_fly_zone)] = 0.4

        observation = self._get_obs()
        info = self._get_info()
        reward = self.reward
        return observation, reward, terminated, truncated, info

    def render(self):
        plt.title("Drone Delivery Environment")
        plt.imshow(self.state_grid)
        plt.show()

# ==============================
# 2) DQN with Experience Replay
# ==============================

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def safe_reset(env):
    r = env.reset()
    if isinstance(r, tuple) and len(r) == 2:
        return r
    return r, {}  

def safe_step(env, action):
    r = env.step(action)
    if isinstance(r, tuple) and len(r) == 5:
        return r
    elif isinstance(r, tuple) and len(r) == 4:
        obs, reward, done, info = r
        return obs, reward, bool(done), False, info
    else:
        raise RuntimeError("Unsupported env.step return signature.")

def obs_to_state_vector(obs, env) -> np.ndarray:
    picked, dx, dy, nfz_bucket = env.obs_to_key(obs)
    m = max(1, env.size - 1)
    return np.array([
        float(picked),
        float(dx) / float(m),
        float(dy) / float(m),
        float(nfz_bucket) / 3.0
    ], dtype=np.float32)

@dataclass
class Transition:
    s: np.ndarray
    a: int
    r: float
    s2: np.ndarray
    done: bool

class ReplayBuffer:
    def __init__(self, capacity: int):
        self.capacity = capacity
        self.buffer: Deque[Transition] = deque(maxlen=capacity)

    def __len__(self):
        return len(self.buffer)

    def push(self, *args):
        self.buffer.append(Transition(*args))

    def sample(self, batch_size: int) -> Transition:
        batch = random.sample(self.buffer, batch_size)
        s = np.stack([b.s for b in batch], axis=0)
        a = np.array([b.a for b in batch], dtype=np.int64)
        r = np.array([b.r for b in batch], dtype=np.float32)
        s2 = np.stack([b.s2 for b in batch], axis=0)
        d = np.array([b.done for b in batch], dtype=np.float32)
        return Transition(s, a, r, s2, d)

class QNetwork(nn.Module):
    def __init__(self, state_dim: int, n_actions: int, hidden: Tuple[int, int] = (128, 128)):
        super().__init__()
        h1, h2 = hidden
        self.net = nn.Sequential(
            nn.Linear(state_dim, h1),
            nn.ReLU(),
            nn.Linear(h1, h2),
            nn.ReLU(),
            nn.Linear(h2, n_actions)
        )

        def init_weights(m):
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.zeros_(m.bias)
        self.net.apply(init_weights)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

class EpsilonGreedy:
    def __init__(self, start: float, end: float, decay: float):
        self.eps = start
        self.start = start
        self.end = end
        self.decay = decay

    def value(self) -> float:
        return self.eps

    def step(self):
        self.eps = max(self.end, self.eps * self.decay)

def dqn_train(
    env,
    episodes: int = 1500,
    buffer_size: int = 50_000,
    batch_size: int = 64,
    gamma: float = 0.99,
    lr: float = 1e-3,
    start_learning_after: int = 1_000,
    learn_every: int = 4,
    target_sync_every: int = 1_000,
    eps_start: float = 1.0,
    eps_end: float = 0.05,
    eps_decay: float = 0.995,
    max_grad_norm: float = 5.0,
    seed: int = 0,
    device: Optional[str] = None,
):
    set_seed(seed)
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")

    n_actions = env.action_space.n
    state_dim = 4  # [picked, dx_norm, dy_norm, nfz_norm]

    online_q = QNetwork(state_dim, n_actions).to(device)
    target_q = QNetwork(state_dim, n_actions).to(device)
    target_q.load_state_dict(online_q.state_dict())
    target_q.eval()

    optimizer = optim.Adam(online_q.parameters(), lr=lr)
    loss_fn = nn.SmoothL1Loss()  # Huber

    buffer = ReplayBuffer(buffer_size)
    eps_sched = EpsilonGreedy(eps_start, eps_end, eps_decay)

    episode_returns: List[float] = []
    epsilon_history: List[float] = []

    global_step = 0

    for ep in range(episodes):
        obs, _ = safe_reset(env)
        s = obs_to_state_vector(obs, env)
        ep_ret = 0.0
        terminated = truncated = False

        while not (terminated or truncated):
            epsilon = eps_sched.value()
            epsilon_history.append(epsilon)

            if random.random() < epsilon:
                a = env.action_space.sample()
            else:
                with torch.no_grad():
                    qs = online_q(torch.from_numpy(s).unsqueeze(0).to(device))
                    a = int(qs.argmax(dim=1).item())

            obs2, r, terminated, truncated, _ = safe_step(env, a)
            s2 = obs_to_state_vector(obs2, env)
            done_flag = bool(terminated or truncated)

            buffer.push(s, a, float(r), s2, done_flag)

            ep_ret += float(r)
            s = s2
            global_step += 1

            if len(buffer) >= max(batch_size, start_learning_after) and (global_step % learn_every == 0):
                batch = buffer.sample(batch_size)

                s_b = torch.from_numpy(batch.s).to(device)
                a_b = torch.from_numpy(batch.a).to(device)
                r_b = torch.from_numpy(batch.r).to(device)
                s2_b = torch.from_numpy(batch.s2).to(device)
                d_b = torch.from_numpy(batch.done).to(device)

                # Q(s,a)
                q_values = online_q(s_b).gather(1, a_b.unsqueeze(1)).squeeze(1)

                # 타깃: r + gamma*(1-d)*max_a' Q_target(s', a')
                with torch.no_grad():
                    next_q = target_q(s2_b).max(dim=1).values
                    target = r_b + gamma * (1.0 - d_b) * next_q

                loss = loss_fn(q_values, target)

                optimizer.zero_grad(set_to_none=True)
                loss.backward()
                nn.utils.clip_grad_norm_(online_q.parameters(), max_grad_norm)
                optimizer.step()

            if global_step % target_sync_every == 0:
                target_q.load_state_dict(online_q.state_dict())

        episode_returns.append(ep_ret)
        eps_sched.step()

        if (ep + 1) % 50 == 0:
            mean100 = np.mean(episode_returns[-100:]) if len(episode_returns) >= 1 else ep_ret
            print(f"[Episode {ep+1:4d}] return={ep_ret:8.2f}  eps={eps_sched.value():.3f}  "
                  f"mean100={mean100:8.2f}  buffer={len(buffer):6d}")

    target_q.load_state_dict(online_q.state_dict())

    return {
        "online_q": online_q,
        "target_q": target_q,
        "episode_returns": episode_returns,
        "epsilon_history": epsilon_history,
        "device": device
    }

# ==============================
# ==============================

@torch.no_grad()
def select_greedy_action(online_q: QNetwork, s: np.ndarray, device: str) -> int:
    qs = online_q(torch.from_numpy(s).unsqueeze(0).to(device))
    return int(qs.argmax(dim=1).item())

def evaluate_greedy(env, online_q: QNetwork, episodes: int = 20, device: str = "cpu", plot: bool = True):
    rewards = []
    for _ in range(episodes):
        obs, _ = safe_reset(env)
        s = obs_to_state_vector(obs, env)
        terminated = truncated = False
        ep_ret = 0.0
        while not (terminated or truncated):
            a = select_greedy_action(online_q, s, device)
            obs2, r, terminated, truncated, _ = safe_step(env, a)
            ep_ret += float(r)
            s = obs_to_state_vector(obs2, env)
        rewards.append(ep_ret)

    print("\n=== Greedy Evaluation (DQN) ===")
    print("Rewards:", rewards)
    print(f"Mean: {np.mean(rewards):.3f} | Median: {np.median(rewards):.3f} | "
          f"Best: {np.max(rewards):.3f} | Worst: {np.min(rewards):.3f}")

    if plot:
        plt.figure()
        plt.title("Greedy Evaluation: Total Reward per Episode (DQN)")
        plt.plot(range(1, len(rewards) + 1), rewards, marker="o")
        plt.xlabel("Episode")
        plt.ylabel("Total Reward per Episode")
        plt.tight_layout()
        plt.show()

    return rewards

def run_one_greedy_episode_with_render_and_verify(env, online_q: QNetwork, device: str = "cpu", max_steps: int = 200):
    def _tuple(a):
        return tuple(int(x) for x in np.array(a).ravel())

    obs, _ = safe_reset(env)
    s = obs_to_state_vector(obs, env)

    picked_once = False
    delivered = False
    terminated = truncated = False
    total_reward = 0.0

    print("\n[Greedy Episode START]")
    if hasattr(env, "render"):
        try:
            env.render()
        except Exception:
            pass

    step_idx = 0
    while not (terminated or truncated):
        a = select_greedy_action(online_q, s, device)
        obs2, r, terminated, truncated, _ = safe_step(env, a)
        total_reward += float(r)

        print(f"[Step {step_idx:03d}] action={a}, reward={r}")
        if hasattr(env, "render"):
            try:
                env.render()
            except Exception:
                pass

        try:
            pk = int(obs2["pick"][0])
            if pk == 1:
                picked_once = True
            cust = _tuple(obs2["customer"])
            if cust == (-1, -1):
                delivered = True
        except Exception:
            pass

        s = obs_to_state_vector(obs2, env)
        step_idx += 1
        if step_idx >= max_steps:
            print("[WARN] max_steps reached, stopping.")
            break

    print(f"[Greedy Episode END] total_reward={total_reward:.1f}, terminated={terminated}, truncated={truncated}")
    ok = picked_once and delivered and (terminated or truncated)
    if ok:
        print("VERIFIED: Mission Complete.")
    else:
        print("VERIFICATION FAILED:",
              f"picked_once={picked_once}, delivered={delivered}, terminated={terminated}, truncated={truncated}")
    return ok, total_reward

# ==============================
# 4) Main
# ==============================
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--episodes", type=int, default=1500)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--buffer_size", type=int, default=50000)
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--gamma", type=float, default=0.99)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--start_learning_after", type=int, default=1000)
    parser.add_argument("--learn_every", type=int, default=4)
    parser.add_argument("--target_sync_every", type=int, default=1000)
    parser.add_argument("--eps_start", type=float, default=1.0)
    parser.add_argument("--eps_end", type=float, default=0.05)
    parser.add_argument("--eps_decay", type=float, default=0.995)
    parser.add_argument("--max_grad_norm", type=float, default=5.0)
    parser.add_argument("--eval_episodes", type=int, default=10)
    parser.add_argument("--save_path", type=str, default="dqn_drone_delivery.pt")
    args = parser.parse_args()

    set_seed(args.seed)

    env = DroneDelivery()

    results = dqn_train(
        env,
        episodes=args.episodes,
        buffer_size=args.buffer_size,
        batch_size=args.batch_size,
        gamma=args.gamma,
        lr=args.lr,
        start_learning_after=args.start_learning_after,
        learn_every=args.learn_every,
        target_sync_every=args.target_sync_every,
        eps_start=args.eps_start,
        eps_end=args.eps_end,
        eps_decay=args.eps_decay,
        max_grad_norm=args.max_grad_norm,
        seed=args.seed,
    )

    online_q = results["online_q"]
    device = results["device"]
    episode_returns = results["episode_returns"]
    epsilon_history = results["epsilon_history"]

    torch.save({"state_dict": online_q.state_dict()}, args.save_path)
    print(f"Saved DQN weights to: {args.save_path}")

    if len(episode_returns) > 0:
        print(f"\nTrained episodes: {len(episode_returns)}")
        print(f"Average return: {np.mean(episode_returns):.3f} | Median: {np.median(episode_returns):.3f}")
        plt.figure()
        plt.title("DQN: Total Reward per Episode")
        plt.plot(episode_returns)
        plt.xlabel("Episode")
        plt.ylabel("Total Reward per Episode")
        plt.tight_layout()
        plt.show()

    if len(epsilon_history) > 0:
        plt.figure()
        plt.title("Epsilon Decay per Step")
        plt.plot(epsilon_history)
        plt.xlabel("Step")
        plt.ylabel("Epsilon")
        plt.tight_layout()
        plt.show()

    evaluate_greedy(env, online_q, episodes=args.eval_episodes, device=device, plot=True)
    run_one_greedy_episode_with_render_and_verify(env, online_q, device=device, max_steps=200)

if __name__ == "__main__":
    main()