In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import pickle
import numpy as np
import pkg_resources
from pathlib import Path
from tqdm import trange

from flygym.envs.nmf_mujoco import NeuroMechFlyMuJoCo
from flygym.util.config import leg_dofs_fused_tarsi

In [81]:
data_path = Path(pkg_resources.resource_filename('flygym', 'data'))
with open(data_path / 'behavior' / '210902_pr_fly1.pkl', 'rb') as f:
    data = pickle.load(f)

In [None]:
run_time = 1
out_dir = Path('kin_replay')
nmf = NeuroMechFlyMuJoCo(render_mode='saved', output_dir=out_dir,
                         timestep=5e-4, actuated_joints=leg_dofs_fused_tarsi,
                         init_pose='stretch')

In [None]:
import matplotlib.pyplot as plt
plt.plot(data['LFCoxa'])

In [None]:
obs_list = []    # keep track of the observed states
in_list = []
for i in trange(int(run_time / nmf.timestep)):
    joint_pos = [data[joint.replace('joint_', '')][i]
                 for joint in nmf.actuated_joints]
    # print(joint_pos)
    if i == 10: break
    in_list.append(joint_pos)
    action = {'joints': joint_pos}
    try:
        obs, info = nmf.step(action)
    except:
        print(joint_pos)
    nmf.render()
    obs_list.append(obs)
nmf.close()

In [None]:
np.isfinite(np.array(in_list)).all()

In [None]:
num_joints_to_plot = 1
plt.plot([x['joints'][0, :num_joints_to_plot] for x in obs_list])
plt.plot(np.array(in_list)[:, 0])

In [85]:
"""Demo script for NeuroMechFlyMuJoCo environment:
Execute an environment where all leg joints of the fly repeat a sinusoidal
motion. The output will be saved as a video."""

import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from flygym.envs.nmf_mujoco import NeuroMechFlyMuJoCo
from flygym.util.config import leg_dofs_fused_tarsi

# First, we initialize simulation
run_time = 1
out_dir = Path('mujoco_basic_untethered_sinewave')
nmf = NeuroMechFlyMuJoCo(render_mode='saved', output_dir=out_dir,
                         timestep=1e-4, render_config={'playspeed': 0.1},
                         init_pose='stretch')

# Define the frequency, phase, and amplitude of the sinusoidal waves
freq = 20
phase = 2 * np.pi * np.random.rand(len(nmf.actuators))
amp = 0.9
num_steps = int(run_time / nmf.timestep)
data_block = np.zeros((len(nmf.actuated_joints), num_steps))
src_x = np.arange(len(data['LFCoxa'])) * 5
tgt_x = np.arange(num_steps)
for i, joint in enumerate(nmf.actuated_joints):
    if (key := joint.replace('joint_', '')) in data:
        data_block[i, :] = np.interp(tgt_x, src_x, data[key])

In [86]:
assert np.all(np.isfinite(data_block))

In [87]:
import time
obs_list = []
for i in trange(num_steps):
    #time.sleep(0.01)
    joint_pos = data_block[:, i]
    action = {'joints': joint_pos}
    obs, info = nmf.step(action)
    nmf.render()
    obs_list.append(obs)
nmf.close()

100%|██████████| 10000/10000 [00:14<00:00, 708.71it/s]


In [96]:
"""Demo script for NeuroMechFlyMuJoCo environment:
Execute an environment where all leg joints of the fly repeat a sinusoidal
motion. The output will be saved as a video."""

import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from flygym.envs.nmf_mujoco import NeuroMechFlyMuJoCo
from flygym.util.config import leg_dofs_fused_tarsi

# Initialize simulation
run_time = 1
out_dir = Path('kin_replay')
nmf = NeuroMechFlyMuJoCo(render_mode='saved',
                         output_dir=out_dir,
                         timestep=1e-4,
                         render_config={'playspeed': 0.1},
                         init_pose='stretch',
                         actuated_joints=leg_dofs_fused_tarsi)

# Load recorded data
data_path = Path(pkg_resources.resource_filename('flygym', 'data'))
with open(data_path / 'behavior' / '210902_pr_fly1.pkl', 'rb') as f:
    data = pickle.load(f)

# Interpolate 5x
num_steps = int(run_time / nmf.timestep)
data_block = np.zeros((len(nmf.actuated_joints), num_steps))
measure_t = np.arange(len(data['joint_LFCoxa'])) * data['meta']['timestep']
interp_t = np.arange(num_steps) * nmf.timestep
for i, joint in enumerate(nmf.actuated_joints):
    data_block[i, :] = np.interp(interp_t, measure_t, data[joint])

In [97]:
obs_list = []
for i in trange(num_steps):
    joint_pos = data_block[:, i]
    action = {'joints': joint_pos}
    obs, info = nmf.step(action)
    nmf.render()
    obs_list.append(obs)
nmf.close()

100%|██████████| 10000/10000 [00:13<00:00, 752.69it/s]
