# DAgger MLP Policy: Training, Evaluation, and Plotting (2000 Episodes)
This notebook combines training, evaluation, and plotting for DAgger MLP models on the NeedlePick task.

In [None]:
import os
import pickle
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
import matplotlib.pyplot as plt
import csv
from surrol.tasks.needle_pick import NeedlePick

In [None]:
# --- Set random seeds for reproducibility ---
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
def concat_obs(obs):
    return np.concatenate([obs['observation'], obs['achieved_goal'], obs['desired_goal']])

In [None]:
# --- Policy network (MLP) ---
class MLPPolicy(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden_sizes=(128,128)):
        super().__init__()
        layers = []
        last_dim = obs_dim
        for h in hidden_sizes:
            layers.append(nn.Linear(last_dim, h))
            layers.append(nn.ReLU())
            last_dim = h
        layers.append(nn.Linear(last_dim, act_dim))
        self.net = nn.Sequential(*layers)
    def forward(self, x):
        return self.net(x)

In [None]:
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print("Selected device:", device)

In [None]:
# --- Training function ---
def train_mlp(obs_data, act_data, epochs=50, batch_size=128, lr=1e-3, hidden_sizes=(128,128), val_split=0.1):
    obs_dim = obs_data.shape[1]
    act_dim = act_data.shape[1]
    policy = MLPPolicy(obs_dim, act_dim, hidden_sizes=hidden_sizes).to(device)
    optimizer = optim.Adam(policy.parameters(), lr=lr)
    loss_fn = nn.MSELoss()
    obs_data = np.array(obs_data)
    act_data = np.array(act_data)
    n_val = int(len(obs_data) * val_split)
    if n_val > 0:
        val_obs = torch.tensor(obs_data[:n_val], dtype=torch.float32).to(device)
        val_act = torch.tensor(act_data[:n_val], dtype=torch.float32).to(device)
        train_obs = torch.tensor(obs_data[n_val:], dtype=torch.float32).to(device)
        train_act = torch.tensor(act_data[n_val:], dtype=torch.float32).to(device)
    else:
        train_obs = torch.tensor(obs_data, dtype=torch.float32).to(device)
        train_act = torch.tensor(act_data, dtype=torch.float32).to(device)
        val_obs = val_act = None
    num_samples = train_obs.shape[0]
    losses = []
    val_mse = []
    for epoch in range(epochs):
        policy.train()
        epoch_loss = 0
        indices = np.arange(num_samples)
        np.random.shuffle(indices)
        obs_shuffled = train_obs[indices]
        act_shuffled = train_act[indices]
        for start in range(0, num_samples, batch_size):
            end = start + batch_size
            obs_batch = obs_shuffled[start:end]
            act_batch = act_shuffled[start:end]
            optimizer.zero_grad()
            act_pred = policy(obs_batch)
            loss = loss_fn(act_pred, act_batch)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item() * len(obs_batch)
        epoch_loss /= num_samples
        losses.append(epoch_loss)
        if val_obs is not None:
            policy.eval()
            with torch.no_grad():
                val_pred = policy(val_obs)
                val_loss = loss_fn(val_pred, val_act).item()
                val_mse.append(val_loss)
        if (epoch + 1) % 10 == 0 or epoch == 0 or (epoch + 1) == epochs:
            print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss:.6f}" +
                  (f", Val: {val_mse[-1]:.6f}" if val_mse else ""))
    return policy, losses, val_mse

In [None]:
# --- Collect DAgger data ---
def collect_dagger_data(policy, env, num_episodes=5, max_steps=200):
    new_obs = []
    new_acts = []
    for ep in range(num_episodes):
        obs = env.reset()
        for t in range(max_steps):
            obs_in = concat_obs(obs)
            inp = torch.tensor(obs_in, dtype=torch.float32).unsqueeze(0).to(device)
            with torch.no_grad():
                action_policy = policy(inp).cpu().numpy()[0]
            expert_act = env.get_oracle_action(obs)
            new_obs.append(obs_in)
            new_acts.append(expert_act)
            obs, reward, done, info = env.step(action_policy)
            if done or info.get("is_success", False):
                break
    return new_obs, new_acts

In [None]:
def evaluate_policy(policy, episodes=10, max_steps=200):
    env = NeedlePick(render_mode=None)
    success_count = 0
    returns = []
    for ep in range(episodes):
        obs = env.reset()
        total_reward = 0
        for step in range(max_steps):
            obs_in = np.concatenate([
                obs['observation'],
                obs['achieved_goal'],
                obs['desired_goal']
            ])
            inp = torch.tensor(obs_in, dtype=torch.float32).unsqueeze(0).to(device)
            with torch.no_grad():
                action = policy(inp).cpu().numpy()[0]
            if hasattr(env, 'action_space'):
                action = np.clip(action, env.action_space.low, env.action_space.high)
            obs, reward, done, info = env.step(action)
            total_reward += reward
            if info.get('is_success', False):
                success_count += 1
                break
            if done:
                break
        returns.append(total_reward)
    avg_return = np.mean(returns)
    success_rate = success_count / episodes
    try:
        env.close()
    except Exception as e:
        print(f"Warning: Exception during env.close(): {e}")
    del env
    return avg_return, success_rate

# 1. Training and Evaluation
Train MLP policies with different hyperparameters and save models/metrics. Then, evaluate each trained policy and save results. 

In [None]:
# --- Main workflow ---
out_dir = "mlp_dagger_models"
os.makedirs(out_dir, exist_ok=True)

# Load expert data
with open("expert_trajectories.pkl", "rb") as f:
    trajectories = pickle.load(f)

# Prepare initial dataset
observations = []
actions = []
for episode in trajectories:
    obs_list = episode['observations']
    act_list = episode['actions']
    for obs, act in zip(obs_list, act_list):
        obs_in = concat_obs(obs)
        observations.append(obs_in)
        actions.append(act)
observations = np.array(observations)
actions = np.array(actions)
print(f"Collected {len(observations)} expert steps.")

obs_dim = observations.shape[1]
act_dim = actions.shape[1]

# Hyperparameter grid
learning_rates = [1e-3, 3e-4, 1e-4]
hidden_sizes_list = [(128,128), (256,256)]
batch_sizes = [64, 128]
epochs_list = [75, 150]
dagger_iterations = 5
num_dagger_episodes = 50
max_steps = 200
val_split = 0.1

# Results storage
results = []

# Grid search: Train, DAgger, Evaluate
for lr in learning_rates:
    for hidden_sizes in hidden_sizes_list:
        for batch_size in batch_sizes:
            for epochs in epochs_list:
                print("\n========================================")
                print(f"MLP+DAgger: lr={lr}, hidden={hidden_sizes}, batch={batch_size}, epochs={epochs}")
                epochs_per_iter = epochs // dagger_iterations
                dagger_obs = list(observations)
                dagger_acts = list(actions)
                mlp_policy = None
                losses = []
                val_mse = []
                env = NeedlePick(render_mode=None)
                total_epochs = 0
                for i in range(dagger_iterations):
                    print(f"[MLP DAgger] Iteration {i+1}/{dagger_iterations} (epochs {total_epochs + 1} to {total_epochs + epochs_per_iter})")
                    policy, iter_losses, iter_val_mse = train_mlp(
                        np.array(dagger_obs), np.array(dagger_acts),
                        epochs=epochs_per_iter, batch_size=batch_size, lr=lr,
                        hidden_sizes=hidden_sizes, val_split=val_split
                    )
                    losses.extend(iter_losses)
                    val_mse.extend(iter_val_mse)
                    total_epochs += epochs_per_iter
                    new_obs, new_acts = collect_dagger_data(policy, env, num_episodes=num_dagger_episodes, max_steps=max_steps)
                    dagger_obs.extend(new_obs)
                    dagger_acts.extend(new_acts)
                    combined = list(zip(dagger_obs, dagger_acts))
                    random.shuffle(combined)
                    dagger_obs, dagger_acts = zip(*combined)
                    dagger_obs, dagger_acts = list(dagger_obs), list(dagger_acts)
                    mlp_policy = policy
                    print(f"Aggregated dataset size: {len(dagger_obs)}")
                if env is not None:
                    try:
                        env.close()
                    except Exception as e:
                        print(f"Warning: Exception during env.close(): {e}")
                    del env
                model_name = f"mlp_dagger_lr{lr}_hid{hidden_sizes[0]}_{hidden_sizes[1]}_bs{batch_size}_ep{epochs}"
                torch.save(mlp_policy.state_dict(), os.path.join(out_dir, f"{model_name}.pth"))
                np.save(os.path.join(out_dir, f"{model_name}_train_losses.npy"), np.array(losses))
                np.save(os.path.join(out_dir, f"{model_name}_val_mse.npy"), np.array(val_mse))
                avg_return, success_rate = evaluate_policy(mlp_policy, episodes=10, max_steps=200)
                np.save(os.path.join(out_dir, f"{model_name}_eval_success_rate.npy"), np.array([success_rate]))
                np.save(os.path.join(out_dir, f"{model_name}_eval_return.npy"), np.array([avg_return]))
                print(f"Saved model and logs to {out_dir}: {model_name}")
                results.append({
                    'model_name': model_name,
                    'lr': lr,
                    'hidden_sizes': hidden_sizes,
                    'batch_size': batch_size,
                    'epochs': epochs,
                    'final_val_mse': val_mse[-1] if val_mse else None,
                    'avg_return': avg_return,
                    'success_rate': success_rate
                })

In [None]:
# Save results to CSV
csv_path = os.path.join(out_dir, "evaluation_results_mlp_dagger.csv")
with open(csv_path, "w", newline="") as csvfile:
    writer = csv.DictWriter(csvfile, fieldnames=["model_name", "lr", "hidden_sizes", "batch_size", "epochs", "final_val_mse", "avg_return", "success_rate"])
    writer.writeheader()
    writer.writerows(results)

In [None]:
# Print summary table
print("\n===== MLP+DAgger Hyperparameter Comparison Results =====")
for res in sorted(results, key=lambda x: (-(x['success_rate'] or 0), -(x['avg_return'] or 0))):
    print(f"{res['model_name']}: Success={res['success_rate']*100:.1f}%, Return={res['avg_return']:.2f}, Final Val MSE={res['final_val_mse']}")

# 2. Plotting: Visualize Training and Evaluation Metrics
Plot training loss, validation MSE, success rate, and episode return for all hyperparameter configurations. 

In [None]:
# ---- Plotting ----
models = [r['model_name'] for r in results]
successes = [r['success_rate'] if r['success_rate'] is not None else 0.0 for r in results]
returns = [r['avg_return'] if r['avg_return'] is not None else 0.0 for r in results]

best_val = min(results, key=lambda r: r['final_val_mse'] if r['final_val_mse'] is not None else float('inf'))
best_success = max(results, key=lambda r: r['success_rate'])
best_return = max(results, key=lambda r: r['avg_return'])

plt.figure(figsize=(10, 6))
plt.plot(np.load(os.path.join(out_dir, f"{best_val['model_name']}_train_losses.npy")), label=f"Best Val MSE ({best_val['model_name']})")
plt.plot(np.load(os.path.join(out_dir, f"{best_success['model_name']}_train_losses.npy")), label=f"Best Success ({best_success['model_name']})")
plt.plot(np.load(os.path.join(out_dir, f"{best_return['model_name']}_train_losses.npy")), label=f"Best Return ({best_return['model_name']})")
plt.title("MLP+DAgger Training Loss (MSE) for Best Configs")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend(fontsize=9)
plt.tight_layout()
plt.savefig(os.path.join(out_dir, "mlp_dagger_hyperparam_training_loss.png"))
plt.show()

plt.figure(figsize=(10, 6))
plt.plot(np.load(os.path.join(out_dir, f"{best_val['model_name']}_val_mse.npy")), label=f"Best Val MSE ({best_val['model_name']})")
plt.plot(np.load(os.path.join(out_dir, f"{best_success['model_name']}_val_mse.npy")), label=f"Best Success ({best_success['model_name']})")
plt.plot(np.load(os.path.join(out_dir, f"{best_return['model_name']}_val_mse.npy")), label=f"Best Return ({best_return['model_name']})")
plt.title("MLP+DAgger Validation MSE for Best Configs")
plt.xlabel("Epoch")
plt.ylabel("MSE")
plt.legend(fontsize=9)
plt.tight_layout()
plt.savefig(os.path.join(out_dir, "mlp_dagger_hyperparam_val_mse.png"))
plt.show()

success_sorted = sorted(zip(successes, models), reverse=True)
success_vals, success_labels = zip(*success_sorted)
plt.figure(figsize=(10, max(6, len(models)*0.3)))
plt.barh(range(len(success_vals)), success_vals, color='skyblue')
plt.yticks(range(len(success_labels)), success_labels, fontsize=7)
plt.xlabel("Success Rate")
plt.title("MLP+DAgger Task Success Rate (All Hyperparams) (2000 Episodes)")
plt.xlim([0, 1])
plt.tight_layout()
plt.savefig(os.path.join(out_dir, "mlp_dagger_success_rate_2000.png"))
plt.show()

return_sorted = sorted(zip(returns, models), reverse=True)
return_vals, return_labels = zip(*return_sorted)
plt.figure(figsize=(10, max(6, len(models)*0.3)))
plt.barh(range(len(return_vals)), return_vals, color='salmon')
plt.yticks(range(len(return_labels)), return_labels, fontsize=7)
plt.xlabel("Episode Return")
plt.title("MLP+DAgger Episode Return (All Hyperparams) (2000 Episodes)")
plt.tight_layout()
plt.savefig(os.path.join(out_dir, "mlp_dagger_episode_return_2000.png"))
plt.show()

print(
    f"Plotted and saved in {out_dir}:\n"
    " - Training loss: mlp_dagger_hyperparam_training_loss.png\n"
    " - Validation MSE: mlp_dagger_hyperparam_val_mse.png\n"
    " - Success rate: mlp_dagger_success_rate_2000.png\n"
    " - Episode return: mlp_dagger_episode_return_2000.png"
)