In [None]:
import conf_talos_full_with_locked_joints as conf
from alternating_planner import ComboPlanner
from pin_bullet import SimuProxy
import pinocchio as pin 
import numpy as np
import pybullet 
import time 

planner = ComboPlanner(conf, MPC=True)
N_traj, N_mpc_centroidal, N_mpc_wbd = planner.N_traj, planner.N_mpc_centroidal, planner.N_mpc_wbd
robot, rmodel, rdata = conf.robot, conf.rmodel, conf.rdata
N_interpol = int(conf.dt/conf.dt_ctrl)
nq = rmodel.nq
centroidal_planner = planner.centroidal_planner
wbd_planner = planner.wbd_planner
croc_state = wbd_planner.whole_body_model.state
hg0 = centroidal_planner.x_init[0]
# create open-loop solution tuples
wbd_nx, wbd_nu = len(wbd_planner.x_init[0]), len(wbd_planner.u_init[0])
X_sim_centroidal = np.zeros((N_traj, N_mpc_centroidal+1, wbd_nx))
U_sim_centroidal = np.zeros((N_traj, N_mpc_centroidal, wbd_nu))
wbd_sol = []
# create closed-loop solution tuples
centroidal_nx, centroidal_nu = 9, 12
X_sol_centroidal = np.zeros((N_traj, centroidal_nx))
U_sol_centroidal = np.zeros((N_traj, centroidal_nu))
models = SimuProxy()
models.loadExampleRobot("talos")
models.loadBulletModel(pybullet.GUI)
models.freeze(
    ["arm_left_5_joint",
     "arm_left_6_joint",
     "arm_left_7_joint",
     "arm_right_5_joint",
     "arm_right_6_joint",
     "arm_right_7_joint",
     "gripper_left_joint",
     "gripper_right_joint",
     "head_1_joint",
     "head_2_joint"]
     )
models.setTalosDefaultFriction()
models.setTorqueControlMode()

for traj_time_idx in range(N_traj):
    centroidal_planner.update_ocp(traj_time_idx, hg0)
    if centroidal_planner.ocp.solver_options.nlp_solver_type == 'SQP_RTI':
        # feedback rti_phase (solving QP)
        print('starting RTI feedback phase ' + '...')
        centroidal_planner.acados_solver.options_set('rti_phase', 2)
        t_feedback = time.time()
        status = centroidal_planner.acados_solver.solve()
        elapsed_feedback = time.time() - t_feedback
        print('RTI feedback phase took ' + str(elapsed_feedback) + " seconds")
        centroidal_planner.acados_solver.print_statistics()
        if status == 0:
            print("HOORAY ! found a solution after :", 
                    centroidal_planner.elapsed_prep+elapsed_feedback, " seconds")
        else:
            raise Exception(f'acados returned status {status}.')
    else:
        t = time.time()
        status = centroidal_planner.acados_solver.solve()
        elapsed_time= time.time() - t
        centroidal_planner.acados_solver.print_statistics()
        if status == 0:
            print("HOORAY ! found a solution after :", elapsed_time, " seconds")
        else:
            raise Exception(f'acados returned status {status}.')        
    x_sol = np.array([centroidal_planner.acados_solver.get(i,"x") for i in range(N_mpc_centroidal+1)])
    u_sol = np.array([centroidal_planner.acados_solver.get(i,"u") for i in range(N_mpc_centroidal)])
    # add WBD tracking costs from the centroidal solver solution
    wbd_planner.update_ocp(traj_time_idx, centroidalTask=None, forceTask=u_sol)
    # solve WBD OCP
    if traj_time_idx == 0:
        wbd_planner.solver.solve(wbd_planner.x_init, wbd_planner.u_init)  
    else:
        wbd_planner.solver.solve(xs, us)
    xs = [wbd_planner.solver.xs[i] for i in range(len(wbd_planner.solver.xs))]
    us = [wbd_planner.solver.us[i] for i in range(len(wbd_planner.solver.us))]
    # save open-loop solution
    sol_k = wbd_planner.get_solution_trajectories()
    # scaling DDP gains (what is the proper way of doing this?)
    gains = sol_k['gains'][0]#*(conf.dt/N_interpol)
    x_des, tau_ff = wbd_planner.interpolate_one_step(
                        q=sol_k['jointPos'][0], q_next=sol_k['jointPos'][1],
                     qdot=sol_k['jointVel'][0], qdot_next=sol_k['jointVel'][1],
                    tau=sol_k['jointTorques'][0], tau_next=sol_k['jointTorques'][1]
                    )
    for ctrl_time_idx in range(N_interpol):
        x_meas = models.getState()
        tau_k = tau_ff[ctrl_time_idx] + gains @ (croc_state.diff(x_meas, x_des[ctrl_time_idx])) 
        # send torques andstep simulation 
        models.step(tau_k)
    x_meas = models.getState()
    q_k, dq_k = x_meas[:nq], x_meas[nq:]
    robot.framesForwardKinematics(q_k)
    com = pin.centerOfMass(rmodel, rdata, q_k, dq_k)
    robot.centroidalMomentum(q_k, dq_k)
    wbd_sol += [sol_k]
    # # save closed-loop solution
    X_sol_centroidal[traj_time_idx] = x_sol[0]
    U_sol_centroidal[traj_time_idx] = u_sol[0]
    # warm-start solver from the previous solution 
    xs = xs[1:] + [xs[-1]]     
    us = us[1:] + [us[-1]]    
    # update solvers initial conditions
    # OL-MPC
    # hg0 = x_sol[1]
    # x0 = xs[0]
    # CL-MPC
    hg0 =  np.concatenate([com, np.array(rdata.hg)])
    x0 = np.concatenate([q_k, dq_k])
    wbd_planner.x0 = x0