# Best Model: Training, Evaluation and Plotting (100 Episodes)
This notebook combines training, evaluation, and plotting for all best-seed models across multiple seeds. 

In [None]:
import os
import pickle
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
from surrol.tasks.needle_pick import NeedlePick
import csv
import matplotlib.pyplot as plt
from collections import defaultdict
from matplotlib.lines import Line2D

In [None]:
# == CONFIG: Define your best hyperparameters for each model ==
BEST_MODELS = [
    {
        "name": "mlp_best",
        "type": "mlp_bc",
        "params": dict(epochs=75, batch_size=64, lr=0.0003, hidden_sizes=(256,256))
    },
    {
        "name": "lstm_best",
        "type": "lstm_bc",
        "params": dict(epochs=150, lr=0.0003, hidden_size=256, num_layers=2)
    },
    {
        "name": "mlp_dagger_best",
        "type": "mlp_dagger",
        "params": dict(dagger_iters=5, epochs_dagger=30, batch_size=64, lr=0.001, hidden_sizes=(128,128))
    },
    {
        "name": "mlp_dagger_tuned_best",
        "type": "mlp_dagger_tuned",
        "params": dict(dagger_iters=5, epochs_dagger=30, batch_size=64, lr=0.0003, hidden_sizes=(256,256), expert_weight=0.8)
    },
    {
        "name": "lstm_dagger_best",
        "type": "lstm_dagger",
        "params": dict(dagger_iters=5, epochs_dagger=30, lr=0.001, hidden_size=128, num_layers=2)
    },
    {
        "name": "lstm_dagger_tuned_best",
        "type": "lstm_dagger_tuned",
        "params": dict(dagger_iters=5, epochs_dagger=30, lr=0.001, hidden_size=128, num_layers=2, expert_weight=0.8)
    }
]
NUM_SEEDS = 5
SEED_LIST = [0, 1, 2, 3, 4]
OUT_DIR = "best6_multiseed"
os.makedirs(OUT_DIR, exist_ok=True)

In [None]:
def set_all_seeds(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [None]:
def concat_obs(obs):
    return np.concatenate([obs['observation'], obs['achieved_goal'], obs['desired_goal']])

In [None]:
class MLPPolicy(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden_sizes=(128,128)):
        super().__init__()
        layers = []
        input_dim = obs_dim
        for h in hidden_sizes:
            layers.append(nn.Linear(input_dim, h))
            layers.append(nn.ReLU())
            input_dim = h
        layers.append(nn.Linear(input_dim, act_dim))
        self.model = nn.Sequential(*layers)
    def forward(self, x):
        return self.model(x)

In [None]:
class LSTMPolicy(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden_size=128, num_layers=1):
        super().__init__()
        self.lstm = nn.LSTM(obs_dim, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, act_dim)
    def forward(self, x, hidden=None):
        lstm_out, hidden = self.lstm(x, hidden)
        out = self.fc(lstm_out)
        return out, hidden

In [None]:
def make_mlp_dataset(trajectories):
    obs = []
    acts = []
    for episode in trajectories:
        obs.extend([concat_obs(o) for o in episode['observations']])
        acts.extend(episode['actions'])
    return np.array(obs), np.array(acts)

In [None]:
def make_lstm_episode_dataset(trajectories):
    obs_episodes = []
    act_episodes = []
    for episode in trajectories:
        obs_arr = np.array([concat_obs(o) for o in episode['observations']])
        act_arr = np.array(episode['actions'])
        obs_episodes.append(obs_arr)
        act_episodes.append(act_arr)
    return obs_episodes, act_episodes

In [None]:
def train_mlp_bc(obs, acts, epochs=75, batch_size=128, lr=0.0001, hidden_sizes=(128, 128), val_split=0.1):
    obs = np.array(obs)
    acts = np.array(acts)
    obs_dim = obs.shape[1]
    act_dim = acts.shape[1]
    policy = MLPPolicy(obs_dim, act_dim, hidden_sizes=hidden_sizes).to(device)
    optimizer = optim.Adam(policy.parameters(), lr=lr)
    loss_fn = nn.MSELoss()
    num_samples = len(obs)
    idxs = np.arange(num_samples)
    np.random.shuffle(idxs)
    n_val = int(val_split * num_samples)
    val_idx = idxs[:n_val]
    train_idx = idxs[n_val:]
    train_obs, train_acts = obs[train_idx], acts[train_idx]
    val_obs, val_acts = obs[val_idx], acts[val_idx]
    train_obs = torch.tensor(train_obs, dtype=torch.float32).to(device)
    train_acts = torch.tensor(train_acts, dtype=torch.float32).to(device)
    val_obs = torch.tensor(val_obs, dtype=torch.float32).to(device) if n_val > 0 else None
    val_acts = torch.tensor(val_acts, dtype=torch.float32).to(device) if n_val > 0 else None
    losses = []
    val_mse = []
    for epoch in range(epochs):
        policy.train()
        epoch_loss = 0
        idxs = np.random.permutation(train_obs.shape[0])
        obs_shuffled = train_obs[idxs]
        acts_shuffled = train_acts[idxs]
        for i in range(0, len(obs_shuffled), batch_size):
            obs_batch = obs_shuffled[i:i+batch_size]
            act_batch = acts_shuffled[i:i+batch_size]
            optimizer.zero_grad()
            act_pred = policy(obs_batch)
            loss = loss_fn(act_pred, act_batch)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item() * obs_batch.shape[0]
        avg_loss = epoch_loss / len(obs_shuffled)
        losses.append(avg_loss)
        # Validation
        if val_obs is not None:
            policy.eval()
            with torch.no_grad():
                val_pred = policy(val_obs)
                val_loss = loss_fn(val_pred, val_acts).item()
                val_mse.append(val_loss)
        if (epoch + 1) % 10 == 0 or epoch == 0 or (epoch + 1) == epochs:
            if val_mse:
                print(f"    Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.6f}, Val: {val_mse[-1]:.6f}")
            else:
                print(f"    Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.6f}")
    return policy, losses, val_mse

In [None]:
def train_mlp_weighted(expert_obs, expert_acts, dagger_obs, dagger_acts,
                       expert_weight=0.8, epochs=150, batch_size=128, lr=0.001, hidden_sizes=(128,128), val_split=0.1):
    obs = np.vstack([expert_obs, dagger_obs]) if len(dagger_obs) else expert_obs
    acts = np.vstack([expert_acts, dagger_acts]) if len(dagger_acts) else expert_acts
    weights = np.array([expert_weight]*len(expert_obs) + [1.0-expert_weight]*len(dagger_obs))
    obs_dim = obs.shape[1]
    act_dim = acts.shape[1]
    policy = MLPPolicy(obs_dim, act_dim, hidden_sizes=hidden_sizes).to(device)
    optimizer = optim.Adam(policy.parameters(), lr=lr)
    loss_fn = nn.MSELoss()
    num_samples = len(obs)
    idxs = np.arange(num_samples)
    np.random.shuffle(idxs)
    n_val = int(val_split * num_samples)
    train_obs, train_acts, train_weights = obs[n_val:], acts[n_val:], weights[n_val:]
    val_obs, val_acts, val_weights = obs[:n_val], acts[:n_val], weights[:n_val]
    train_obs = torch.tensor(train_obs, dtype=torch.float32).to(device)
    train_acts = torch.tensor(train_acts, dtype=torch.float32).to(device)
    train_weights = torch.tensor(train_weights, dtype=torch.float32).to(device)
    val_obs = torch.tensor(val_obs, dtype=torch.float32).to(device) if n_val > 0 else None
    val_acts = torch.tensor(val_acts, dtype=torch.float32).to(device) if n_val > 0 else None
    val_weights = torch.tensor(val_weights, dtype=torch.float32).to(device) if n_val > 0 else None
    losses = []
    val_mse = []
    for epoch in range(epochs):
        policy.train()
        epoch_loss = 0
        idxs = np.random.permutation(train_obs.shape[0])
        obs_shuffled = train_obs[idxs]
        acts_shuffled = train_acts[idxs]
        weights_shuffled = train_weights[idxs]
        for i in range(0, len(obs_shuffled), batch_size):
            obs_batch = obs_shuffled[i:i+batch_size]
            act_batch = acts_shuffled[i:i+batch_size]
            w_batch = weights_shuffled[i:i+batch_size]
            optimizer.zero_grad()
            act_pred = policy(obs_batch)
            loss = loss_fn(act_pred, act_batch)
            weighted_loss = (loss * w_batch.mean())
            weighted_loss.backward()
            optimizer.step()
            epoch_loss += weighted_loss.item() * obs_batch.shape[0]
        avg_loss = epoch_loss / len(obs_shuffled)
        losses.append(avg_loss)
        # Validation
        if val_obs is not None:
            policy.eval()
            with torch.no_grad():
                val_pred = policy(val_obs)
                val_loss = loss_fn(val_pred, val_acts).item()
                val_mse.append(val_loss)
        if (epoch + 1) % 10 == 0 or epoch == 0 or (epoch + 1) == epochs:
            if val_mse:
                print(f"    Epoch {epoch+1}/{epochs}, Weighted Loss: {avg_loss:.6f}, Val: {val_mse[-1]:.6f}")
            else:
                print(f"    Epoch {epoch+1}/{epochs}, Weighted Loss: {avg_loss:.6f}")
    return policy, losses, val_mse

In [None]:
def train_lstm_bc(obs_episodes, act_episodes, epochs=150, lr=0.001, hidden_size=256, num_layers=2, val_split=0.1):
    obs_dim = obs_episodes[0].shape[1]
    act_dim = act_episodes[0].shape[1]
    policy = LSTMPolicy(obs_dim, act_dim, hidden_size=hidden_size, num_layers=num_layers).to(device)
    optimizer = optim.Adam(policy.parameters(), lr=lr)
    loss_fn = nn.MSELoss()
    n_val = int(len(obs_episodes) * val_split)
    train_obs_episodes = obs_episodes[n_val:]
    train_act_episodes = act_episodes[n_val:]
    val_obs_episodes = obs_episodes[:n_val] if n_val > 0 else []
    val_act_episodes = act_episodes[:n_val] if n_val > 0 else []
    num_episodes = len(train_obs_episodes)
    losses = []
    val_mse = []
    for epoch in range(epochs):
        policy.train()
        epoch_loss = 0
        total_steps = 0
        idxs = np.random.permutation(num_episodes)
        for epi_idx in idxs:
            obs_seq = torch.tensor(train_obs_episodes[epi_idx], dtype=torch.float32).unsqueeze(0).to(device)
            act_seq = torch.tensor(train_act_episodes[epi_idx], dtype=torch.float32).unsqueeze(0).to(device)
            optimizer.zero_grad()
            act_pred, _ = policy(obs_seq)
            loss = loss_fn(act_pred, act_seq)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item() * obs_seq.shape[1]
            total_steps += obs_seq.shape[1]
        avg_loss = epoch_loss / total_steps if total_steps > 0 else 0
        losses.append(avg_loss)
        # Validation
        if val_obs_episodes:
            policy.eval()
            with torch.no_grad():
                val_loss = 0
                val_steps = 0
                for obs_seq, act_seq in zip(val_obs_episodes, val_act_episodes):
                    obs_seq_t = torch.tensor(obs_seq, dtype=torch.float32).unsqueeze(0).to(device)
                    act_seq_t = torch.tensor(act_seq, dtype=torch.float32).unsqueeze(0).to(device)
                    act_pred, _ = policy(obs_seq_t)
                    l = loss_fn(act_pred, act_seq_t)
                    val_loss += l.item() * obs_seq_t.shape[1]
                    val_steps += obs_seq_t.shape[1]
                val_mse.append(val_loss / val_steps if val_steps > 0 else 0)
        if (epoch + 1) % 10 == 0 or epoch == 0 or (epoch + 1) == epochs:
            if val_mse:
                print(f"    Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.6f}, Val: {val_mse[-1]:.6f}")
            else:
                print(f"    Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.6f}")
    return policy, losses, val_mse

In [None]:
def train_lstm_weighted(expert_obs_episodes, expert_act_episodes, dagger_obs_episodes, dagger_act_episodes,
                        expert_weight=0.8, epochs=150, lr=0.001, hidden_size=128, num_layers=1, val_split=0.1):
    obs_dim = expert_obs_episodes[0].shape[1]
    act_dim = expert_act_episodes[0].shape[1]
    policy = LSTMPolicy(obs_dim, act_dim, hidden_size=hidden_size, num_layers=num_layers).to(device)
    optimizer = optim.Adam(policy.parameters(), lr=lr)
    loss_fn = nn.MSELoss()
    obs_episodes = list(expert_obs_episodes) + list(dagger_obs_episodes)
    act_episodes = list(expert_act_episodes) + list(dagger_act_episodes)
    weights = np.array([expert_weight]*len(expert_obs_episodes) + [1.0-expert_weight]*len(dagger_obs_episodes))
    indices = np.arange(len(obs_episodes))
    np.random.shuffle(indices)
    obs_episodes = [obs_episodes[i] for i in indices]
    act_episodes = [act_episodes[i] for i in indices]
    weights = weights[indices]
    n_val = int(len(obs_episodes) * val_split)
    train_obs_episodes = obs_episodes[n_val:]
    train_act_episodes = act_episodes[n_val:]
    train_weights = weights[n_val:]
    num_episodes = len(train_obs_episodes)
    losses = []
    val_mse = []
    for epoch in range(epochs):
        policy.train()
        epoch_loss = 0
        total_steps = 0
        indices = np.arange(num_episodes)
        np.random.shuffle(indices)
        for epi_idx in indices:
            obs_seq = torch.tensor(train_obs_episodes[epi_idx], dtype=torch.float32).unsqueeze(0).to(device)
            act_seq = torch.tensor(train_act_episodes[epi_idx], dtype=torch.float32).unsqueeze(0).to(device)
            optimizer.zero_grad()
            act_pred, _ = policy(obs_seq)
            loss = loss_fn(act_pred, act_seq)
            weighted_loss = loss * train_weights[epi_idx]
            weighted_loss.backward()
            optimizer.step()
            epoch_loss += weighted_loss.item() * obs_seq.shape[1]
            total_steps += obs_seq.shape[1]
        avg_loss = epoch_loss / total_steps if total_steps > 0 else 0
        losses.append(avg_loss)
        # Validation
        if n_val > 0:
            policy.eval()
            with torch.no_grad():
                val_loss = 0
                val_steps = 0
                val_obs_episodes = obs_episodes[:n_val]
                val_act_episodes = act_episodes[:n_val]
                val_weights = weights[:n_val]
                for obs_seq, act_seq, w in zip(val_obs_episodes, val_act_episodes, val_weights):
                    obs_seq_t = torch.tensor(obs_seq, dtype=torch.float32).unsqueeze(0).to(device)
                    act_seq_t = torch.tensor(act_seq, dtype=torch.float32).unsqueeze(0).to(device)
                    act_pred, _ = policy(obs_seq_t)
                    l = loss_fn(act_pred, act_seq_t) * w
                    val_loss += l.item() * obs_seq_t.shape[1]
                    val_steps += obs_seq_t.shape[1]
                val_mse.append(val_loss / val_steps if val_steps > 0 else 0)
        if (epoch + 1) % 10 == 0 or epoch == 0 or (epoch + 1) == epochs:
            if val_mse:
                print(f"    Epoch {epoch+1}/{epochs}, Weighted Loss: {avg_loss:.6f}, Val: {val_mse[-1]:.6f}")
            else:
                print(f"    Epoch {epoch+1}/{epochs}, Weighted Loss: {avg_loss:.6f}")
    return policy, losses, val_mse

In [None]:
def collect_dagger_mlp_episodes_filtered(policy, env, num_episodes=5, max_steps=200):
    dagger_obs = []
    dagger_acts = []
    for ep in range(num_episodes):
        obs = env.reset()
        obs_seq = []
        act_seq = []
        success = False
        for t in range(max_steps):
            obs_in = concat_obs(obs)
            obs_seq.append(obs_in)
            inp = torch.tensor(obs_in, dtype=torch.float32).unsqueeze(0).to(device)
            with torch.no_grad():
                action = policy(inp).cpu().numpy().squeeze(0)
            expert_act = env.get_oracle_action(obs)
            act_seq.append(expert_act)
            obs, reward, done, info = env.step(action)
            if info.get("is_success", False):
                success = True
                break
            if done:
                break
        if success:
            dagger_obs.extend(obs_seq)
            dagger_acts.extend(act_seq)
    return np.array(dagger_obs), np.array(dagger_acts)

In [None]:
def collect_dagger_lstm_episodes_filtered(policy, env, num_episodes=5, max_steps=200):
    dagger_obs_episodes = []
    dagger_act_episodes = []
    for ep in range(num_episodes):
        obs = env.reset()
        obs_seq = []
        act_seq = []
        success = False
        for t in range(max_steps):
            obs_in = concat_obs(obs)
            obs_seq.append(obs_in)
            inp = torch.tensor(obs_in, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device)
            with torch.no_grad():
                action, _ = policy(inp)
                action = action.cpu().numpy().squeeze(0).squeeze(0)
            expert_act = env.get_oracle_action(obs)
            act_seq.append(expert_act)
            obs, reward, done, info = env.step(action)
            if info.get("is_success", False):
                success = True
                break
            if done:
                break
        if success:
            dagger_obs_episodes.append(np.array(obs_seq))
            dagger_act_episodes.append(np.array(act_seq))
    return dagger_obs_episodes, dagger_act_episodes

In [None]:
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch. device('mps' if torch.backends.mps.is_available() else 'cpu')
print("Selected device:", device)

## Training: Train all models for all seeds and save checkpoints.

In [None]:
if __name__ == "__main__":
    device = torch.device('mps' if torch.cuda.is_available() else "cpu")
    with open("expert_trajectories.pkl", "rb") as f:
        expert_trajectories = pickle.load(f)
    obs, acts = make_mlp_dataset(expert_trajectories)
    obs_dim = obs.shape[1]
    act_dim = acts.shape[1]
    obs_episodes, act_episodes = make_lstm_episode_dataset(expert_trajectories)

    for model_cfg in BEST_MODELS:
        print(f"\n========== Training model: {model_cfg['name']} ==========")
        for seed in SEED_LIST:
            print(f"[{model_cfg['name']}] Seed {seed} ...")
            set_all_seeds(seed)
            model_base = f"{model_cfg['name']}_seed{seed}"
            train_losses, val_mse = None, None

            if model_cfg['type'] == "mlp_bc":
                policy, train_losses, val_mse = train_mlp_bc(
                    obs, acts,
                    epochs=model_cfg['params']['epochs'],
                    batch_size=model_cfg['params']['batch_size'],
                    lr=model_cfg['params']['lr'],
                    hidden_sizes=model_cfg['params']['hidden_sizes']
                )
            elif model_cfg['type'] == "lstm_bc":
                policy, train_losses, val_mse = train_lstm_bc(
                    obs_episodes, act_episodes,
                    epochs=model_cfg['params']['epochs'],
                    lr=model_cfg['params']['lr'],
                    hidden_size=model_cfg['params']['hidden_size'],
                    num_layers=model_cfg['params']['num_layers']
                )
            elif model_cfg['type'] == "mlp_dagger":
                dagger_obs = np.empty((0, obs.shape[1]), dtype=np.float32)
                dagger_acts = np.empty((0, acts.shape[1]), dtype=np.float32)
                env = NeedlePick(render_mode=None)
                mlp_policy = None
                all_train_losses, all_val_mse = [], []
                for i in range(model_cfg['params']['dagger_iters']):
                    print(f"  [MLP DAgger] Iter {i+1}/{model_cfg['params']['dagger_iters']}")
                    mlp_policy, train_losses, val_mse = train_mlp_bc(
                        np.vstack([obs, dagger_obs]) if dagger_obs.shape[0] else obs,
                        np.vstack([acts, dagger_acts]) if dagger_acts.shape[0] else acts,
                        epochs=model_cfg['params']['epochs_dagger'],
                        batch_size=model_cfg['params']['batch_size'],
                        lr=model_cfg['params']['lr'],
                        hidden_sizes=model_cfg['params']['hidden_sizes'])
                    if train_losses is not None:
                        all_train_losses += list(train_losses)
                    if val_mse is not None:
                        all_val_mse += list(val_mse)
                    new_obs, new_acts = collect_dagger_mlp_episodes_filtered(
                        mlp_policy, env, num_episodes=10, max_steps=200)
                    if new_obs.shape[0] > 0:
                        dagger_obs = np.vstack([dagger_obs, new_obs])
                        dagger_acts = np.vstack([dagger_acts, new_acts])
                    print(f"    DAgger: Collected {new_obs.shape[0]} new episodes. Aggregated dataset: {obs.shape[0] + dagger_obs.shape[0]} samples")
                policy = mlp_policy
                train_losses = all_train_losses
                val_mse = all_val_mse
                del env
            elif model_cfg['type'] == "mlp_dagger_tuned":
                dagger_obs = np.empty((0, obs.shape[1]), dtype=np.float32)
                dagger_acts = np.empty((0, acts.shape[1]), dtype=np.float32)
                env = NeedlePick(render_mode=None)
                mlp_policy = None
                all_train_losses, all_val_mse = [], []
                for i in range(model_cfg['params']['dagger_iters']):
                    print(f"  [MLP DAgger Tuned] Iter {i+1}/{model_cfg['params']['dagger_iters']}")
                    mlp_policy, train_losses, val_mse = train_mlp_weighted(
                        obs, acts, dagger_obs, dagger_acts,
                        expert_weight=model_cfg['params']['expert_weight'],
                        epochs=model_cfg['params']['epochs_dagger'],
                        batch_size=model_cfg['params']['batch_size'],
                        lr=model_cfg['params']['lr'],
                        hidden_sizes=model_cfg['params']['hidden_sizes'])
                    if train_losses is not None:
                        all_train_losses += list(train_losses)
                    if val_mse is not None:
                        all_val_mse += list(val_mse)
                    new_obs, new_acts = collect_dagger_mlp_episodes_filtered(
                        mlp_policy, env, num_episodes=10, max_steps=200)
                    if new_obs.shape[0] > 0:
                        dagger_obs = np.vstack([dagger_obs, new_obs])
                        dagger_acts = np.vstack([dagger_acts, new_acts])
                    print(f"    DAgger: Collected {new_obs.shape[0]} new episodes. Aggregated dataset: {obs.shape[0] + dagger_obs.shape[0]} samples")
                policy = mlp_policy
                train_losses = all_train_losses
                val_mse = all_val_mse
                del env
            elif model_cfg['type'] == "lstm_dagger":
                expert_obs_episodes, expert_act_episodes = make_lstm_episode_dataset(expert_trajectories)
                dagger_obs_episodes = []
                dagger_act_episodes = []
                env = NeedlePick(render_mode=None)
                lstm_policy = None
                all_train_losses, all_val_mse = [], []
                for i in range(model_cfg['params']['dagger_iters']):
                    print(f"  [LSTM DAgger] Iter {i+1}/{model_cfg['params']['dagger_iters']}")
                    lstm_policy, train_losses, val_mse = train_lstm_bc(
                        expert_obs_episodes + dagger_obs_episodes,
                        expert_act_episodes + dagger_act_episodes,
                        epochs=model_cfg['params']['epochs_dagger'],
                        lr=model_cfg['params']['lr'],
                        hidden_size=model_cfg['params']['hidden_size'],
                        num_layers=model_cfg['params']['num_layers'])
                    if train_losses is not None:
                        all_train_losses += list(train_losses)
                    if val_mse is not None:
                        all_val_mse += list(val_mse)
                    new_obs, new_acts = collect_dagger_lstm_episodes_filtered(
                        lstm_policy, env, num_episodes=10, max_steps=200)
                    dagger_obs_episodes += new_obs
                    dagger_act_episodes += new_acts
                    print(f"    DAgger: Collected {len(new_obs)} new episodes. Aggregated dataset: {len(expert_obs_episodes)+len(dagger_obs_episodes)} episodes")
                policy = lstm_policy
                train_losses = all_train_losses
                val_mse = all_val_mse
                del env
            elif model_cfg['type'] == "lstm_dagger_tuned":
                expert_obs_episodes, expert_act_episodes = make_lstm_episode_dataset(expert_trajectories)
                dagger_obs_episodes = []
                dagger_act_episodes = []
                env = NeedlePick(render_mode=None)
                lstm_policy = None
                all_train_losses, all_val_mse = [], []
                for i in range(model_cfg['params']['dagger_iters']):
                    print(f"  [LSTM DAgger Tuned] Iter {i+1}/{model_cfg['params']['dagger_iters']}")
                    lstm_policy, train_losses, val_mse = train_lstm_weighted(
                        expert_obs_episodes, expert_act_episodes,
                        dagger_obs_episodes, dagger_act_episodes,
                        expert_weight=model_cfg['params']['expert_weight'],
                        epochs=model_cfg['params']['epochs_dagger'],
                        lr=model_cfg['params']['lr'],
                        hidden_size=model_cfg['params']['hidden_size'],
                        num_layers=model_cfg['params']['num_layers'])
                    if train_losses is not None:
                        all_train_losses += list(train_losses)
                    if val_mse is not None:
                        all_val_mse += list(val_mse)
                    new_obs, new_acts = collect_dagger_lstm_episodes_filtered(
                        lstm_policy, env, num_episodes=10, max_steps=200)
                    dagger_obs_episodes += new_obs
                    dagger_act_episodes += new_acts
                    print(f"    DAgger: Collected {len(new_obs)} new episodes. Aggregated dataset: {len(expert_obs_episodes)+len(dagger_obs_episodes)} episodes")
                policy = lstm_policy
                train_losses = all_train_losses
                val_mse = all_val_mse
                del env
            else:
                raise ValueError(f"Unknown model type {model_cfg['type']}")

            # ---- SAVE LOSSES & VAL ----
            if train_losses is not None:
                np.save(os.path.join(OUT_DIR, f"{model_base}_train_losses.npy"), np.array(train_losses))
            if val_mse is not None:
                np.save(os.path.join(OUT_DIR, f"{model_base}_val_mse.npy"), np.array(val_mse))

            torch.save(policy.state_dict(), os.path.join(OUT_DIR, f"{model_base}.pth"))
        print(f"\nAll models for {model_cfg['name']} trained and saved in {OUT_DIR}")

## Evaluation: Evaluate all models for all seeds and save metrics

In [None]:
# --- Evaluation setup for all models and all seeds ---
results = []
with open("expert_trajectories.pkl", "rb") as f:
    trajectories = pickle.load(f)
obs_example = trajectories[0]['observations'][0]
obs_dim = (
    obs_example['observation'].shape[0] +
    obs_example['achieved_goal'].shape[0] +
    obs_example['desired_goal'].shape[0]
)
act_dim = trajectories[0]['actions'][0].shape[0]

def evaluate_policy(policy, model_type, episodes=10, max_steps=200, save_traj_path=None):
    env = NeedlePick(render_mode=None)
    success_count = 0
    returns = []
    all_trajectories = []
    for ep in range(episodes):
        obs = env.reset()
        total_reward = 0
        hidden = None
        traj = {'obs': [], 'actions': [], 'rewards': [], 'infos': []}
        for step in range(max_steps):
            obs_in = concat_obs(obs)
            if model_type.startswith("mlp"):
                inp = torch.tensor(obs_in, dtype=torch.float32).unsqueeze(0).to(device)
                with torch.no_grad():
                    action = policy(inp).cpu().numpy().squeeze(0)
            else:
                inp = torch.tensor(obs_in, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device)
                with torch.no_grad():
                    action_tensor, hidden = policy(inp, hidden)
                    action = action_tensor.cpu().numpy().squeeze(0).squeeze(0)
            if hasattr(env, 'action_space'):
                action = np.clip(action, env.action_space.low, env.action_space.high)
            obs_next, reward, done, info = env.step(action)
            total_reward += reward
            traj['obs'].append(obs)
            traj['actions'].append(action)
            traj['rewards'].append(reward)
            traj['infos'].append(info)
            obs = obs_next
            if info.get('is_success', False):
                success_count += 1
                break
            if done:
                break
        returns.append(total_reward)
        all_trajectories.append(traj)
    if save_traj_path is not None:
        with open(save_traj_path, "wb") as f:
            pickle.dump(all_trajectories, f)
    avg_return = np.mean(returns)
    success_rate = success_count / episodes
    try:
        env.close()
    except Exception as e:
        print(f"Warning: Exception during env.close(): {e}")
    del env
    return avg_return, success_rate

for model_cfg in BEST_MODELS:
    for seed in SEED_LIST:
        model_name = f"{model_cfg['name']}_seed{seed}"
        model_path = os.path.join(OUT_DIR, f"{model_name}.pth")
        if not os.path.exists(model_path):
            print(f"Model {model_name} not found in {OUT_DIR}. Skipping...")
            continue
        print(f"\nEvaluating {model_name}")
        if model_cfg["type"].startswith("mlp"):
            policy = MLPPolicy(obs_dim, act_dim, hidden_sizes=model_cfg["params"]["hidden_sizes"]).to(device)
        else:
            hidden_size = model_cfg["params"].get("hidden_size", 128)
            num_layers = model_cfg["params"].get("num_layers", 1)
            policy = LSTMPolicy(obs_dim, act_dim, hidden_size=hidden_size, num_layers=num_layers).to(device)
        policy.load_state_dict(torch.load(model_path, map_location=device))
        policy.eval()
        traj_save_path = os.path.join(OUT_DIR, f"{model_name}_eval_traj.pkl")
        avg_return, success_rate = evaluate_policy(
            policy, model_cfg["type"], episodes=10, max_steps=200, save_traj_path=traj_save_path)
        print(f"Result: Success rate: {success_rate*100:.1f}%, Avg return: {avg_return:.2f}")
        results.append({
            'model_name': model_name,
            'model_type': model_cfg["type"],
            'seed': seed,
            'avg_return': avg_return,
            'success_rate': success_rate,
            **model_cfg["params"],
        })
        np.save(os.path.join(OUT_DIR, f"{model_name}_eval_success_rate.npy"), np.array([success_rate]))
        np.save(os.path.join(OUT_DIR, f"{model_name}_eval_return.npy"), np.array([avg_return]))

# --- Save results to CSV ---
csv_path = os.path.join(OUT_DIR, "evaluation_results_best6_multiseed.csv")
columns = ["model_name", "model_type", "seed",
           "avg_return", "success_rate",
           "epochs", "batch_size", "lr", "hidden_sizes", "hidden_size", "num_layers",
           "dagger_iters", "epochs_dagger", "expert_weight"]
with open(csv_path, "w", newline="") as csvfile:
    writer = csv.DictWriter(csvfile, fieldnames=columns)
    writer.writeheader()
    writer.writerows(results)

# --- Print summary across seeds for each model ---
agg = defaultdict(list)
for res in results:
    model_base = res['model_name'].rsplit("_seed", 1)[0]
    agg[model_base].append(res)

print("\n===== BEST6 Multiseed Evaluation Results (ALL SEEDS) =====")
for model_base, group in agg.items():
    sr = np.array([r['success_rate'] for r in group])
    ret = np.array([r['avg_return'] for r in group])
    print(f"{model_base}:")
    print(f"    Success Rate: {sr.mean()*100:.2f}% ± {sr.std()*100:.2f}%")
    print(f"    Avg Return: {ret.mean():.2f} ± {ret.std():.2f}")

print("\n===== All Individual Evaluation Results =====")
for res in sorted(results, key=lambda x: (-x['success_rate'], -x['avg_return'])):
    print(f"{res['model_name']}: Success={res['success_rate']*100:.1f}%, Return={res['avg_return']:.2f}")

print(f"\nSaved evaluation results and metrics to folder: {OUT_DIR}")
print("Trajectory for each evaluation is saved as *_eval_traj.pkl")

## Plotting: Visualize results across all models and seeds

In [None]:
# ---- Directory where models/metrics are saved ----
out_dir = OUT_DIR

In [None]:
# ---- Model configs (must match your training/eval code) ----
SEED_LIST = [0, 1, 2, 3, 4]

In [None]:
# ---- Collect results ----
results = []
for model_cfg in BEST_MODELS:
    for seed in SEED_LIST:
        model_name = f"{model_cfg['name']}_seed{seed}"
        res = {'model_name': model_name, 'model_type': model_cfg['type'], 'seed': seed, **model_cfg['params']}
        # Try to find and load metrics if available for plotting (train/val loss may not be saved for all)
        for metric in ["train_losses", "val_mse", "eval_success_rate", "eval_return"]:
            metric_path = os.path.join(out_dir, f"{model_name}_{metric}.npy")
            if os.path.exists(metric_path):
                res[metric] = np.load(metric_path)
            else:
                res[metric] = None
        results.append(res)

In [None]:
if not results:
    print("No results found, check that your metrics files exist and names match.")
    exit()

In [None]:
# ---- Group by model type for summary stats ----
grouped = defaultdict(list)
for r in results:
    grouped[r['model_name'].rsplit("_seed", 1)[0]].append(r)

In [None]:
# ---- Compute mean/std for each model config ----
summary = []
for model_base, group in grouped.items():
    n = len(group)
    success_rates = [float(g['eval_success_rate'][0]) for g in group if g['eval_success_rate'] is not None]
    returns = [float(g['eval_return'][0]) for g in group if g['eval_return'] is not None]
    # For training/val loss, average curves if available
    train_curves = [g['train_losses'] for g in group if g['train_losses'] is not None]
    val_curves = [g['val_mse'] for g in group if g['val_mse'] is not None]
    summary.append({
        'model_base': model_base,
        'model_type': group[0]['model_type'],
        'success_rate_mean': np.mean(success_rates) if success_rates else None,
        'success_rate_std': np.std(success_rates) if success_rates else None,
        'return_mean': np.mean(returns) if returns else None,
        'return_std': np.std(returns) if returns else None,
        'success_rates': success_rates,
        'returns': returns,
        'train_curves': train_curves,
        'val_curves': val_curves
    })

In [None]:
# ---- 1. Bar plot: Success Rate for all model configs (mean ± std) ----
plt.figure(figsize=(10, max(6, len(summary)*0.5)))
sorted_summary = sorted(summary, key=lambda s: -(s['success_rate_mean'] if s['success_rate_mean'] is not None else -1))
labels = [s['model_base'] for s in sorted_summary]
means = [s['success_rate_mean'] for s in sorted_summary]
stds = [s['success_rate_std'] for s in sorted_summary]
plt.barh(range(len(means)), means, xerr=stds, color='skyblue')
plt.yticks(range(len(labels)), labels, fontsize=9)
plt.xlabel("Success Rate (mean ± std, N=5 seeds)")
plt.title("Task Success Rate (100 Demonstration Episodes)")
plt.xlim([0, 1])
plt.tight_layout()
plt.savefig(os.path.join(out_dir, "best6_multiseed_success_rate.png"))
plt.show()

In [None]:
# ---- 1b. Also plot each individual seed as points on the bar plot for Success Rate ----
plt.figure(figsize=(10, max(6, len(summary)*0.5)))
plt.barh(range(len(means)), means, xerr=stds, color='skyblue', alpha=0.7, label="Mean ± std")
for i, s in enumerate(sorted_summary):
    plt.scatter([val for val in s['success_rates']], [i]*len(s['success_rates']), color='k', marker='|', s=100, label="Seed values" if i==0 else "")
plt.yticks(range(len(labels)), labels, fontsize=9)
plt.xlabel("Success Rate")
plt.title("Task Success Rate (100 Demonstration Episodes)")
plt.xlim([0, 1])
plt.legend()
plt.tight_layout()
plt.savefig(os.path.join(out_dir, "best6_multiseed_success_rate_with_seeds.png"))
plt.show()

In [None]:
# ---- 2. Bar plot: Return for all model configs (mean ± std) ----
plt.figure(figsize=(10, max(6, len(summary)*0.5)))
means = [s['return_mean'] for s in sorted_summary]
stds = [s['return_std'] for s in sorted_summary]
plt.barh(range(len(means)), means, xerr=stds, color='salmon')
plt.yticks(range(len(labels)), labels, fontsize=9)
plt.xlabel("Episode Return (mean ± std, N=5 seeds)")
plt.title("Episode Return (100 Demonstration Episodes)")
plt.tight_layout()
plt.savefig(os.path.join(out_dir, "best6_multiseed_episode_return.png"))
plt.show()

In [None]:
# ---- 2b. Also plot each individual seed as points on the bar plot for Return ----
plt.figure(figsize=(10, max(6, len(summary)*0.5)))
plt.barh(range(len(means)), means, xerr=stds, color='salmon', alpha=0.7, label="Mean ± std")
for i, s in enumerate(sorted_summary):
    plt.scatter([val for val in s['returns']], [i]*len(s['returns']), color='k', marker='|', s=100, label="Seed values" if i==0 else "")
plt.yticks(range(len(labels)), labels, fontsize=9)
plt.xlabel("Episode Return")
plt.title("Episode Return (100 Demonstration Episodes)")
plt.legend()
plt.tight_layout()
plt.savefig(os.path.join(out_dir, "best6_multiseed_episode_return_with_seeds.png"))
plt.show()

In [None]:
print(
    f"Plotted and saved in {out_dir}:\n"
    " - Success rate: best6_multiseed_success_rate.png\n"
    " - Success rate with all seeds: best6_multiseed_success_rate_with_seeds.png\n"
    " - Episode return: best6_multiseed_episode_return.png\n"
    " - Episode return with all seeds: best6_multiseed_episode_return_with_seeds.png\n"
    " - Training loss (best model): *_training_loss_seeds.png\n"
    " - Validation MSE (best model): *_val_mse_seeds.png"
)

In [None]:
def pretty_model_name(model_base):
    name = model_base.replace("_", " ").replace("dagger", "Dagger").replace("mlp", "MLP").replace("lstm", "LSTM")
    name = name.replace("tuned", "Tuned")
    if name.lower().endswith(" best"):
        name = name[:-5]
    return " ".join(w.capitalize() if not w.isupper() else w for w in name.split())

for s in sorted_summary:
    model_label = pretty_model_name(s['model_base'])
    if s['train_curves']:
        plt.figure(figsize=(10, 6))
        for i, curve in enumerate(s['train_curves']):
            plt.plot(curve, alpha=0.5, label=f"Seed {i}")
        if len(s['train_curves']) > 1:
            train_mean = np.mean(np.array(s['train_curves']), axis=0)
            train_std = np.std(np.array(s['train_curves']), axis=0)
            plt.plot(train_mean, color='k', label="Mean", linewidth=2)
            plt.fill_between(range(len(train_mean)), train_mean-train_std, train_mean+train_std, color='k', alpha=0.2, label="Mean ± std")
        plt.title(f"Training Loss (MSE) for {model_label} (100 Demonstration Episodes)")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.legend(fontsize=9)
        plt.tight_layout()
        plt.savefig(os.path.join(out_dir, f"{s['model_base']}_training_loss_seeds.png"))
        plt.show()
    if s['val_curves']:
        plt.figure(figsize=(10, 6))
        for i, curve in enumerate(s['val_curves']):
            plt.plot(curve, alpha=0.5, label=f"Seed {i}")
        if len(s['val_curves']) > 1:
            val_mean = np.mean(np.array(s['val_curves']), axis=0)
            val_std = np.std(np.array(s['val_curves']), axis=0)
            plt.plot(val_mean, color='k', label="Mean", linewidth=2)
            plt.fill_between(range(len(val_mean)), val_mean-val_std, val_mean+val_std, color='k', alpha=0.2, label="Mean ± std")
        plt.title(f"Validation MSE for {model_label} (100 Demonstration Episodes)")
        plt.xlabel("Epoch")
        plt.ylabel("MSE")
        plt.legend(fontsize=9)
        plt.tight_layout()
        plt.savefig(os.path.join(out_dir, f"{s['model_base']}_val_mse_seeds.png"))
        plt.show()

In [None]:
SEED_LIST = [0, 1, 2, 3, 4]
seed_colors = [
    "#7eb6f6",   # Seed 0 - biru muda
    "#f3c37c",   # Seed 1 - oranye muda
    "#a6d9a7",   # Seed 2 - hijau muda
    "#ea9498",   # Seed 3 - merah muda
    "#b2abd2",   # Seed 4 - ungu muda
]
labels = [s['model_base'] for s in sorted_summary]
means = [s['success_rate_mean'] for s in sorted_summary]
stds = [s['success_rate_std'] for s in sorted_summary]
plt.figure(figsize=(10, max(6, len(labels)*0.5)))
plt.barh(range(len(means)), means, xerr=stds, color='skyblue', alpha=0.7, label="Mean ± std")
for i, s in enumerate(sorted_summary):
    val_to_seeds = {}
    val_to_colors = {}
    for j, val in enumerate(s['success_rates']):
        if val not in val_to_seeds:
            val_to_seeds[val] = []
            val_to_colors[val] = []
        val_to_seeds[val].append(SEED_LIST[j])
        val_to_colors[val].append(seed_colors[j % len(seed_colors)])
    for val, seed_nums in val_to_seeds.items():
        dot_color = val_to_colors[val][0]
        plt.scatter(val, i, color=dot_color, s=100, marker='o', zorder=5)
        label_text = ",".join(str(num) for num in seed_nums)
        plt.text(val, i - 0.17, label_text, color=dot_color, ha='center', va='top', 
                 fontsize=10, fontweight='bold', zorder=6)
plt.yticks(range(len(labels)), labels, fontsize=9)
plt.xlabel("Success Rate")
plt.title("Task Success Rate (100 Demonstration Episodes)")
legend_elements = [
    Line2D([0], [0], marker='o', color='w', label=f'Seed {s}', markerfacecolor=seed_colors[i], markersize=10)
    for i, s in enumerate(SEED_LIST)
]
plt.legend(handles=legend_elements + [Line2D([0], [0], color='skyblue', lw=10, label="Mean ± std")], fontsize=10)
plt.xlim([0, 1])
plt.tight_layout()
plt.savefig(os.path.join(out_dir, "best6_multiseed_success_rate_inline_seed_numbers.png"))
plt.show()

In [None]:
means = [s['return_mean'] for s in sorted_summary]
stds = [s['return_std'] for s in sorted_summary]
plt.figure(figsize=(10, max(6, len(labels)*0.5)))
plt.barh(range(len(means)), means, xerr=stds, color='salmon', alpha=0.7, label="Mean ± std")
for i, s in enumerate(sorted_summary):
    val_to_seeds = {}
    val_to_colors = {}
    for j, val in enumerate(s['returns']):
        if val not in val_to_seeds:
            val_to_seeds[val] = []
            val_to_colors[val] = []
        val_to_seeds[val].append(SEED_LIST[j])
        val_to_colors[val].append(seed_colors[j % len(seed_colors)])
    for val, seed_nums in val_to_seeds.items():
        dot_color = val_to_colors[val][0]
        plt.scatter(val, i, color=dot_color, s=100, marker='o', zorder=5)
        label_text = ",".join(str(num) for num in seed_nums)
        plt.text(val, i - 0.17, label_text, color=dot_color, ha='center', va='top', 
                 fontsize=10, fontweight='bold', zorder=6)
plt.yticks(range(len(labels)), labels, fontsize=9)
plt.xlabel("Episode Return")
plt.title("Episode Return (100 Demonstration Episodes)")
legend_elements = [
    Line2D([0], [0], marker='o', color='w', label=f'Seed {s}', markerfacecolor=seed_colors[i], markersize=10)
    for i, s in enumerate(SEED_LIST)
]
plt.legend(handles=legend_elements + [Line2D([0], [0], color='salmon', lw=10, label="Mean ± std")], fontsize=10)
plt.tight_layout()
plt.savefig(os.path.join(out_dir, "best6_multiseed_episode_return_inline_seed_numbers.png"))
plt.show()