In [1]:
import sys
sys.path.append('./gym')

In [2]:
import gym
import numpy as np
import cvxpy as cp
import math

In [3]:
class cartpole_collect:
    def __init__(self, env):
        self.states = []
        self.actions = []
        self.s_dots = []
        self.x_dots = []
        self.As = []
        self.Bs = []
        self.ABs = []

        self.gravity = env.gravity
        self.masscart = env.masscart
        self.masspole = env.masspole
        self.total_mass = env.total_mass
        self.length = env.length
        self.polemass_length = env.polemass_length
        self.force_mag = env.force_mag
        self.tau = env.tau
        self.Ns = 2
        self.Nc = 1
        
    def add_state(self, state):
        self.states.append(state)
        
    def add_action(self, action):
        self.actions.append([action])
        
    def get_derivatives(self):
        self.num = len(self.actions)
        for i in range(self.num):
            self.x_dots.append(self.get_x_dot(self.states[i], self.actions[i]))
            self.s_dots.append(self.get_s_dot(self.states[i], self.states[i+1]))
            
            A,B = self.get_AB(self.states[i][[1,3]], self.s_dots[i][[1,3]], self.actions[i])
            
            self.As.append(A)
            self.Bs.append(B)
            
        self.alined_AB()
            
    def show(self):
        for i in range(self.num):
#             print(self.x_dots[i][[0,1]])
#             print(self.s_dots[i][[1,3]])
#             print()
            
#             print(self.states[i][[1,3]])
#             print(self.s_dots[i][[0,2]])
#             print()
            
#             print(self.As[i],'\n',self.Bs[i])
            print(self.ABs[i])
            print()
        
        
    def get_x_dot(self, state, action):
        action = action[0]
        x, x_dot, theta, theta_dot = state
        force = self.force_mag if action == 1 else -self.force_mag
        costheta = math.cos(theta)
        sintheta = math.sin(theta)

        temp = (force + self.polemass_length * theta_dot ** 2 * sintheta) / self.total_mass
        thetaacc = (self.gravity * sintheta - costheta * temp) / (self.length * (4.0 / 3.0 - self.masspole * costheta ** 2 / self.total_mass))
        xacc = temp - self.polemass_length * thetaacc * costheta / self.total_mass
        
        return np.array([xacc, thetaacc])

    def get_s_dot(self, s0, s1):
        s_dot = (s1-s0)/self.tau
        
        return s_dot      
    
    def get_AB(self, s, sd, a):
        Ns = self.Ns
        Nc = self.Nc
        
        a[0] = self.force_mag if a[0] == 1 else -self.force_mag
        
        A = cp.Variable((Ns,Ns))
        B = cp.Variable((Ns,Nc))
        
        cost = cp.Minimize(cp.sum_squares(A@s + B@a - sd))
        prob = cp.Problem(cost)
        prob.solve()
        
        return A.value, B.value
        
    def alined_AB(self):
        for i in range(self.num):
            x = np.hstack((self.As[i].flatten(),self.Bs[i].flatten()))
            self.ABs.append(x)

In [4]:
env = gym.make('CartPole-v0')
state = env.reset()


collector = cartpole_collect(env)
collector.add_state(state)


num_steps = 1000
for i in range(num_steps):
    
    action = env.action_space.sample()
    state, reward, done, info = env.step(action)
    
    collector.add_state(state)
    collector.add_action(action)
    if done:
        break
        
env.close()

collector.get_derivatives()

In [6]:
collector.show()

[  68.31251856 -120.40390653 -104.64751165  184.44597675    0.32809758
   -0.5026106 ]

[ 13.37140235  -9.96154752 -20.49415477  15.26791965   0.32722869
  -0.50153868]

[-7.38594375  5.17726788 10.75667856 -7.5400258   0.32576879 -0.4744404 ]

[ 13.34906629  -9.51003803 -20.82907586  14.83888829   0.32764149
  -0.51123196]

[-7.3585861   4.99613521 10.49292243 -7.12420273  0.32527732 -0.46382683]

[ 13.29693635  -8.80080252 -21.12107501  13.97934119   0.32804234
  -0.52106791]

[-7.32144664  4.73971527 10.20076952 -6.60371447  0.32474364 -0.45245634]

[-13.05944865   7.84577807  17.90449017 -10.75655337   0.32457181
  -0.4449876 ]

[-60.81894762  22.17126862  82.55644311 -30.09557298   0.32486835
  -0.44098059]

[-23.16837572  27.72946516  37.82365759 -45.26988894   0.32906141
  -0.537211  ]

[-57.95897768  15.88807321  78.62114312 -21.55211371   0.32477537
  -0.44055661]

[-23.61726342  54.40798145  38.61050414 -88.94847621   0.32924363
  -0.53826146]

[ 56.07791397 -12.51529941 -91.