# Dependencies

In [1]:
%pip install torch



In [2]:
!apt-get install -y swig

!git clone https://github.com/openai/box2d-py
%cd box2d-py
!pip install -e .

%cd ..
!pip install gymnasium[box2d] --no-deps

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
swig is already the newest version (4.0.2-1ubuntu1).
0 upgraded, 0 newly installed, 0 to remove and 34 not upgraded.
fatal: destination path 'box2d-py' already exists and is not an empty directory.
/content/box2d-py
Obtaining file:///content/box2d-py
  Preparing metadata (setup.py) ... [?25l[?25hdone
Installing collected packages: box2d-py
  Attempting uninstall: box2d-py
    Found existing installation: box2d-py 2.3.8
    Uninstalling box2d-py-2.3.8:
      Successfully uninstalled box2d-py-2.3.8
  Running setup.py develop for box2d-py
Successfully installed box2d-py-2.3.8
/content


# importing the model

In [3]:
import os
import torch
import numpy as np
import gymnasium as gym
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import random
import pickle
import seaborn as sns

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Our DQN architecture
class DQN(nn.Module):
    def __init__(self, n_observations, n_actions):
        super().__init__()
        self.layer1 = nn.Linear(n_observations, 128)
        self.layer2 = nn.Linear(128, 128)
        self.layer3 = nn.Linear(128, n_actions)

    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        return self.layer3(x)

# Load the environment
env = gym.make("LunarLander-v3")

n_observations = env.observation_space.shape[0]
n_actions = env.action_space.n

# Loading model, we are loading 5 dagger models to see how dagger performed during many instances in the codes given before
model1 = DQN(n_observations, n_actions).to(device)
model2 = DQN(n_observations, n_actions).to(device)
model3 = DQN(n_observations, n_actions).to(device)
model4 = DQN(n_observations, n_actions).to(device)
model5 = DQN(n_observations, n_actions).to(device)
model1.load_state_dict(torch.load("/content/dagger_model_iter1.pth"))  # replace with your actual path
model1.eval()
model2.load_state_dict(torch.load("/content/dagger_model_iter2.pth"))  # replace with your actual path
model2.eval()
model3.load_state_dict(torch.load("/content/dagger_model_iter3.pth"))  # replace with your actual path
model3.eval()
model4.load_state_dict(torch.load("/content/dagger_model_iter4.pth"))  # replace with your actual path
model4.eval()
model5.load_state_dict(torch.load("/content/dagger_model_iter5.pth"))  # replace with your actual path
model5.eval()

# Evaluation loop
def evaluate_policy(env, model, episodes=100):
    total_rewards = []
    for ep in range(episodes):
        obs, _ = env.reset()
        done = False
        total_reward = 0.0
        while not done:
            obs_tensor = torch.tensor(obs, dtype=torch.float32, device=device).unsqueeze(0)
            with torch.no_grad():
                q_values = model(obs_tensor)
                action = torch.argmax(q_values, dim=1).item()
            obs, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            total_reward += reward
        total_rewards.append(total_reward)

    avg_reward = np.mean(total_rewards)
    print(f"Average Reward over {episodes} episodes: {avg_reward:.2f}")
    return avg_reward

evaluate_policy(env, model1)
evaluate_policy(env, model2)
evaluate_policy(env, model3)
evaluate_policy(env, model4)
evaluate_policy(env, model5)
env.close()


Average Reward over 100 episodes: 242.07
Average Reward over 100 episodes: 246.25
Average Reward over 100 episodes: 229.65
Average Reward over 100 episodes: 236.86
Average Reward over 100 episodes: 237.37


# Self play with MCTS

We have implemented self play a little differently. Common self play is covered in two player competitive environments, but we are doing a single agent self play mechanism.

In [4]:
# Full modified MCTS self-play code with visualization and video recording


from collections import deque
import copy

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class DQN(nn.Module):
    def __init__(self, n_observations, n_actions):
        super().__init__()
        self.layer1 = nn.Linear(n_observations, 128)
        self.layer2 = nn.Linear(128, 128)
        self.layer3 = nn.Linear(128, n_actions)

    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        return self.layer3(x)

class Node:
    def __init__(self, state, parent=None, action=None):
        self.state = state
        self.parent = parent
        self.action = action
        self.children = []
        self.visits = 0
        self.total_value = 0.0
        self.mean_value = 0.0
        self.prior = 0.0
        self.virtual_loss = 0
        self.terminal = False

    def is_fully_expanded(self):
        return len(self.children) > 0 and all(child.visits > 0 for child in self.children)

    def best_child(self, cpuct=1.0):
        if len(self.children) == 0:
            return None
        values = np.array([child.mean_value for child in self.children])
        if len(np.unique(values)) > 1:
            norm_values = 2 * (values - np.min(values)) / (np.max(values) - np.min(values)) - 1
        else:
            norm_values = np.ones_like(values)
        scores = norm_values + cpuct * np.array([child.prior * np.sqrt(self.visits) / (1 + child.visits) for child in self.children])
        return self.children[np.argmax(scores)]

    def expand(self, action_probs, env):
        for action, prob in enumerate(action_probs):
            if prob > 0:
                env_copy = copy.deepcopy(env)  # Still deepcopy the env to avoid issues
                obs, _ = env_copy.reset()       # Reset to get a fresh lander object
                env_copy.unwrapped.state = self.state.copy() if hasattr(self.state, 'copy') else self.state
                next_state, reward, terminated, truncated, _ = env_copy.step(action)  # Now step should work
                done = terminated or truncated
                child = Node(next_state, self, action)
                child.prior = prob
                child.terminal = done
                self.children.append(child)

    def update(self, value):
        self.visits += 1
        self.total_value += value
        self.mean_value = self.total_value / self.visits

    def update_virtual_loss(self, virtual_loss_weight=0.01):
        self.virtual_loss += virtual_loss_weight * abs(self.mean_value)

class MCTS:
    def __init__(self, policy_net, env, cpuct=1.0, num_simulations=10):
        self.policy_net = policy_net
        self.env = env
        self.cpuct = cpuct
        self.num_simulations = num_simulations
        self.virtual_loss_weight = 0.01

    def search(self, root_state):
        root = Node(root_state)
        with torch.no_grad():
            state_tensor = torch.FloatTensor(root_state).unsqueeze(0).to(device)
            action_probs = torch.softmax(self.policy_net(state_tensor), dim=1).cpu().numpy().flatten()
        root.expand(action_probs, self.env)

        for _ in range(self.num_simulations):
            node = root
            search_path = [node]
            while node.is_fully_expanded() and not node.terminal:
                node = node.best_child(self.cpuct)
                node.update_virtual_loss(self.virtual_loss_weight)
                search_path.append(node)
            if not node.terminal:
                with torch.no_grad():
                    state_tensor = torch.FloatTensor(node.state).unsqueeze(0).to(device)
                    action_probs = torch.softmax(self.policy_net(state_tensor), dim=1).cpu().numpy().flatten()
                node.expand(action_probs, self.env)
                if node.children:
                    node = random.choice([c for c in node.children if c.visits == 0])
                    search_path.append(node)
            value = 0 if node.terminal else self.rollout(node.state)
            for node in reversed(search_path):
                node.update(value)
                node.virtual_loss = 0
        return root

    def rollout(self, state):
        env_copy = copy.deepcopy(self.env)
        _, _ = env_copy.reset()
        env_copy.unwrapped.state = state.copy() if hasattr(state, 'copy') else state
        total_reward = 0.0
        done = False
        for _ in range(30):  # shorten for speed
            with torch.no_grad():
                state_tensor = torch.FloatTensor(env_copy.unwrapped.state).unsqueeze(0).to(device)
                action_probs = torch.softmax(self.policy_net(state_tensor), dim=1).cpu().numpy().flatten()
            action = np.random.choice(len(action_probs), p=action_probs)
            next_state, reward, terminated, truncated, _ = env_copy.step(action)
            done = terminated or truncated
            total_reward += reward
            if done:
                break
        return total_reward

    def get_action(self, state, temperature=1.0):
        root = self.search(state)
        visit_counts = np.array([child.visits for child in root.children])
        actions = [child.action for child in root.children]
        if temperature == 0:
            return actions[np.argmax(visit_counts)]
        visit_probs = visit_counts ** (1 / temperature)
        visit_probs /= np.sum(visit_probs)
        return np.random.choice(actions, p=visit_probs)

class PolicyIterationTrainer:
    def __init__(self, env_name="LunarLander-v3", num_generations=5, games_per_gen=10, num_simulations=10,
                 buffer_size=50000, batch_size=64, lr=1e-4):
        self.env_name = env_name
        self.env = gym.make(env_name)
        self.num_generations = num_generations
        self.games_per_gen = games_per_gen
        self.num_simulations = num_simulations
        self.buffer_size = buffer_size
        self.batch_size = batch_size
        self.lr = lr

        n_obs = self.env.observation_space.shape[0]
        n_actions = self.env.action_space.n
        self.policy_net = DQN(n_obs, n_actions).to(device)
        self.target_net = DQN(n_obs, n_actions).to(device)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=lr)
        self.replay_buffer = deque(maxlen=buffer_size)
        self.mcts = MCTS(self.policy_net, self.env, num_simulations=num_simulations)

        self.loss_history = []
        self.eval_history = []
        self.action_log = []

    def generate_self_play_data(self):
        for _ in range(self.games_per_gen):
            state, _ = self.env.reset()
            done = False
            episode_data = []
            while not done:
                action = self.mcts.get_action(state)
                self.action_log.append(action)
                episode_data.append((state.copy(), action))
                state, _, terminated, truncated, _ = self.env.step(action)
                done = terminated or truncated
            self.replay_buffer.extend(episode_data)

    def train_policy(self):
        if len(self.replay_buffer) < self.batch_size:
            return None
        batch = random.sample(self.replay_buffer, self.batch_size)
        states, actions = zip(*batch)
        states = torch.FloatTensor(np.array(states)).to(device)
        actions = torch.LongTensor(actions).to(device)
        logits = self.policy_net(states)
        loss = F.cross_entropy(logits, actions)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return loss.item()

    def evaluate(self, episodes=5):
        total_reward = 0
        self.policy_net.eval()
        for _ in range(episodes):
            state, _ = self.env.reset()
            done = False
            while not done:
                with torch.no_grad():
                    state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
                    action = torch.argmax(self.policy_net(state_tensor)).item()
                state, reward, terminated, truncated, _ = self.env.step(action)
                done = terminated or truncated
                total_reward += reward
        self.policy_net.train()
        return total_reward / episodes

    def plot_visuals(self):
        plt.figure(figsize=(10, 4))
        plt.subplot(1, 2, 1)
        plt.plot(self.loss_history)
        plt.title("Training Loss")
        plt.subplot(1, 2, 2)
        plt.plot(self.eval_history)
        plt.title("Evaluation Reward")
        plt.tight_layout()
        plt.savefig("training_curves.png")
        plt.close()

        sns.countplot(x=self.action_log)
        plt.title("Action Distribution")
        plt.savefig("action_dist.png")
        plt.close()

    def record_policy_video(self, name):
        env = gym.make(self.env_name, render_mode="rgb_array")
        env = gym.wrappers.RecordVideo(env, video_folder="videos", name_prefix=name)
        state, _ = env.reset()
        done = False
        while not done:
            with torch.no_grad():
                state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
                action = torch.argmax(self.policy_net(state_tensor)).item()
            state, _, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
        env.close()

    def train(self):
        os.makedirs("videos", exist_ok=True)
        self.record_policy_video("initial_policy")

        best_reward = -float("inf")
        for g in range(self.num_generations):
            self.generate_self_play_data()
            loss = self.train_policy()
            reward = self.evaluate()
            if loss is not None:
                self.loss_history.append(loss)
            self.eval_history.append(reward)
            print(f"Gen {g+1}: Loss={loss:.4f} | Eval={reward:.2f}")
            if reward > best_reward:
                best_reward = reward
                torch.save(self.policy_net.state_dict(), "best_policy_net.pth")
                print(">> Best model saved.")
            if g % 5 == 0:
                self.target_net.load_state_dict(self.policy_net.state_dict())

        self.plot_visuals()
        self.record_policy_video("final_policy")

if __name__ == "__main__":
    trainer = PolicyIterationTrainer()
    try:
        trainer.policy_net.load_state_dict(torch.load("/content/dagger_model_iter2.pth"))
        trainer.target_net.load_state_dict(torch.load("/content/dagger_model_iter2.pth"))
        print("Loaded pre-trained model.")
    except:
        print("No pre-trained model found.")
    trainer.train()


Loaded pre-trained model.


  logger.warn(


Gen 1: Loss=40.3473 | Eval=214.49
>> Best model saved.
Gen 2: Loss=37.5454 | Eval=232.15
>> Best model saved.
Gen 3: Loss=28.6853 | Eval=240.97
>> Best model saved.
Gen 4: Loss=33.0428 | Eval=252.18
>> Best model saved.
Gen 5: Loss=36.7560 | Eval=216.21


  logger.warn(
