# DAgger LSTM (Weighted+Filtered) Policy: Training, Evaluation, and Plotting (2000 Episodes)
This notebook combines training, evaluation, and plotting for DAgger LSTM (weighted+filtered) models on the NeedlePick task.

In [None]:
import os
import pickle
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
import matplotlib.pyplot as plt
import csv
from surrol.tasks.needle_pick import NeedlePick

In [None]:
# --- Set random seeds for reproducibility ---
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
def concat_obs(obs):
    return np.concatenate([obs['observation'], obs['achieved_goal'], obs['desired_goal']])

In [None]:
# --- LSTM Policy ---
class LSTMPolicy(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden_size=128, num_layers=1):
        super().__init__()
        self.lstm = nn.LSTM(obs_dim, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, act_dim)
    def forward(self, x, hidden=None):
        lstm_out, hidden = self.lstm(x, hidden)
        out = self.fc(lstm_out)
        return out, hidden

In [None]:
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print("Selected device:", device)

In [None]:
# --- Training function with expert/DAgger weighting ---
def train_lstm_weighted(expert_obs_episodes, expert_act_episodes, dagger_obs_episodes, dagger_act_episodes,
                        expert_weight=0.8, epochs=50, lr=1e-3, hidden_size=128, num_layers=1, val_split=0.1):
    obs_dim = expert_obs_episodes[0].shape[1]
    act_dim = expert_act_episodes[0].shape[1]
    policy = LSTMPolicy(obs_dim, act_dim, hidden_size=hidden_size, num_layers=num_layers).to(device)
    optimizer = optim.Adam(policy.parameters(), lr=lr)
    loss_fn = nn.MSELoss()
    obs_episodes = list(expert_obs_episodes) + list(dagger_obs_episodes)
    act_episodes = list(expert_act_episodes) + list(dagger_act_episodes)
    weights = np.array([expert_weight]*len(expert_obs_episodes) + [1.0-expert_weight]*len(dagger_obs_episodes))
    indices = np.arange(len(obs_episodes))
    np.random.shuffle(indices)
    obs_episodes = [obs_episodes[i] for i in indices]
    act_episodes = [act_episodes[i] for i in indices]
    weights = weights[indices]
    n_val = int(len(obs_episodes) * val_split)
    if n_val > 0:
        val_obs_episodes = obs_episodes[:n_val]
        val_act_episodes = act_episodes[:n_val]
        val_weights = weights[:n_val]
        train_obs_episodes = obs_episodes[n_val:]
        train_act_episodes = act_episodes[n_val:]
        train_weights = weights[n_val:]
    else:
        train_obs_episodes = obs_episodes
        train_act_episodes = act_episodes
        train_weights = weights
        val_obs_episodes = []
        val_act_episodes = []
        val_weights = []
    num_episodes = len(train_obs_episodes)
    losses = []
    val_mse = []
    for epoch in range(epochs):
        policy.train()
        epoch_loss = 0
        total_steps = 0
        indices = np.arange(num_episodes)
        np.random.shuffle(indices)
        for epi_idx in indices:
            obs_seq = torch.tensor(train_obs_episodes[epi_idx], dtype=torch.float32).unsqueeze(0).to(device)
            act_seq = torch.tensor(train_act_episodes[epi_idx], dtype=torch.float32).unsqueeze(0).to(device)
            optimizer.zero_grad()
            act_pred, _ = policy(obs_seq)
            loss = loss_fn(act_pred, act_seq)
            weighted_loss = loss * train_weights[epi_idx]
            weighted_loss.backward()
            optimizer.step()
            epoch_loss += weighted_loss.item() * obs_seq.shape[1]
            total_steps += obs_seq.shape[1]
        avg_loss = epoch_loss / total_steps if total_steps > 0 else 0
        losses.append(avg_loss)
        if val_obs_episodes:
            policy.eval()
            with torch.no_grad():
                val_loss = 0
                val_steps = 0
                for obs_seq, act_seq, w in zip(val_obs_episodes, val_act_episodes, val_weights):
                    obs_seq_t = torch.tensor(obs_seq, dtype=torch.float32).unsqueeze(0).to(device)
                    act_seq_t = torch.tensor(act_seq, dtype=torch.float32).unsqueeze(0).to(device)
                    act_pred, _ = policy(obs_seq_t)
                    l = loss_fn(act_pred, act_seq_t) * w
                    val_loss += l.item() * obs_seq_t.shape[1]
                    val_steps += obs_seq_t.shape[1]
                val_mse.append(val_loss / val_steps if val_steps > 0 else 0)
        if (epoch + 1) % 10 == 0 or epoch == 0 or (epoch + 1) == epochs:
            if val_mse:
                print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.6f}, Val: {val_mse[-1]:.6f}")
            else:
                print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.6f}")
    return policy, losses, val_mse

In [None]:
def make_lstm_episode_dataset(trajectories):
    obs_episodes = []
    act_episodes = []
    for episode in trajectories:
        obs_list = episode['observations']
        act_list = episode['actions']
        obs_arr = np.array([concat_obs(o) for o in obs_list])
        act_arr = np.array(act_list)
        obs_episodes.append(obs_arr)
        act_episodes.append(act_arr)
    return obs_episodes, act_episodes

In [None]:
def collect_dagger_lstm_episodes_filtered(policy, env, num_episodes=5, max_steps=200):
    dagger_obs_episodes = []
    dagger_act_episodes = []
    for ep in range(num_episodes):
        obs = env.reset()
        obs_seq = []
        act_seq = []
        success = False
        for t in range(max_steps):
            obs_in = concat_obs(obs)
            obs_seq.append(obs_in)
            inp = torch.tensor(obs_in, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device)
            with torch.no_grad():
                action, _ = policy(inp)
                action = action.cpu().numpy().squeeze(0).squeeze(0)
            expert_act = env.get_oracle_action(obs)
            act_seq.append(expert_act)
            obs, reward, done, info = env.step(action)
            if info.get("is_success", False):
                success = True
                break
            if done:
                break
        if success:
            dagger_obs_episodes.append(np.array(obs_seq))
            dagger_act_episodes.append(np.array(act_seq))
    return dagger_obs_episodes, dagger_act_episodes

In [None]:
def evaluate_lstm_policy(policy, episodes=10, max_steps=200):
    env = NeedlePick(render_mode=None)
    success_count = 0
    returns = []
    for ep in range(episodes):
        obs = env.reset()
        total_reward = 0
        hidden = None
        for step in range(max_steps):
            obs_in = concat_obs(obs)
            inp = torch.tensor(obs_in, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device)
            with torch.no_grad():
                action_tensor, hidden = policy(inp, hidden)
                action = action_tensor.cpu().numpy().squeeze(0).squeeze(0)
            if hasattr(env, 'action_space'):
                action = np.clip(action, env.action_space.low, env.action_space.high)
            obs, reward, done, info = env.step(action)
            total_reward += reward
            if info.get('is_success', False):
                success_count += 1
                break
            if done:
                break
        returns.append(total_reward)
    avg_return = np.mean(returns)
    success_rate = success_count / episodes
    try:
        env.close()
    except Exception as e:
        print(f"Warning: Exception during env.close(): {e}")
    del env
    return avg_return, success_rate

# 1. Training and Evaluation
Train DAgger LSTM tuned policies with different hyperparameters and save models/metrics. Then, evaluate each trained policy and save results. 

In [None]:
# --- Main workflow ---
out_dir = "lstm_dagger_models_tune"
os.makedirs(out_dir, exist_ok=True)

with open("expert_trajectories.pkl", "rb") as f:
    trajectories = pickle.load(f)

expert_obs_episodes, expert_act_episodes = make_lstm_episode_dataset(trajectories)
print(f"Collected {len(expert_obs_episodes)} expert LSTM episodes.")

obs_dim = expert_obs_episodes[0].shape[1]
act_dim = expert_act_episodes[0].shape[1]

learning_rates = [1e-3, 3e-4, 1e-4]
hidden_sizes = [128, 256]
epochs_list = [75, 150]
num_layers = [1, 2]
dagger_iterations = 5
num_dagger_episodes = 10
max_steps = 200
val_split = 0.1
expert_weight = 0.8

results = []

for lr in learning_rates:
    for hidden_size in hidden_sizes:
        for epochs in epochs_list:
            for layers in num_layers:
                print("\n========================================")
                print(f"LSTM+DAgger (filtered+weighted): lr={lr}, hidden={hidden_size}, epochs={epochs}, layers={layers}")
                epochs_per_iter = epochs // dagger_iterations
                dagger_obs_episodes = []
                dagger_act_episodes = []
                lstm_policy = None
                losses = []
                val_mse = []
                env = NeedlePick(render_mode=None)
                total_epochs = 0
                for i in range(dagger_iterations):
                    print(f"[LSTM DAgger] Iter {i+1}/{dagger_iterations} (epochs {total_epochs+1}-{total_epochs+epochs_per_iter})")
                    policy, iter_losses, iter_val_mse = train_lstm_weighted(
                        expert_obs_episodes, expert_act_episodes,
                        dagger_obs_episodes, dagger_act_episodes,
                        expert_weight=expert_weight,
                        epochs=epochs_per_iter, lr=lr, hidden_size=hidden_size,
                        num_layers=layers, val_split=val_split
                    )
                    losses.extend(iter_losses)
                    val_mse.extend(iter_val_mse)
                    total_epochs += epochs_per_iter
                    new_obs, new_acts = collect_dagger_lstm_episodes_filtered(policy, env, num_episodes=num_dagger_episodes, max_steps=max_steps)
                    print(f"DAgger: Collected {len(new_obs)} successful new episodes.")
                    dagger_obs_episodes.extend(new_obs)
                    dagger_act_episodes.extend(new_acts)
                    lstm_policy = policy
                    print(f"Aggregated dataset: {len(expert_obs_episodes)+len(dagger_obs_episodes)} episodes")
                if env is not None:
                    try:
                        env.close()
                    except Exception as e:
                        print(f"Warning: Exception during env.close(): {e}")
                    del env
                model_name = f"lstm_dagger_lr{lr}_hid{hidden_size}_ep{epochs}_lay{layers}_fullseq"
                torch.save(lstm_policy.state_dict(), os.path.join(out_dir, f"{model_name}.pth"))
                np.save(os.path.join(out_dir, f"{model_name}_train_losses.npy"), np.array(losses))
                np.save(os.path.join(out_dir, f"{model_name}_val_mse.npy"), np.array(val_mse))
                avg_return, success_rate = evaluate_lstm_policy(lstm_policy, episodes=10, max_steps=200)
                np.save(os.path.join(out_dir, f"{model_name}_eval_success_rate.npy"), np.array([success_rate]))
                np.save(os.path.join(out_dir, f"{model_name}_eval_return.npy"), np.array([avg_return]))
                print(f"Saved model and logs to {out_dir}: {model_name}")
                results.append({
                    'model_name': model_name,
                    'lr': lr,
                    'hidden_size': hidden_size,
                    'epochs': epochs,
                    'num_layers': layers,
                    'final_val_mse': val_mse[-1] if val_mse else None,
                    'avg_return': avg_return,
                    'success_rate': success_rate
                })

In [None]:
# Save results to CSV
csv_path = os.path.join(out_dir, "evaluation_results_lstm_dagger.csv")
with open(csv_path, "w", newline="") as csvfile:
    writer = csv.DictWriter(csvfile, fieldnames=["model_name", "lr", "hidden_size", "epochs", "num_layers", "final_val_mse", "avg_return", "success_rate"])
    writer.writeheader()
    writer.writerows(results)

In [None]:
# Print summary table
print("\n===== LSTM+DAgger (Weighted+Filtered) Hyperparameter Comparison Results =====")
for res in sorted(results, key=lambda x: (-(x['success_rate'] or 0), -(x['avg_return'] or 0))):
    print(f"{res['model_name']}: Success={res['success_rate']*100:.1f}%, Return={res['avg_return']:.2f}, Final Val MSE={res['final_val_mse']}")
print(f"\nAll models and metrics are saved in the folder: {out_dir}")

# 2. Plotting: Visualize Training and Evaluation Metrics
Plot training loss, validation MSE, success rate, and episode return for all hyperparameter configurations. 

In [None]:
# ---- Plotting ----
models = [r['model_name'] for r in results]
successes = [r['success_rate'] if r['success_rate'] is not None else 0.0 for r in results]
returns = [r['avg_return'] if r['avg_return'] is not None else 0.0 for r in results]

best_val = min(results, key=lambda r: r['final_val_mse'] if r['final_val_mse'] is not None else float('inf'))
best_success = max(results, key=lambda r: r['success_rate'])
best_return = max(results, key=lambda r: r['avg_return'])

plt.figure(figsize=(10, 6))
plt.plot(np.load(os.path.join(out_dir, f"{best_val['model_name']}_train_losses.npy")), label=f"Best Val MSE ({best_val['model_name']})")
plt.plot(np.load(os.path.join(out_dir, f"{best_success['model_name']}_train_losses.npy")), label=f"Best Success ({best_success['model_name']})")
plt.plot(np.load(os.path.join(out_dir, f"{best_return['model_name']}_train_losses.npy")), label=f"Best Return ({best_return['model_name']})")
plt.title("LSTM+DAgger (Weighted+Filtered) Training Loss (MSE) for Best Configs")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend(fontsize=9)
plt.tight_layout()
plt.savefig(os.path.join(out_dir, "lstm_dagger_hyperparam_training_loss.png"))
plt.show()

plt.figure(figsize=(10, 6))
plt.plot(np.load(os.path.join(out_dir, f"{best_val['model_name']}_val_mse.npy")), label=f"Best Val MSE ({best_val['model_name']})")
plt.plot(np.load(os.path.join(out_dir, f"{best_success['model_name']}_val_mse.npy")), label=f"Best Success ({best_success['model_name']})")
plt.plot(np.load(os.path.join(out_dir, f"{best_return['model_name']}_val_mse.npy")), label=f"Best Return ({best_return['model_name']})")
plt.title("LSTM+DAgger (Weighted+Filtered) Validation MSE for Best Configs")
plt.xlabel("Epoch")
plt.ylabel("MSE")
plt.legend(fontsize=9)
plt.tight_layout()
plt.savefig(os.path.join(out_dir, "lstm_dagger_hyperparam_val_mse.png"))
plt.show()

success_sorted = sorted(zip(successes, models), reverse=True)
success_vals, success_labels = zip(*success_sorted)
plt.figure(figsize=(10, max(6, len(models)*0.3)))
plt.barh(range(len(success_vals)), success_vals, color='skyblue')
plt.yticks(range(len(success_labels)), success_labels, fontsize=7)
plt.xlabel("Success Rate")
plt.title("LSTM+DAgger (Weighted+Filtered) Task Success Rate (All Hyperparams) (2000 episodes)")
plt.xlim([0, 1])
plt.tight_layout()
plt.savefig(os.path.join(out_dir, "lstm_dagger_tuned_success_rate_2000.png"))
plt.show()

return_sorted = sorted(zip(returns, models), reverse=True)
return_vals, return_labels = zip(*return_sorted)
plt.figure(figsize=(10, max(6, len(models)*0.3)))
plt.barh(range(len(return_vals)), return_vals, color='salmon')
plt.yticks(range(len(return_labels)), return_labels, fontsize=7)
plt.xlabel("Episode Return")
plt.title("LSTM+DAgger (Weighted+Filtered) Episode Return (All Hyperparams) (2000 episodes)")
plt.tight_layout()
plt.savefig(os.path.join(out_dir, "lstm_dagger_tuned_episode_return_2000.png"))
plt.show()

print(
    f"Plotted and saved in {out_dir}:\n"
    " - Training loss: lstm_dagger_hyperparam_training_loss.png\n"
    " - Validation MSE: lstm_dagger_hyperparam_val_mse.png\n"
    " - Success rate: lstm_dagger_tuned_success_rate_2000.png\n"
    " - Episode return: lstm_dagger_tuned_episode_return_2000.png"
)