In [None]:
# train_dqn_full.py
"""
Enhanced DQN training script for Flight Landing.
- Simulates 3 runway types: Dry, Wet, Icy (different drag physics)
- Easier start (alt=400m, dist=800m)
- Softer penalties, stronger landing reward
- Prints only summary every 200 episodes
- Works in Colab / Jupyter
"""

import argparse
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import gymnasium as gym
from gymnasium import spaces

# =====================================================
# 1️⃣ ENVIRONMENT
# =====================================================
class FlightEnv(gym.Env):
    """Simplified Flight Landing Environment with runway condition physics."""
    def __init__(self, start_alt=400.0, start_dist=800.0):
        super().__init__()
        self.observation_space = spaces.Box(
            low=np.array([0, 0, 0, -30, 0], dtype=np.float32),
            high=np.array([5000, 300, 10000, 30, 1], dtype=np.float32),
            dtype=np.float32
        )
        self.action_space = spaces.Discrete(5)
        self.start_alt = start_alt
        self.start_dist = start_dist
        self.reset()

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.altitude = float(self.start_alt)
        self.speed = float(160.0 + np.random.uniform(-10, 10))
        self.distance = float(self.start_dist)
        self.prev_distance = self.distance
        self.angle = float(np.random.uniform(-2, 2))
        # Runway conditions: 0.0=dry, 0.5=wet, 1.0=icy
        self.runway_condition = float(np.random.choice([0.0, 0.5, 1.0]))
        self.steps = 0
        return self._get_obs(), {}

    def step(self, action):
        self.steps += 1

        # --- Action effects ---
        if action == 0:  # throttle up
            self.speed += 6.0
        elif action == 1:  # throttle down
            self.speed -= 6.0
        elif action == 2:  # pitch up
            self.altitude += 35.0
            self.angle += 1.5
        elif action == 3:  # pitch down
            self.altitude -= 35.0
            self.angle -= 1.5

        # --- Dynamics ---
        self.distance -= max(self.speed * 0.3, 1.0)
        self.altitude -= 8.0

        # --- Runway physics ---
        if self.runway_condition == 0.0:  # dry
            drag_factor = 0.6
        elif self.runway_condition == 0.5:  # wet
            drag_factor = 0.4
        else:  # icy
            drag_factor = 0.25

        self.speed -= drag_factor
        self.angle = np.clip(self.angle, -30, 30)
        self.altitude = max(self.altitude, 0.0)
        self.speed = np.clip(self.speed, 0.0, 300.0)

        # --- Reward shaping ---
        reward = 0.0
        reward += (self.prev_distance - self.distance) * 0.02
        self.prev_distance = self.distance
        reward -= 0.03
        reward -= 0.005 * abs(self.altitude - 100)
        reward -= 0.005 * abs(self.speed - 150)
        reward -= 0.01 * abs(self.angle)

        if self.distance < 400:
            reward += 0.8
        if 0 < self.altitude < 100 and 100 < self.speed < 200:
            reward += 1.5
        reward += (self.start_dist - self.distance) / max(1.0, self.start_dist)

        done = False
        success = False
        outcome = "in-flight"

        # --- Landing / Crash logic ---
        if self.distance <= 0:
            if 0 <= self.altitude <= 50 and 100 <= self.speed <= 200 and abs(self.angle) < 10:
                reward += 200.0
                success = True
                outcome = "successful landing"
            else:
                reward -= 40.0
                outcome = "failed landing"
            done = True

        if self.altitude <= 0 and self.distance > 0:
            reward -= 40.0
            done = True
            outcome = "crash before runway"
        if self.speed <= 20 and self.altitude > 100:
            reward -= 40.0
            done = True
            outcome = "stall midair"
        if self.steps >= 600:
            done = True
            outcome = "timeout"

        info = {"success": success, "outcome": outcome, "runway_condition": self.runway_condition}
        return self._get_obs(), float(reward), bool(done), False, info

    def _get_obs(self):
        return np.array([
            self.altitude / 5000.0,
            self.speed / 300.0,
            self.distance / 10000.0,
            (self.angle + 30.0) / 60.0,
            self.runway_condition
        ], dtype=np.float32)


# =====================================================
# 2️⃣ REPLAY BUFFER
# =====================================================
class ReplayBuffer:
    def __init__(self, capacity: int, obs_shape):
        self.capacity = int(capacity)
        self.obs_shape = tuple(obs_shape)
        self.ptr = 0
        self.size = 0
        self.states = np.zeros((self.capacity,) + self.obs_shape, dtype=np.float32)
        self.next_states = np.zeros((self.capacity,) + self.obs_shape, dtype=np.float32)
        self.actions = np.zeros((self.capacity,), dtype=np.int64)
        self.rewards = np.zeros((self.capacity,), dtype=np.float32)
        self.dones = np.zeros((self.capacity,), dtype=np.float32)

    def push(self, state, action, reward, next_state, done):
        self.states[self.ptr] = state
        self.next_states[self.ptr] = next_state
        self.actions[self.ptr] = int(action)
        self.rewards[self.ptr] = float(reward)
        self.dones[self.ptr] = 1.0 if done else 0.0
        self.ptr = (self.ptr + 1) % self.capacity
        self.size = min(self.size + 1, self.capacity)

    def sample(self, batch_size: int):
        idxs = np.random.randint(0, self.size, size=batch_size)
        return dict(
            states=self.states[idxs],
            actions=self.actions[idxs],
            rewards=self.rewards[idxs],
            next_states=self.next_states[idxs],
            dones=self.dones[idxs]
        )

    def __len__(self):
        return self.size


# =====================================================
# 3️⃣ AGENT
# =====================================================
class QNetwork(nn.Module):
    def __init__(self, obs_dim: int, n_actions: int):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(obs_dim, 256), nn.ReLU(),
            nn.Linear(256, 256), nn.ReLU(),
            nn.Linear(256, n_actions)
        )

    def forward(self, x): return self.model(x)


class DQNAgent:
    def __init__(self, obs_dim, n_actions, lr=3e-4, gamma=0.99, device='cpu'):
        self.device = torch.device(device)
        self.q_net = QNetwork(obs_dim, n_actions).to(self.device)
        self.target_q = QNetwork(obs_dim, n_actions).to(self.device)
        self.target_q.load_state_dict(self.q_net.state_dict())
        self.opt = optim.Adam(self.q_net.parameters(), lr=lr)
        self.loss_fn = nn.MSELoss()
        self.gamma = gamma

    def act(self, obs, epsilon=0.0):
        if np.random.rand() < epsilon:
            return np.random.randint(0, self.q_net.model[-1].out_features)
        with torch.no_grad():
            t = torch.tensor(obs, dtype=torch.float32, device=self.device).unsqueeze(0)
            q = self.q_net(t)
            return int(q.argmax().item())

    def update(self, batch):
        s = torch.tensor(batch['states'], dtype=torch.float32, device=self.device)
        ns = torch.tensor(batch['next_states'], dtype=torch.float32, device=self.device)
        a = torch.tensor(batch['actions'], dtype=torch.int64, device=self.device)
        r = torch.tensor(batch['rewards'], dtype=torch.float32, device=self.device)
        d = torch.tensor(batch['dones'], dtype=torch.float32, device=self.device)
        q = self.q_net(s).gather(1, a.unsqueeze(1)).squeeze(1)
        with torch.no_grad():
            nq = self.target_q(ns).max(1)[0]
            tgt = r + (1 - d) * self.gamma * nq
        loss = self.loss_fn(q, tgt)
        self.opt.zero_grad(); loss.backward()
        nn.utils.clip_grad_norm_(self.q_net.parameters(), 10.0)
        self.opt.step()
        return loss.item()

    def sync_target(self): self.target_q.load_state_dict(self.q_net.state_dict())


# =====================================================
# 4️⃣ TRAINING
# =====================================================
def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument('--episodes', type=int, default=50000)
    p.add_argument('--lr', type=float, default=3e-4)
    p.add_argument('--buffer_size', type=int, default=30000)
    p.add_argument('--batch_size', type=int, default=128)
    p.add_argument('--eps_start', type=float, default=1.0)
    p.add_argument('--eps_end', type=float, default=0.05)
    p.add_argument('--eps_decay', type=float, default=0.995)
    p.add_argument('--target_update', type=int, default=1000)
    p.add_argument('--save_dir', type=str, default='./checkpoints')
    p.add_argument('--cpu', action='store_true')
    p.add_argument('--seed', type=int, default=42)
    return p.parse_args()


def train(args):
    device = 'cuda' if torch.cuda.is_available() and not args.cpu else 'cpu'
    torch.manual_seed(args.seed); np.random.seed(args.seed); random.seed(args.seed)
    env = FlightEnv()
    obs_dim = env.observation_space.shape[0]
    n_actions = env.action_space.n
    agent = DQNAgent(obs_dim, n_actions, lr=args.lr, gamma=0.99, device=device)
    buffer = ReplayBuffer(args.buffer_size, (obs_dim,))
    epsilon = args.eps_start
    os.makedirs(args.save_dir, exist_ok=True)

    outcomes = {"successful landing": 0, "failed landing": 0,
                "crash before runway": 0, "stall midair": 0, "timeout": 0}
    runway_counts = {0.0: 0, 0.5: 0, 1.0: 0}
    runway_success = {0.0: 0, 0.5: 0, 1.0: 0}

    best_return = -1e9
    total_steps = 0

    for ep in range(1, args.episodes + 1):
        obs, _ = env.reset()
        ep_return = 0.0
        done = False
        outcome = None
        runway = env.runway_condition

        while not done:
            action = agent.act(obs, epsilon)
            nobs, rew, done, _, info = env.step(action)
            buffer.push(obs, action, rew, nobs, done)
            obs = nobs
            ep_return += rew
            total_steps += 1
            outcome = info["outcome"]

            if len(buffer) > args.batch_size:
                batch = buffer.sample(args.batch_size)
                agent.update(batch)

            if total_steps % args.target_update == 0:
                agent.sync_target()

        epsilon = max(args.eps_end, epsilon * args.eps_decay)
        outcomes[outcome] = outcomes.get(outcome, 0) + 1
        runway_counts[runway] += 1
        if info["success"]:
            runway_success[runway] += 1

        # Print summary every 200 episodes
        if ep % 200 == 0:
            runway_probs = {k: (runway_success[k] / runway_counts[k] if runway_counts[k] > 0 else 0)
                            for k in runway_counts}
            print("\n--- Outcome Summary up to Episode", ep, "---")
            for k, v in outcomes.items():
                print(f"{k:<25}: {v}")
            print("Runway success probabilities:", runway_probs)
            print("------------------------------------------\n")

        if ep_return > best_return:
            best_return = ep_return
            torch.save(agent.q_net.state_dict(), os.path.join(args.save_dir, 'best_model.pt'))

    print("✅ Training done.")
    print("Best episode return:", best_return)
    print("Final outcomes summary:")
    for k, v in outcomes.items():
        print(f"{k:<25}: {v}")
    final_probs = {k: (runway_success[k] / runway_counts[k] if runway_counts[k] > 0 else 0)
                   for k in runway_counts}
    print("Final runway success probabilities:", final_probs)


# =====================================================
# 5️⃣ RUN
# =====================================================
if __name__ == "__main__":
    import sys
    sys.argv = [sys.argv[0]]  # fix for Colab
    args = parse_args()
    print("Starting training with args:", args)
    train(args)
