# BC LSTM Policy: Training, Evaluation, and Plotting Pipeline (100 Episodes)
This notebook combines LSTM training, evaluation, and plotting using expert trajectories.

In [None]:
import os
import pickle
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
import csv
import matplotlib.pyplot as plt
from surrol.tasks.needle_pick import NeedlePick

In [None]:
# --- Set random seeds for reproducibility ---
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
def concat_obs(obs):
    return np.concatenate([obs['observation'], obs['achieved_goal'], obs['desired_goal']])

In [None]:
# --- Policy network (LSTM) ---
class LSTMPolicy(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden_size=128, num_layers=1):
        super().__init__()
        self.lstm = nn.LSTM(obs_dim, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, act_dim)

    def forward(self, x, hidden=None):
        # x: (batch, seq_len, obs_dim)
        lstm_out, hidden = self.lstm(x, hidden)
        # Output action for every time step
        out = self.fc(lstm_out)  # (batch, seq_len, act_dim)
        return out, hidden

In [None]:
device = torch.device('mps' if torch.backends.mps.is_available() else ('cuda' if torch.cuda.is_available() else 'cpu'))
print("Selected device:", device)

In [None]:
# --- Prepare full-episode LSTM dataset from trajectories ---
def make_lstm_episode_dataset(trajectories):
    obs_episodes = []
    act_episodes = []
    for episode in trajectories:
        obs_list = episode['observations']
        act_list = episode['actions']
        obs_arr = np.array([concat_obs(o) for o in obs_list])
        act_arr = np.array(act_list)
        obs_episodes.append(obs_arr)
        act_episodes.append(act_arr)
    return obs_episodes, act_episodes

In [None]:
# --- Training function on full episodes ---
def train_lstm_full_episodes(obs_episodes, act_episodes, epochs=50, lr=1e-3, hidden_size=128, num_layers=1, val_split=0.1):
    obs_dim = obs_episodes[0].shape[1]
    act_dim = act_episodes[0].shape[1]
    policy = LSTMPolicy(obs_dim, act_dim, hidden_size=hidden_size, num_layers=num_layers).to(device)
    optimizer = optim.Adam(policy.parameters(), lr=lr)
    loss_fn = nn.MSELoss()

    # Split for validation (by episode)
    n_val = int(len(obs_episodes) * val_split)
    if n_val > 0:
        val_obs_episodes = obs_episodes[:n_val]
        val_act_episodes = act_episodes[:n_val]
        train_obs_episodes = obs_episodes[n_val:]
        train_act_episodes = act_episodes[n_val:]
    else:
        train_obs_episodes = obs_episodes
        train_act_episodes = act_episodes
        val_obs_episodes = []
        val_act_episodes = []

    num_episodes = len(train_obs_episodes)
    losses = []
    val_mse = []
    for epoch in range(epochs):
        policy.train()
        epoch_loss = 0
        total_steps = 0
        indices = np.arange(num_episodes)
        np.random.shuffle(indices)
        for idx in indices:
            obs_seq = torch.tensor(train_obs_episodes[idx], dtype=torch.float32).unsqueeze(0).to(device)  # (1, seq_len, obs_dim)
            act_seq = torch.tensor(train_act_episodes[idx], dtype=torch.float32).unsqueeze(0).to(device)  # (1, seq_len, act_dim)
            optimizer.zero_grad()
            act_pred, _ = policy(obs_seq)
            # Compute loss across all time steps in the episode
            loss = loss_fn(act_pred, act_seq)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item() * obs_seq.shape[1]
            total_steps += obs_seq.shape[1]
        avg_loss = epoch_loss / total_steps if total_steps > 0 else 0
        losses.append(avg_loss)
        # Validation
        if val_obs_episodes:
            policy.eval()
            with torch.no_grad():
                val_loss = 0
                val_steps = 0
                for obs_seq, act_seq in zip(val_obs_episodes, val_act_episodes):
                    obs_seq_t = torch.tensor(obs_seq, dtype=torch.float32).unsqueeze(0).to(device)
                    act_seq_t = torch.tensor(act_seq, dtype=torch.float32).unsqueeze(0).to(device)
                    act_pred, _ = policy(obs_seq_t)
                    l = loss_fn(act_pred, act_seq_t)
                    val_loss += l.item() * obs_seq_t.shape[1]
                    val_steps += obs_seq_t.shape[1]
                val_mse.append(val_loss / val_steps if val_steps > 0 else 0)
        if (epoch + 1) % 10 == 0 or epoch == 0 or (epoch + 1) == epochs:
            if val_mse:
                print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.6f}, Val: {val_mse[-1]:.6f}")
            else:
                print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.6f}")
    return policy, losses, val_mse

In [None]:
# --- Evaluate LSTM policy step-by-step ---
def evaluate_lstm_policy_sequential(policy, episodes=10, max_steps=200):
    success_count = 0
    returns = []
    for ep in range(episodes):
        env = NeedlePick(render_mode=None)
        obs = env.reset()
        obs_in = concat_obs(obs)
        hidden = None
        total_reward = 0
        success = False
        for t in range(max_steps):
            inp = torch.tensor(obs_in, dtype=torch.float32).view(1, 1, -1).to(device)  # shape (1,1,obs_dim)
            with torch.no_grad():
                out, hidden = policy(inp, hidden)
            hidden = tuple(h.detach() for h in hidden)
            action = out.cpu().numpy()[0, 0]
            action = np.clip(action, env.action_space.low, env.action_space.high)
            obs, reward, done, info = env.step(action)
            obs_in = concat_obs(obs)
            total_reward += reward
            if info.get('is_success', False):
                success = True
                break
            if done:
                break
        returns.append(total_reward)
        if success:
            success_count += 1
        env.close()
    avg_return = np.mean(returns)
    success_rate = success_count / episodes
    return avg_return, success_rate

# 1. Training: Hyperparameter Grid Search
Train LSTM policies with different hyperparameters and save models/metrics.

In [None]:
# --- Main pipeline ---
out_dir = "lstm_bc_models"
os.makedirs(out_dir, exist_ok=True)

# Load expert data
with open("expert_trajectories.pkl", "rb") as f:
    trajectories = pickle.load(f)

# Prepare full-episode LSTM dataset
obs_episodes, act_episodes = make_lstm_episode_dataset(trajectories)
print(f"Collected {len(obs_episodes)} LSTM episodes for training (full-episode).")

obs_dim = obs_episodes[0].shape[1]
act_dim = act_episodes[0].shape[1]

# Define hyperparameter grid
learning_rates = [1e-3, 3e-4, 1e-4]
hidden_sizes = [128, 256]
epochs_list = [75, 150]
num_layers_list = [1, 2]

# Results storage
results = []

# Grid search (batch_size dropped since each episode is a batch)
for lr in learning_rates:
    for hidden_size in hidden_sizes:
        for epochs in epochs_list:
            for num_layer in num_layers_list:
                print("\n========================================")
                print(f"Training LSTM (full-episode): lr={lr}, hidden={hidden_size}, epochs={epochs}, layers={num_layer}")
                policy, losses, val_mse = train_lstm_full_episodes(
                    obs_episodes, act_episodes, epochs=epochs,
                    lr=lr, hidden_size=hidden_size, num_layers=num_layer
                )
                # Save model and metrics
                model_name = f"lstm_bc_lr{lr}_hid{hidden_size}_ep{epochs}_lay{num_layer}_fullseq"
                torch.save(policy.state_dict(), os.path.join(out_dir, f"{model_name}.pth"))
                np.save(os.path.join(out_dir, f"{model_name}_train_losses.npy"), np.array(losses))
                np.save(os.path.join(out_dir, f"{model_name}_val_mse.npy"), np.array(val_mse))
                print(f"Saved model and logs to {out_dir}: {model_name}")
                # Store final val MSE for comparison
                final_val = val_mse[-1] if val_mse else None
                results.append({
                    'lr': lr,
                    'hidden_size': hidden_size,
                    'epochs': epochs,
                    'num_layers': num_layer,
                    'final_val_mse': final_val,
                    'model_name': model_name
                })

# 2. Evaluation: Policy Performance
Evaluate each trained policy and save results.

In [None]:
# --- Evaluate all models ---
eval_results = []
EVAL_EPISODES = 10
MAX_STEPS = 200
for res in results:
    model_name = res['model_name']
    model_path = os.path.join(out_dir, f"{model_name}.pth")
    if not os.path.exists(model_path):
        print(f"Model {model_name} not found in {out_dir}. Skipping...")
        continue
    print(f"\nEvaluating {model_name} (sequential LSTM)")
    policy = LSTMPolicy(obs_dim, act_dim, hidden_size=res['hidden_size'], num_layers=res['num_layers']).to(device)
    policy.load_state_dict(torch.load(model_path, map_location=device))
    policy.eval()
    avg_return, success_rate = evaluate_lstm_policy_sequential(
        policy, episodes=EVAL_EPISODES, max_steps=MAX_STEPS)
    print(f"Result: Success rate: {success_rate*100:.1f}%, Avg return: {avg_return:.2f}")
    eval_results.append({
        **res,
        'avg_return': avg_return,
        'success_rate': success_rate
    })
    np.save(os.path.join(out_dir, f"{model_name}_eval_success_rate.npy"), np.array([success_rate]))
    np.save(os.path.join(out_dir, f"{model_name}_eval_return.npy"), np.array([avg_return]))

In [None]:
# --- Save results to CSV ---
csv_path = os.path.join(out_dir, "evaluation_results.csv")
with open(csv_path, "w", newline="") as csvfile:
    writer = csv.DictWriter(csvfile, fieldnames=["model_name", "lr", "hidden_size", "epochs", "num_layers", "final_val_mse", "avg_return", "success_rate"])
    writer.writeheader()
    writer.writerows(eval_results)

In [None]:
# --- Print summary ---
print("\n===== LSTM Grid Evaluation Results (Sequential, Step-by-Step) =====")
for res in sorted(eval_results, key=lambda x: (-x['success_rate'], -x['avg_return'])):
    print(f"{res['model_name']}: Success={res['success_rate']*100:.1f}%, Return={res['avg_return']:.2f}")
print(f"\nAll models and metrics are saved in the folder: {out_dir}")

# 3. Plotting: Visualize Training and Evaluation Metrics
Plot training loss, validation MSE, success rate, and episode return for all hyperparameter configurations. 

In [None]:
# ---- Plotting ----
if not eval_results:
    print("No results found, check that your metrics files exist and names match.")
else:
    # Find the best configs for each metric
    best_val = sorted(eval_results, key=lambda r: r['final_val_mse'])[0]
    success_candidates = [r for r in eval_results if r['success_rate'] is not None]
    if success_candidates:
        best_success = sorted(success_candidates, key=lambda r: -r['success_rate'])[0]
    else:
        best_success = best_val
    return_candidates = [r for r in eval_results if r['avg_return'] is not None]
    if return_candidates:
        best_return = sorted(return_candidates, key=lambda r: -r['avg_return'])[0]
    else:
        best_return = best_val

    print("\nBest by val MSE:", best_val['model_name'])
    print("Best by online success:", best_success['model_name'])
    print("Best by online return:", best_return['model_name'])

    models = [r['model_name'] for r in eval_results]
    successes = [r['success_rate'] if r['success_rate'] is not None else 0.0 for r in eval_results]
    returns = [r['avg_return'] if r['avg_return'] is not None else 0.0 for r in eval_results]

    # 1. Training Loss (MSE) for best configs
    plt.figure(figsize=(10, 6))
    plt.plot(np.load(os.path.join(out_dir, f"{best_val['model_name']}_train_losses.npy")), label=f"Best Val MSE ({best_val['model_name']})")
    plt.plot(np.load(os.path.join(out_dir, f"{best_success['model_name']}_train_losses.npy")), label=f"Best Success ({best_success['model_name']})")
    plt.plot(np.load(os.path.join(out_dir, f"{best_return['model_name']}_train_losses.npy")), label=f"Best Return ({best_return['model_name']})")
    plt.title("LSTM Training Loss (MSE) for Best Configs (100 Episodes)")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend(fontsize=9)
    plt.tight_layout()
    plt.savefig(os.path.join(out_dir, "lstm_hyperparam_training_loss.png"))
    plt.show()

    # 2. Validation MSE for best configs
    plt.figure(figsize=(10, 6))
    plt.plot(np.load(os.path.join(out_dir, f"{best_val['model_name']}_val_mse.npy")), label=f"Best Val MSE ({best_val['model_name']})")
    plt.plot(np.load(os.path.join(out_dir, f"{best_success['model_name']}_val_mse.npy")), label=f"Best Success ({best_success['model_name']})")
    plt.plot(np.load(os.path.join(out_dir, f"{best_return['model_name']}_val_mse.npy")), label=f"Best Return ({best_return['model_name']})")
    plt.title("LSTM Validation MSE for Best Configs (100 Episodes)")
    plt.xlabel("Epoch")
    plt.ylabel("MSE")
    plt.legend(fontsize=9)
    plt.tight_layout()
    plt.savefig(os.path.join(out_dir, "lstm_hyperparam_val_mse.png"))
    plt.show()

    # 3. Task Success Rate (All Hyperparams)
    success_sorted = sorted(zip(successes, models), reverse=True)
    success_vals, success_labels = zip(*success_sorted)
    plt.figure(figsize=(10, max(6, len(models)*0.3)))
    plt.barh(range(len(success_vals)), success_vals, color='skyblue')
    plt.yticks(range(len(success_labels)), success_labels, fontsize=7)
    plt.xlabel("Success Rate")
    plt.title("LSTM Task Success Rate (All Hyperparams) (100 Episodes)")
    plt.xlim([0, 1])
    plt.tight_layout()
    plt.savefig(os.path.join(out_dir, "lstm_hyperparam_success_rate.png"))
    plt.show()

    # 4. Episode Return (All Hyperparams)
    return_sorted = sorted(zip(returns, models), reverse=True)
    return_vals, return_labels = zip(*return_sorted)
    plt.figure(figsize=(10, max(6, len(models)*0.3)))
    plt.barh(range(len(return_vals)), return_vals, color='salmon')
    plt.yticks(range(len(return_labels)), return_labels, fontsize=7)
    plt.xlabel("Episode Return")
    plt.title("LSTM Episode Return (All Hyperparams)(100 Episodes)")
    plt.tight_layout()
    plt.savefig(os.path.join(out_dir, "lstm_hyperparam_episode_return.png"))
    plt.show()

    print(
        "Plotted and saved in {}:\n - Training loss: lstm_hyperparam_training_loss.png\n - Validation MSE: lstm_hyperparam_val_mse.png\n - Success rate: lstm_hyperparam_success_rate.png\n - Episode return: lstm_hyperparam_episode_return.png".format(out_dir)
    )