In [5]:
import torch
import numpy as np
from surrol.tasks.needle_pick import NeedlePick
from best_seed_train import MLPPolicy, LSTMPolicy  # Adjust if your policy code is elsewhere


In [6]:

# --- CONFIG ---
model_path = "best6_multiseed/lstm_best_seed0.pth"   # Path to your .pth file
model_type = "lstm"  # "mlp" or "lstm"
render = True        # Set True to visualize, False for headless
episode_seed = 6    # Use None for random, or set for reproducibility
max_steps = 200


In [7]:

# --- Observation/Action dims ---
# Load one trajectory to get obs/act dims
import pickle, os
with open(os.path.join("expert_trajectories.pkl"), "rb") as f:
    expert_trajs = pickle.load(f)
obs_example = expert_trajs[0]['observations'][0]
obs_dim = (
    obs_example['observation'].shape[0]
    + obs_example['achieved_goal'].shape[0]
    + obs_example['desired_goal'].shape[0]
)
act_dim = expert_trajs[0]['actions'][0].shape[0]


In [8]:
 
# --- Build policy and load weights ---
device = torch. device('mps' if torch.backends.mps.is_available() else 'cpu')
print("Selected device:", device)

if model_type == "mlp":
    policy = MLPPolicy(obs_dim, act_dim, hidden_sizes=(256, 256)).to(device)
elif model_type == "lstm":
    policy = LSTMPolicy(obs_dim, act_dim, hidden_size=256, num_layers=2).to(device)
else:
    raise ValueError("model_type must be 'mlp' or 'lstm'")

policy.load_state_dict(torch.load(model_path, map_location=device))
policy.eval()

Selected device: mps


RuntimeError: Error(s) in loading state_dict for LSTMPolicy:
	Missing key(s) in state_dict: "lstm.weight_ih_l1", "lstm.weight_hh_l1", "lstm.bias_ih_l1", "lstm.bias_hh_l1". 
	size mismatch for lstm.weight_ih_l0: copying a param with shape torch.Size([512, 25]) from checkpoint, the shape in current model is torch.Size([1024, 25]).
	size mismatch for lstm.weight_hh_l0: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([1024, 256]).
	size mismatch for lstm.bias_ih_l0: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for lstm.bias_hh_l0: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for fc.weight: copying a param with shape torch.Size([5, 128]) from checkpoint, the shape in current model is torch.Size([5, 256]).

In [None]:

# --- Environment Setup ---
if episode_seed is not None:
    np.random.seed(episode_seed)
env = NeedlePick(render_mode="human" if render else None)
obs = env.reset()

def concat_obs(obs):
    return np.concatenate([obs['observation'], obs['achieved_goal'], obs['desired_goal']])

total_reward = 0
success = False
hidden = None

for step in range(max_steps):
    obs_in = concat_obs(obs)
    if model_type == "mlp":
        inp = torch.tensor(obs_in, dtype=torch.float32).unsqueeze(0).to(device)
        with torch.no_grad():
            action = policy(inp).cpu().numpy().squeeze(0)
    else:
        inp = torch.tensor(obs_in, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device)
        with torch.no_grad():
            action_tensor, hidden = policy(inp, hidden)
            action = action_tensor.cpu().numpy().squeeze(0).squeeze(0)
    # Optional: clip to action space if needed
    if hasattr(env, 'action_space'):
        action = np.clip(action, env.action_space.low, env.action_space.high)
    obs, reward, done, info = env.step(action)
    total_reward += reward
    if info.get('is_success', False):
        print(f"Success at step {step}")
        success = True
        break
    if done:
        break

print(f"Episode finished. Total reward: {total_reward:.2f}, Success: {success}")
env.close()

Version = 4.1 Metal - 89.3
Vendor = Apple
Renderer = Apple M2
b3Printf: Selected demo: Physics Server
startThreads creating 1 threads.
starting thread 0
started thread 0 
MotionThreadFunc thread started


2025-06-29 20:35:51.878 python[35893:15022927] +[IMKClient subclass]: chose IMKClient_Modern
2025-06-29 20:35:51.878 python[35893:15022927] +[IMKInputSession subclass]: chose IMKInputSession_Modern


Success at step 37
Episode finished. Total reward: -37.00, Success: True
numActiveThreads = 0
stopping threads
Thread with taskId 0 exiting
Thread TERMINATED
destroy semaphore
semaphore destroyed
destroy main semaphore
main semaphore destroyed


: 