In [1]:
import numpy as np
from scipy.linalg import *
# from drl.env import TwoLinkArm
from drl.env.arm import TwoLinkArm
import time

%matplotlib notebook

In [2]:
epsilon = 1e-5
Ts = 10
env = TwoLinkArm(g=0.)

Q = np.eye(env.state_dim)*100.
R = np.eye(env.action_dim)*1.


## Finite difference derivatives

In [3]:
def calc_derivatives(x, u):
    A = np.zeros((env.state_dim, env.state_dim))
    for i in range(env.state_dim):
        x_tmp = x.copy()
        x_tmp[i] += epsilon
        _, f_1 = env.dynamics_func(x_tmp, u)
        x_tmp = x.copy()
        x_tmp[i] -= epsilon
        _, f_2 = env.dynamics_func(x_tmp, u)
        fxdx = (f_1 - f_2) / (2*epsilon)
        A[:, i] = fxdx
        
    B = np.zeros((env.state_dim, env.action_dim))
    for i in range(env.action_dim):
        u_tmp = u.copy()
        u_tmp[i] += epsilon
        _, f_1 = env.dynamics_func(x, u_tmp)
        u_tmp = u.copy()
        u_tmp[i] -= epsilon
        _, f_2 = env.dynamics_func(x, u_tmp)
        fxdu = (f_1 - f_2) / (2*epsilon)
        B[:, i] = fxdu
        
    return A, B

In [4]:
def run_experiment():
    x = env.reset()
    env.render()
    u = [0.]*env.action_dim
    
    V = 0.
    for _ in range(int(Ts/env.dt)):
        # Calculate optimal feedback gain K
        error = x - env.get_goal()
        A, B = calc_derivatives(error, u)

        P = solve_continuous_are(A, B, Q, R)
        K = np.dot(np.linalg.pinv(R), np.dot(B.T, P))
        
        u = -np.dot(K, error)

        x, _, _, _ = env.step(u)
        
        env.render()
        
        V += np.dot(error.T, np.dot(P, error))
    
    print('Episode %s - Total Cost: %s' % (str(i), V))

    return env.goal - x

In [5]:
for i in range(5):
    error = run_experiment()

Episode 0 - Total Cost: 114.951192596
Episode 1 - Total Cost: 10569.4571968
Episode 2 - Total Cost: 11714.8571677
Episode 3 - Total Cost: 7700.81839004
Episode 4 - Total Cost: 23161.6375303


In [6]:
env.render(close=True)

## Run with changing goal once goal is reached

In [7]:
num_targets = 5
def run_experiment2():
    x = env.reset()
    env.render()
    u = [0.]*env.action_dim
    
    V_tot = 0.
    for i in range(num_targets):
        env.set_goal()
        goal_reached = False
        
        while not goal_reached:
            error = x - env.get_goal()
            A, B = calc_derivatives(error, u)

            P = solve_continuous_are(A, B, Q, R)
            K = np.dot(np.linalg.pinv(R), np.dot(B.T, P))

            u = -np.dot(K, error)

            x, _, _, _ = env.step(u)

            env.render()
            
            V = np.dot(error.T, np.dot(P, error))
            V_tot += V
            if V < 1e-4:
                goal_reached = True
                
        print('Target %s - Total Cost: %s' % (str(i), V_tot))
        

In [8]:
run_experiment2()

[ 27.17042667  14.54708701]
[ 16.22124329   6.88922912]
[ 9.53941952  4.04340974]
[ 8.43374525  1.73213507]
[ 7.04641963  0.62729518]
[ 5.55547115  0.13092439]
[ 4.06681929 -0.07840503]
[ 2.62197922 -0.16056959]
[ 1.23301392 -0.19015856]
[ 0.36114205 -0.2709959 ]
[-0.09744045 -0.35013993]
[-0.33951896 -0.39959032]
[-0.46829024 -0.42454146]
[-0.53798261 -0.43384624]
[-0.57648096 -0.43431455]
[-0.59793785 -0.43011509]
[-0.60948491 -0.42351433]
[-0.61473843 -0.41564985]
[-0.61558862 -0.40706036]
[-0.61309156 -0.39799355]
[-0.60790333 -0.38856738]
[-0.60048587 -0.37884798]
[-0.5912021  -0.36888389]
[-0.58035846 -0.35871939]
[-0.56822349 -0.34839846]
[-0.55503594 -0.33796501]
[-0.5410085  -0.32746212]
[-0.52632998 -0.31693105]
[-0.51116701 -0.30641054]
[-0.49566565 -0.2959364 ]
[-0.4799531 -0.2855413]
[-0.46413935 -0.27525477]
[-0.44831887 -0.26510322]
[-0.43257224 -0.25511012]
[-0.4169677  -0.24529611]
[-0.40156254 -0.23567917]
[-0.38640445 -0.22627477]
[-0.37153271 -0.21709603]
[-0.356979

[ 0.00199001  0.00039892]
[ 0.00189001  0.00037886]
[ 0.00179503  0.00035982]
[ 0.00170482  0.00034173]
[ 0.00161915  0.00032455]
[ 0.00153779  0.00030823]
[ 0.00146052  0.00029274]
[ 0.00138712  0.00027803]
[ 0.00131742  0.00026405]
[ 0.00125122  0.00025078]
[ 0.00118835  0.00023818]
[ 0.00112864  0.00022621]
[ 0.00107193  0.00021484]
[ 0.00101807  0.00020404]
[ 0.00096691  0.00019379]
Target 1 - Total Cost: 47577.8440421
[ 44.64486613  23.3366271 ]
[ 36.46256254   9.07501748]
[ 32.47231862   5.7825675 ]
[ 32.26718375   3.89255752]
[ 31.55104621   2.55828779]
[ 30.16079134   1.49852718]
[ 28.23554544   0.64711136]
[  2.59975391e+01   1.87372635e-02]
[ 23.63791302  -0.37560772]
[ 21.29802465  -0.5742268 ]
[ 19.05897124  -0.64329897]
[ 16.9521032   -0.64349358]
[ 14.97928196  -0.61518023]
[ 13.12946414  -0.57978044]
[ 11.38807582  -0.54646106]
[ 9.74109656 -0.51800324]
[ 8.17638256 -0.49442397]
[ 6.68382329 -0.47482108]
[ 5.25509944 -0.45817541]
[ 3.88335779 -0.44363262]
[ 2.5629124  -0

[ 0.00853976 -0.00487048]
[ 0.00811069 -0.00462839]
[ 0.00770317 -0.0043982 ]
[ 0.00731614 -0.00417935]
[ 0.00694854 -0.00397128]
[ 0.00659942 -0.00377348]
[ 0.00626785 -0.00358545]
[ 0.00595293 -0.00340671]
[ 0.00565383 -0.00323682]
[ 0.00536977 -0.00307533]
[ 0.00509998 -0.00292185]
[ 0.00484374 -0.00277598]
[ 0.00460038 -0.00263734]
[ 0.00436924 -0.00250559]
[ 0.00414972 -0.00238039]
[ 0.00394123 -0.00226141]
[ 0.00374322 -0.00214834]
[ 0.00355515 -0.0020409 ]
[ 0.00337653 -0.00193882]
[ 0.00320689 -0.00184181]
[ 0.00304577 -0.00174964]
[ 0.00289275 -0.00166207]
[ 0.00274741 -0.00157886]
[ 0.00260938 -0.00149981]
[ 0.00247828 -0.0014247 ]
[ 0.00235377 -0.00135334]
[ 0.00223552 -0.00128554]
[ 0.0021232  -0.00122113]
[ 0.00201653 -0.00115994]
[ 0.00191522 -0.00110181]
[ 0.001819   -0.00104659]
[ 0.00172761 -0.00099412]
[ 0.00164082 -0.00094429]
[ 0.00155838 -0.00089694]
[ 0.00148009 -0.00085196]
[ 0.00140573 -0.00080924]
[ 0.00133511 -0.00076865]
[ 0.00126803 -0.0007301 ]
[ 0.00120433

In [9]:
env.render(close=True)