In [40]:
import minari
import os
import h5py
import numpy as np
import torch
import matplotlib.pyplot as plt
import gymnasium as gym
import shimmy
from custom_dmc_tasks import point_mass_maze
import ot
from tqdm import tqdm

from stable_baselines3 import TD3
from stable_baselines3.common.vec_env import DummyVecEnv, VecMonitor

from src.model.base_models.unet_mlp import TDFlowUnet
from src.model.flow_matching import ConditionalFlowMatching
# os.environ['MUJOCO_GL'] = 'glfw'

In [41]:
task = 'reach_top_left'
epoch = 95

model_path = f"td3_point_mass_expert_{task}"

td3 = TD3.load(model_path)


In [42]:
fm = ConditionalFlowMatching(TDFlowUnet(), obs_dim=(4,))

In [43]:
state_dict = torch.load(f"checkpoints/td2_cfm_model_{task}_epoch_{epoch}.pth", map_location='cpu')

In [44]:
fm.model.load_state_dict(state_dict)

<All keys matched successfully>

In [45]:
def make_env():
    raw_env = point_mass_maze.make(task=task)
    return shimmy.DmControlCompatibilityV0(raw_env)

vec_env = DummyVecEnv([make_env for _ in range(1)])
vec_env = VecMonitor(vec_env)

In [46]:
from collections import OrderedDict
def force_set_state(vec_env, target_state):
    """
    target_state: np.array([x, y, vx, vy])
    """
    new_obs_list = []
    
    for e in vec_env.envs:
        dmc_env = e.unwrapped._env 
        dmc_env.task.set_state(dmc_env.physics, target_state)
        obs_dict = dmc_env.task.get_observation(dmc_env.physics)
        flat_obs = np.concatenate([obs_dict['position'], obs_dict['velocity']])
        new_obs_list.append(flat_obs)
    new_obs_list = np.array(new_obs_list)
    return OrderedDict({'position':new_obs_list[:, :2], 'velocity': new_obs_list[:, 2:]})

def get_true_physics_state(vec_env, env_idx=0):
    physics = vec_env.envs[env_idx].unwrapped.physics
    
    true_pos = physics.data.qpos[:2].copy()
    true_vel = physics.data.qvel[:2].copy()
    
    return OrderedDict({'position':true_pos, 'velocity': true_vel})

# obs = force_set_state(vec_env, target_state=np.array([-0.22865181,  0.23067307, 0.0, 0.0]))

In [47]:
states = []
actions = []
trajectories = []
rewards = []
stopping_times = []
initial_cond = []
for n in tqdm(range(64)):
    x = np.random.uniform(-0.29, -0.15)
    y = np.random.uniform(0.15, 0.29)
    obs = force_set_state(vec_env, target_state=np.array([x,  y, 0.0, 0.0]))
    state = [np.concat([obs['position'], obs['velocity']], axis=-1)] # set initial state
    actions = [np.array([[0.0, 0.0]])] # set initial action
    obs, r, terminated, _ = vec_env.step(actions[0])
    state.append(np.concat([obs['position'], obs['velocity']], axis=-1))
    reward = [r]
    cond = np.concatenate([np.array([[x,  y, 0.0, 0.0]]), actions[0]], axis=-1)
    initial_cond.append(cond)
    for t in range(1000): 
        # frame = raw_env.physics.render(height=480, width=480, camera_id=0)
        # frames.append(frame)
        
        action, _ = td3.predict(obs, deterministic=True)
        obs, r, terminated, _ = vec_env.step(action)
        reward.append(r)
        if terminated:
            state.append(np.concat([obs['position'], obs['velocity']], axis=-1))
            break
        state.append(np.concat([obs['position'], obs['velocity']], axis=-1))
    
    T = np.minimum(np.random.geometric(1-0.99, 2048), len(state)-1)
    states.append(np.array(state)[T].squeeze())
    trajectories.append(state)
    rewards.append(reward)
    stopping_times.append(T)

100%|██████████| 64/64 [00:06<00:00,  9.40it/s]


In [48]:
initial_cond = torch.from_numpy(np.array(initial_cond)).squeeze().to(torch.float32)

In [49]:
initial_cond = initial_cond[:,None, :].repeat(1, 2048, 1)

In [50]:
t_batch = torch.ones(2048)
gen_samples = []
for i in tqdm(range(64)):
    gen_samples.append(
    fm.sample(2048, initial_cond[i], t_batch)
    )

100%|██████████| 64/64 [00:32<00:00,  1.96it/s]


In [51]:
# EMD 
all_emds = []
for i in tqdm(range(64)):
    X_true = states[i] # sT
    X_pred = gen_samples[i].numpy() # samples from m
    M = ot.dist(X_true, X_pred, metric='euclidean')
    a, b = np.ones((2048,)) / 2048, np.ones((2048,)) / 2048
    value = ot.emd2(a, b, M)
    all_emds.append(value)

final_score = np.mean(all_emds)

100%|██████████| 64/64 [00:12<00:00,  4.99it/s]


In [52]:
final_score

np.float64(0.029807528934109047)

In [55]:
torch.tensor(0.).item()

0.0