In [1]:
import torch
import numpy as np
from mppi import MPPI
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython import display
from tqdm.notebook import tqdm
from celluloid import Camera
import os
import imageio

In [2]:
%load_ext autoreload
%autoreload 2

In [24]:
from DDPController import *
from cartpoleDynamics import *
import autograd.numpy as np

if not os.path.exists('cartpole'):
    os.makedirs('cartpole')

state_dim = 4
action_dim = 1
x_final = np.array([.0, .0, .0, .0])
Q = np.diag([10, .1, 100, 10.])
R = np.array([[.3]])
terminal_scale = 100.0
cost = CartpoleCost(x_final, terminal_scale, Q, R)
DDP_dynamic = dynamics

controller = DDPcontroller(DDP_dynamic, cost, tolerance = 1e-3, max_iter = 100, T = 10, state_dim = state_dim, control_dim = action_dim, rho = 0.9, max_dc_iter = 10, dt = 0.05)

# initial_state = np.random.randn(state_dim)
initial_state = np.array([0,0,np.pi,0])

state = initial_state

target = x_final

num_steps = 100
pbar = tqdm(range(num_steps))

for i in pbar:
    
    action = controller.command(state)
    
    state = DDP_dynamic(state, action)
    state = state.squeeze()
    # print(state)
    dx = (state-target)
    d_theta = np.mod(dx[2] + np.pi, 2 * np.pi) - np.pi
    error = np.array([dx[0], dx[1], d_theta, dx[3]]) @ np.diag([0.1, 0.1, 1, 0.1])
    error_i = np.linalg.norm(error)
    pbar.set_description(f'Goal Error: {error_i:.4f}')

    # --- Start plotting
    fig, ax = plt.subplots()
    ax = plt.axes(xlim=(state[0]-10, state[0]+10), ylim=(-2, 2))
    ax.set_aspect('equal')
    ax.grid()
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_title('cartpole at t={:.2f}'.format(i*0.05))
    x = state[0]
    theta1 = state[2]
    L1 = 0.5
    x1 = x + L1*np.sin(theta1)
    y1 = L1*np.cos(theta1)
    plt.plot([x,x1],[0,y1],color='black')   
    filename = os.path.join('cartpole', 'plot_{:03d}.png'.format(i))
    # plt.show()
    plt.savefig(filename)
    plt.close()
    if error_i < 0.1 and i > 50:
        num_steps = i
        break
    # --- End plotting

images = []
for i in range(num_steps):
    filename = os.path.join('cartpole', 'plot_{:03d}.png'.format(i))
    images.append(imageio.imread(filename))
imageio.mimsave('cartpole.gif', images, duration=0.1)
    # --- End plotting
plt.show()
plt.close()



  0%|          | 0/100 [00:00<?, ?it/s]

  images.append(imageio.imread(filename))
