# 1. MPPI - Standing


#### 0. Load scene

In [None]:
import mujoco
import numpy as np

""" LOAD MJCF & CREATE ENVIRONMENT """
model = mj.MjModel.from_xml_path('./asset/scene.xml')
data = mj.MjData(model)

# for mppi class
model_mppi = mj.MjModel.from_xml_path('./asset/scene.xml')
data_mppi = mj.MjData(model_mppi)

In [None]:
""" MUJOCO VIEWER CLASS """
# archive to pp_base_mujoco
# make it as package
import time
import sys
import glfw 

class MUJOCOVIEWERCLASS():
    def __init__(self, model, data, camera_names, size):

        # initialize with empty lists
        self.windows = []
        self.contexts = []
        self.scenes = []
        self.cameras = []
        self.options = []
        self.viewport=mujoco.MjrRect(0, 0, 0, 0)
        if not glfw.init():
            sys.exit("couldn't initialize glfw")
        self.last_x, self.last_y = 0, 0
        self.mouse_button = None
        
        # create cameras
        for i, name in enumerate(camera_names):
            window = glfw.create_window(size[i][0], size[i][1], f"Camera: {name}", None, None)
            if not window:
                glfw.terminate()
                raise RuntimeError("GLFW window creation failed")
            glfw.make_context_current(window)

            # scene, camera, options
            scene = mujoco.MjvScene(model, maxgeom=1000)
            cam = mujoco.MjvCamera()
            cam.type = mujoco.mjtCamera.mjCAMERA_FIXED
            cam.fixedcamid = 0
            opt = mujoco.MjvOption()

            # append to list
            self.scenes.append(scene)
            self.cameras.append(cam)
            self.options.append(opt)
            self.windows.append(window)
            self. contexts.append(mujoco.MjrContext(model, mujoco.mjtFontScale.mjFONTSCALE_150))

            # set callbacks
            glfw.set_mouse_button_callback(window, self.mouse_button_callback)
            glfw.set_cursor_pos_callback(window, self.cursor_pos_callback)
            glfw.set_scroll_callback(window, self.scroll_callback)  

        # render interval
        self.render_interval = 0.1
        self.last_render_time = time.time()

    def is_alive(self):
        return all([not glfw.window_should_close(win) for win in self.windows])
    
    def close(self):
        for window in self.windows:
            glfw.destroy_window(window)
        glfw.terminate()

    def render(self):
        current_time = time.time()
        if current_time - self.last_render_time >= self.render_interval:
            self.last_render_time = current_time

            for i, window in enumerate(self.windows):
            
                glfw.make_context_current(window)
                width, height = glfw.get_framebuffer_size(window)
                global_width = width
                global_height = height
                self.viewport.width = width
                self.viewport.height = height

                # update scene & render
                mujoco.mjv_updateScene(model, data, self.options[i], None, self.cameras[i],
                                    mujoco.mjtCatBit.mjCAT_ALL, self.scenes[i])
                mujoco.mjr_render(self.viewport, self.scenes[i], self.contexts[i])
                glfw.swap_buffers(window)

            glfw.poll_events()
            time.sleep(0.01)

    def update_viewer(self):
        # update with clicks & callbacks
        pass

    def mouse_button_callback(self, window, button, action, mods):
        if action == glfw.PRESS:
            self.mouse_button = button
            # print("mouse pressed")
        elif action == glfw.RELEASE:
            self.mouse_button = None

    def cursor_pos_callback(self, window, xpos, ypos):
        dx = (xpos - self.last_x)/1000
        dy = (ypos - self.last_y)/1000
        self.last_x, self.last_y = xpos, ypos
        # print(f"mouse position: {xpos}, {ypos}")

        if self.mouse_button is not None:
            action = {
                glfw.MOUSE_BUTTON_LEFT: mujoco.mjtMouse.mjMOUSE_ROTATE_H,
                glfw.MOUSE_BUTTON_RIGHT: mujoco.mjtMouse.mjMOUSE_MOVE_H,
                glfw.MOUSE_BUTTON_MIDDLE: mujoco.mjtMouse.mjMOUSE_ZOOM
            }.get(self.mouse_button, None)

            if action is not None:
                mujoco.mjv_moveCamera(model, action, dx, dy, self.scenes[0], self.cameras[0])
                # print("cam moved")

    def scroll_callback(self, window, xoffset, yoffset):
        # Zoom camera with scroll wheel
        mujoco.mjv_moveCamera(model, mujoco.mjtMouse.mjMOUSE_ZOOM, 0.0, -yoffset/100, self.scenes[0], self.cameras[0])
        # print(f"offsets x:{xoffset}, y:{yoffset}")




#### 1. Declare MPPI Controller

In [None]:
""" MPPI CONTROLLER CLASS """

class MPPICONTROLLER:
    def __init__(
            self,
            model,
            data,
            cost_function,
            n_sample            = 100,
            horizon             = 20, # mujoco steps
            lambda_             = 1.0,
            alpha               = 0.5,
            sigma               = 0.5,
            exploration_rate    = 0.1
        ):
        self.model              = model
        self.data               = data
        self.cost_function      = cost_function

        self.lambda_            = lambda_
        self.alpha              = alpha
        self.gamma              = lambda_ * (1.0 - alpha)
        self.exploration_rate   = exploration_rate

        self.n_ctrl             = model.nu
        self.ctrl = np.zeros((horizon, model.nu)) # control sequence

    def sample_epsilon(self):
        """ SAMPLE RANDOM NOISE """
        epsilon = np.random.normal(
            loc     = 0.0,
            scale   = self.sigma,
            size    = (self.ctrl.shape[0], self.ctrl.shape[1])
        )
        print("epsilon shape:", epsilon.shape)
        return epsilon

    def simulate_cost(self):
        """ GENERATE CONTROL SIGNAL & SIMULATE COST """
        costs = np.zeros(self.n_sample)

        for k in range(self.n_sample):
            current_control = np.zeros_like(self.ctrl)

            for t in range(self.horizon):
                if k < self.n_sample * self.exploration_rate:
                    epsilon = self.sample_epsilon()
                    current_control[t] = self.ctrl[t] + epsilon[t]
                else:
                    current_control[t] = self.ctrl[t]
                
                # simulator step
                self.data.ctrl[:] = current_control[t]
                mujoco.mj_step(self.model, self.data)
                state = self.data.qpos, self.data.qvel, self.data.xpos, self.data.cvel

                # calculate cost
                costs[k] += self.cost_function(state)
        
        self.costs = costs
        return costs
    
    def compute_weights(self):
        """ COMPUTE WEIGHTS FROM COSTS """
        w = np.zeros((self.K))
        rho = S.min()
        # calculate eta
        eta = 0.0
        for k in range(self.n_sample):
            eta += np.exp( (-1.0/self.param_lambda) * (self.costs[k]-rho) )
        # calculate weight
        for k in range(self.n_sample):
            w[k] = (1.0 / eta) * np.exp( (-1.0/self.lambda_) * (self.costs[k]-rho) )
        self.weights = w
        return w

    def return_action(self):
        """ WEIGHTED AVERAGE & RETURN FIRST ACTION """
        new_ctrl = np.zeros_like(self.ctrl)
        for t in range(self.horizon):
            for k in range(self.n_sample):
                new_ctrl[t] += self.weights[k] * ( self.ctrl[t] + self.epsilon[k][t] )
        self.ctrl = new_ctrl
        return self.ctrl[0], self.ctrl

In [None]:
""" COST FUNCTION: define own cost function """
def cost_function_standing(state):
    # state: self.data.qpos, self.data.qvel, self.data.xpos, self.data.cvel

#### 2. Main Code
- Env outside: rendering & executing
- Env inside MPPI: for cost calculation

In [None]:
""" MAIN LOOP """
mujoco.mj_resetData(model, data)
mujoco.mj_resetData(model_mppi, data_mppi)
viewer = MUJOCOVIEWERCLASS(
    model,
    data,
    camera_names = ['camera1'],
    size = [(640, 480)]
    )

controller = MPPICONTROLLER(
    model_mppi,
    data_mppi,
    cost_function_standing,
    n_sample=100,
    )

while viewer.is_alive:
    # 1. get state from mujoco
    current_state = data.qpos, data.qvel, data.xpos, data.xvelp

    # 2. compute control signal from mppi
    controller.sample_epsilon()
    controller.simulate_cost()
    controller.compute_weights()
    action, action_horizon = controller.return_action()

    # 3. step mujoco & render
    data.ctrl[:] = action
    mujoco.mj_step(model, data)
    viewer.render()


# close
viewer.close()