# Lunar Lander

In this notebook, we will train an agent using a Deep Q-Network (DQN) to land a spacecraft safely in the LunarLander-v3 environment.

Steps:
- Setup the environment
- Define the DQN agent
- Train the agent
- Plot the training results
---

### 1. Imports and Device Setup

Load packages and check if we can use a GPU for faster training


In [None]:
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque, namedtuple

# Make sure we run on GPU if available (otherwise CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 2. Build the Q-Network
This is the brain of our agent
It learns to predict the best action to take in any situation.

In [None]:
class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=64):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim)
        )
    
    def forward(self, x):
        return self.layers(x)

### 3. Build the Replay Buffer
This is our agent's "memory" to remember what happened and learn from it.

In [None]:
class ReplayBuffer:
    def __init__(self, capacity=100_000):
        self.buffer = deque(maxlen=capacity)
        self.experience = namedtuple("Experience", ["state", "action", "reward", "next_state", "done"])

    def push(self, state, action, reward, next_state, done):
        self.buffer.append(self.experience(state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states      = torch.from_numpy(np.vstack([e.state for e in batch])).float().to(device)
        actions     = torch.from_numpy(np.vstack([e.action for e in batch])).long().to(device)
        rewards     = torch.from_numpy(np.vstack([e.reward for e in batch])).float().to(device)
        next_states = torch.from_numpy(np.vstack([e.next_state for e in batch])).float().to(device)
        dones       = torch.from_numpy(np.vstack([e.done for e in batch]).astype(np.uint8)).float().to(device)
        return states, actions, rewards, next_states, dones

    def __len__(self):
        return len(self.buffer)

### 4. Build the DQN Agent
This agent will:
- Pick actions
- Store experiences
- Learn from past mistakes

In [None]:
class DQNAgent:
    def __init__(self, state_dim, action_dim, hidden_dim=64, lr=1e-4, gamma=0.99, tau=1e-3, batch_size=64, update_every=5):
        self.q_eval = QNetwork(state_dim, action_dim, hidden_dim).to(device)
        self.q_target = QNetwork(state_dim, action_dim, hidden_dim).to(device)
        self.optimizer = optim.Adam(self.q_eval.parameters(), lr=lr)
        
        self.gamma = gamma
        self.tau = tau
        self.batch_size = batch_size
        self.update_every = update_every
        self.step_counter = 0

        self.buffer = ReplayBuffer()
        self.loss_fn = nn.MSELoss()

    def act(self, state, epsilon=0.1):
        if random.random() < epsilon:
            return random.randrange(self.q_eval.layers[-1].out_features)
        else:
            state_t = torch.tensor(state).float().unsqueeze(0).to(device)
            with torch.no_grad():
                return int(torch.argmax(self.q_eval(state_t)).item())

    def step(self, state, action, reward, next_state, done):
        self.buffer.push(state, action, reward, next_state, done)

        self.step_counter = (self.step_counter + 1) % self.update_every
        if self.step_counter == 0 and len(self.buffer) >= self.batch_size:
            self.learn()

    def learn(self):
        states, actions, rewards, next_states, dones = self.buffer.sample(self.batch_size)

        q_targets_next = self.q_target(next_states).detach().max(dim=1, keepdim=True)[0]
        q_targets = rewards + (self.gamma * q_targets_next * (1 - dones))
        q_expected = self.q_eval(states).gather(1, actions)

        loss = self.loss_fn(q_expected, q_targets)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        self.soft_update()

    def soft_update(self):
        for target_param, eval_param in zip(self.q_target.parameters(), self.q_eval.parameters()):
            target_param.data.copy_(self.tau * eval_param.data + (1.0 - self.tau) * target_param.data)


### 5. Train the Agent
Let's make our lander smarter with each episode

In [None]:
def train_dqn(env, agent, n_episodes=1500, max_steps=1000, eps_start=1.0, eps_end=0.01, eps_decay=0.995):
    scores = []
    epsilon = eps_start

    for ep in range(1, n_episodes + 1):
        state, _ = env.reset(seed=ep)
        total_reward = 0

        for t in range(max_steps):
            action = agent.act(state, epsilon)
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated

            agent.step(state, action, reward, next_state, done)

            state = next_state
            total_reward += reward

            if done:
                break

        scores.append(total_reward)
        epsilon = max(eps_end, eps_decay * epsilon)

        if ep % 100 == 0:
            avg_score = np.mean(scores[-100:])
            print(f"Episode {ep} | Average (last 100 episodes): {avg_score:.2f}")

    return scores

### 6. Main Training Execution

In [None]:
# Create environment
env = gym.make("LunarLander-v3")
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

# Create agent
agent = DQNAgent(state_dim, action_dim)

# Train agent
scores = train_dqn(env, agent)

# Save trained model
torch.save(agent.q_eval.state_dict(), "dqn_lander.pth")

# Close env after training
env.close()

### 7. Plot the Results
Let's see how the lander improved over time

In [None]:
plt.figure(figsize=(10,6))
plt.plot(np.arange(len(scores)), scores)
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.title('Training Performance Over Episodes')
plt.grid(True)
plt.savefig("training_curve.png")
plt.show()

print("Training complete!")