In [1]:
import PIL.Image

import numpy as np
import pkg_resources
import pickle
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from flygym.envs.nmf_mujoco import NeuroMechFlyMuJoCo, MuJoCoParameters
from tqdm import trange
from flygym.util.config import all_leg_dofs, leg_dofs_3_per_leg
from flygym.state import stretched_pose

import cv2


from flygym.util.cpg_controller import (advancement_transfer, phase_oscillator, sine_output, initialize_solver,
                         phase_biases_tripod_measured, phase_biases_tripod_idealized,
                         phase_biases_ltetrapod_idealized, phase_biases_metachronal_idealized,
                         plot_phase_amp_output)

In [2]:
plt.rcParams['font.family'] = 'Arial'
plt.rcParams['pdf.fonttype'] = 42

In [3]:
# Initialize the simulation
run_time = 1

sim_params = MuJoCoParameters(
    timestep=2e-4,
    render_mode="saved",
    render_camera="Animat/camera_right",
    render_playspeed=0.1,
    enable_adhesion=True,
    adhesion_gain=20,
    draw_adhesion=True,
    draw_contacts=True,
    force_arrow_scaling=0.5,
)
nmf = NeuroMechFlyMuJoCo(
    sim_params=sim_params,
    init_pose=stretched_pose,
    actuated_joints=all_leg_dofs,
    spawn_pos = [0.0, 0.0, 0.2]
)

num_steps_base = int(run_time / nmf.timestep)

In [4]:
# Load recorded data
data_path = Path(pkg_resources.resource_filename('flygym', 'data'))
with open(data_path / 'behavior' / 'single_steps.pkl', 'rb') as f:
    data = pickle.load(f)

In [5]:
# Interpolate 5x
step_duration = len(data['joint_LFCoxa'])
interp_step_duration = int(step_duration * data['meta']['timestep'] / nmf.timestep)
data_block = np.zeros((len(nmf.actuated_joints), interp_step_duration))
measure_t = np.arange(step_duration) * data['meta']['timestep']
interp_t = np.arange(interp_step_duration) * nmf.timestep
for i, joint in enumerate(nmf.actuated_joints):
    data_block[i, :] = np.interp(interp_t, measure_t, data[joint])

leg_swing_starts = {k:v/nmf.timestep for k,v in data["swing_stance_time"]["swing"].items()}
leg_stance_starts = {k:v/nmf.timestep for k,v in data["swing_stance_time"]["stance"].items()}

In [6]:
legs = ["RF", "RM", "RH", "LF", "LM", "LH"]
n_oscillators = len(legs)

t = np.arange(0, run_time, nmf.timestep)

n_joints = len(nmf.actuated_joints)
leg_ids = np.arange(len(legs)).astype(int)
joint_ids = np.arange(n_joints).astype(int)
# Map the id of the joint to the leg it belongs to (usefull to go through the steps for each legs)
match_leg_to_joints = np.array([i  for joint in nmf.actuated_joints for i, leg in enumerate(legs) if leg in joint])

# Coxa joint of each leg (recover the advancement of the leg)
joints_to_leg = np.array([i for ts in nmf.last_tarsalseg_names for i, joint in enumerate(nmf.actuated_joints) if f"{ts[:2]}Coxa_roll" in joint])
stance_starts_in_order = np.array([leg_stance_starts[ts[:2]] for ts in nmf.last_tarsalseg_names])
swing_starts_in_order = np.array([leg_swing_starts[ts[:2]] for ts in nmf.last_tarsalseg_names])

In [7]:
# lets say we want 5 oscillations in the time period
n_steps = 5
frequencies = np.ones(n_oscillators) * n_steps / run_time

# For now each oscillator have the same amplitude
target_amplitudes = np.ones(n_oscillators) * 1.0
rates = np.ones(n_oscillators) * 20.0

phase_biases = phase_biases_tripod_idealized * 2 * np.pi

coupling_weights = (np.abs(phase_biases) > 0).astype(float) * 10.0 

In [8]:
n_stabilisation_steps = 2000
num_steps = n_stabilisation_steps + num_steps_base

# Initilize the simulation
np.random.seed(0)
start_ampl = np.ones(6) * 0.2
obs, info = nmf.reset()
solver = initialize_solver(phase_oscillator, "dopri5", nmf.curr_time,
                            n_oscillators, frequencies,
                              coupling_weights, phase_biases,
                                start_ampl, rates,
                                int_params={"atol": 1e-6, "rtol": 1e-6, "max_step":100000})


# Initalize storage
obs_list = []
phases = np.zeros((num_steps, n_oscillators))
amplitudes = np.zeros((num_steps, n_oscillators))

joint_angles = np.zeros((num_steps, n_joints))
input_joint_angles = np.zeros(len(nmf.actuated_joints))

indices = np.zeros_like(nmf.actuated_joints, dtype=np.int64)

for i in trange(num_steps):

    res = solver.integrate(nmf.curr_time)
    phase = res[:n_oscillators]
    amp = res[n_oscillators:2*n_oscillators]

    phases[i, :] = phase
    amplitudes[i, :] = amp    

    if i == n_stabilisation_steps:
        # Now set the amplitude to their real values
        solver.set_f_params(n_oscillators, frequencies,
                              coupling_weights, phase_biases,
                                target_amplitudes, rates)
    if i > n_stabilisation_steps:
        indices = advancement_transfer(phase, interp_step_duration, match_leg_to_joints)
        # scale amplitude by interpolating between the resting values and i timestep value
        input_joint_angles = data_block[joint_ids, 0] + \
                  (data_block[joint_ids, indices]-data_block[joint_ids, 0])*amp[match_leg_to_joints]
    else:
        input_joint_angles = data_block[joint_ids, 0]

    joint_angles[i, :] = input_joint_angles
    adhesion_signal = np.logical_or(indices[joints_to_leg] < swing_starts_in_order,
                                        indices[joints_to_leg] > stance_starts_in_order)

    action = {"joints": input_joint_angles, "adhesion": adhesion_signal}
    
    obs, reward, terminated, truncated, info = nmf.step(action)
    obs_list.append(obs)
    render_status = nmf.render()

100%|██████████| 7000/7000 [00:29<00:00, 239.71it/s]


In [9]:
nmf.save_video("outputs/force_visualization.mp4", stabilization_time=0.5)