In [3]:
import numpy as np
from scipy.integrate import odeint


class PIDController:
    """
        v = Final output of plant
        u = Input signal to the plant
    """
    def __init__(self, dt, initial_error, set_point):
        self.sum_error = 0
        self.dt = dt

        self.prev_error = initial_error
        self.set_point = set_point

    def set_set_point(self, sp):
        self.set_point = sp

    def set_k_constants(self, Ki, Kp, Kd):
        self.params = (Ki, Kp, Kd)

    def get_pid_output(self, v):
        """
            Given the latest measurement of v(t) we obtain new u(t)
        """
        Ki, Kp, Kd = self.params
        dt = self.dt
        #print(dt)
        error = self.set_point - v
        #print(error)
        self.sum_error += error*dt
        derror = (error - self.prev_error)/dt
        
        u = Kp*error + Ki*self.sum_error + Kd*derror
        
        #temp acts as prev_error that needs to be send
        temp = self.prev_error
        self.prev_error = error
        #print(u)
        return u, error, temp, derror
    

class Model:
    """
        u  = Plant input
        ycaps = y1, y2 where y1 is the Plant Output
    """
    
    def __init__(self, initial_y_caps):
        self.y_caps = [initial_y_caps]
        self.u = 0
        
    def de_solver(self, y_cap, t, params):
        """
        Given a vector <y, y1 > returns <y1, y2 > according to equation
        above.
        """
        y, y1 = y_cap # y, dy/dt
        x = params

        y2 = ((- 831.6 *y1)  -2.133*10**7*y +(2.109*10**6*x))
        #y2 = ((-4.826*10**5 *y1) - 4.826*10**5*y +(2.109*10**6*x))
        #2.14*y2 - 9.276*y1 - 4.228*y + 4.228*x
        dy_cap = y1, y2
        #print(dy_cap)
        #print(x)
        return dy_cap
    
    def set_model_input(self, u_):
        self.u = u_
        
    def get_model_output(self, t_step):
        #y_cap_solved = odeint(self.de_solver, self.y_caps[-1], t_step, args=(self.u,))[-1]
        #print(y_cap_solved)
        y_cap_solved = odeint(self.de_solver, self.y_caps[-1], t_step, args=(self.u,))
        self.y_caps.append(y_cap_solved[-1])
        return self.y_caps[-1][0]


class PIDModel:
    """
        Action space is (Kd', Kp', alpha), this class handles all
        the denormalisation

        Kpmin, Kpmax, etc are parameters set in __init__
    """
    def __init__(self, ku, tu, t, SP):
        """
            Add arguments to this method as necessary
            Kp_min, Kp_max, etc should be arguments
        """
        self.ku = ku
        self.tu = tu
        self.kpmax = 0.6*ku
        self.kpmin = 0.32*ku
        self.kdmax = 0.15*ku*tu
        self.kdmin = 0.08*ku*tu
        
        self.count = 0
        self.setpoint = SP
        self.t = t
        
        self.model = Model(initial_y_caps=(0,0))
        self.pid = PIDController(dt=t[1]-t[0], initial_error=0, set_point=SP[0])
        
    def get_next_count(self):
        """Used to update the time-step through count""" 
        self.count = self.count + 1
        return self.count
    
    def get_reward(self, e_t, prev_e_t, de_t):
        alpha1 = 1
        alpha2 = 1
        r1 = np.exp(-np.abs(e_t-prev_e_t))
        if de_t >= 0:
            r2 = 1 - np.tanh(de_t)
        else:
            r2 = np.tanh(de_t)

        r = alpha1*r1 + alpha2*r2
        return r


    def step(self, action):
        """
            Support this:
            new_state, reward, done = env.step(action)

            done is a boolean whether or not episode is finished
        """
        kd_, kp_, alpha = action
        """
            De-Normalising kd', kp' values
        """
        #kp = (self.kpmax - self.kpmin)*kp_ + self.kpmin
        #kd = (self.kdmax - self.kdmin)*kd_ + self.kdmin
        #ki = kp**2/(alpha*kd)
        kp = kp_ 
        kd = kd_
        #ki = kp**2/(alpha*kd)
        ki = alpha*kd
        #Updating PID parameters
        self.pid.set_k_constants(ki, kp, kd)
    
        #Getting plant output for step-time
        v = self.model.get_model_output(t_step=self.t[self.count: self.count+2])
        #print(v)
        #print(self.t[self.count: self.count+2])
        #Detting PID output, error, derivative of error corresponding to recent plant output
        u, e_t, prev_e_t, de_t = self.pid.get_pid_output(v)
        
        #Updating plants input to output of PID
        self.model.set_model_input(u)
        #Updating PIDs setpoint for next time-step also updating count for next time-step
        self.pid.set_set_point(self.setpoint[self.get_next_count()])
        
        reward = self.get_reward(e_t, prev_e_t, de_t)
        new_state = kd_, kp_, alpha, e_t, de_t
        
        if (self.count + 1) == len(self.t):
            done = True
        else:
            done = False
            
        return new_state, reward, done

    def reset(self):
        """
            Restart the episode, clear all data
        """
        self.model.__init__((0, 0))
        self.pid.__init__(dt=self.t[1]-self.t[0], initial_error=0, set_point=self.setpoint[0])
        self.__init__(self.ku, self.tu, self.t, self.setpoint)
        
    def output(self):
        """
            Returns array of [y1, y2] where y1 is plants output       
        """
        return self.model.y_caps


if __name__ == "__main__":
    env = PIDModel(1, 0.1, t=np.linspace(0, 200, num=1000), SP=np.ones(1000)*1000)

    action = (0.1, 0.1, 3)
    for i in range(500):
        new_state, reward, done = env.step(action)
        print(new_state, reward)

        if done: break




(0.1, 0.1, 3, 1000.0, 4995.0) 0.0
(0.1, 0.1, 3, 931.986478326346, -339.72754075990196) -1.0
(0.1, 0.1, 3, 980.9966323755813, 244.80571947593043) 5.189918514815713e-22
(0.1, 0.1, 3, 970.1259038080957, -54.29928919459042) -0.9999809934816672
(0.1, 0.1, 3, 967.8604121287112, -11.316130938525415) -0.8962210037018901
(0.1, 0.1, 3, 962.3676708932622, -27.436242471067697) -0.9958834557236067
(0.1, 0.1, 3, 956.1362703331093, -31.125845797963763) -0.9980333045085036
(0.1, 0.1, 3, 950.1429718182104, -29.936526081920128) -0.9975045807177491
(0.1, 0.1, 3, 944.3711750546992, -28.83012483373868) -0.9968858429194664
(0.1, 0.1, 3, 938.9999932910355, -26.829052909500106) -0.9953513655243658
(0.1, 0.1, 3, 933.2601910450647, -28.670312218623828) -0.9967845959284499
(0.1, 0.1, 3, 927.901382895106, -26.767246709043658) -0.9952934877771175
(0.1, 0.1, 3, 922.4106014282675, -27.426453426858796) -0.9958753803389925
(0.1, 0.1, 3, 917.211834032445, -25.96784314213296) -0.9944766316464717
(0.1, 0.1, 3, 912.236789

(0.1, 0.1, 3, 496.6182995685814, -16.139817324657244) -0.9604897307396857
(0.1, 0.1, 3, 493.8912192142991, -13.621766369640055) -0.9345900148629969
(0.1, 0.1, 3, 490.8006379107565, -15.437453611195247) -0.9545244883562107
(0.1, 0.1, 3, 488.075522060665, -13.611953671207175) -0.934461390372716
(0.1, 0.1, 3, 485.01158106362243, -15.304385280227539) -0.9532967254088337
(0.1, 0.1, 3, 482.3580961177297, -13.254157304734177) -0.9295945740620193
(0.1, 0.1, 3, 479.34814750942473, -15.034693298483308) -0.9507057879913473
(0.1, 0.1, 3, 477.1017618913214, -11.220696162426073) -0.8942191327450525
(0.1, 0.1, 3, 473.82604199043885, -16.362220904908405) -0.9622103453711138
(0.1, 0.1, 3, 470.95987203827553, -14.316518911055773) -0.9430834974210618
(0.1, 0.1, 3, 468.198061288556, -13.79524469484905) -0.9368227334916969
(0.1, 0.1, 3, 465.4408907288048, -13.772066945957326) -0.9365288977737789
(0.1, 0.1, 3, 462.6943751178168, -13.718845476885122) -0.9358490007295721
(0.1, 0.1, 3, 460.547626046308, -10.72

(0.1, 0.1, 3, 250.5942107278155, -7.412065072587497) -0.7732469398810715
(0.1, 0.1, 3, 249.12134804802702, -7.356949085543443) -0.7707309621875124
(0.1, 0.1, 3, 247.86202840719181, -6.29030160597183) -0.7161460451223992
(0.1, 0.1, 3, 246.15637720711413, -8.519727744388021) -0.8183458658881615
(0.1, 0.1, 3, 244.7826405536772, -6.861814583917494) -0.7468385858370863
(0.1, 0.1, 3, 243.30568746543133, -7.37738067578809) -0.7716668800961322
(0.1, 0.1, 3, 241.88773085628634, -7.082693262679245) -0.7577901523654732
(0.1, 0.1, 3, 240.47236019792865, -7.06977643849665) -0.7571629658365292
(0.1, 0.1, 3, 239.09110868648668, -6.899351299652659) -0.7487340695408331
(0.1, 0.1, 3, 237.72269802209019, -6.835211268660479) -0.7454865458236821
(0.1, 0.1, 3, 236.32931420537977, -6.959952164468504) -0.7517642944310379
(0.1, 0.1, 3, 234.92344121751728, -7.02233557437316) -0.7548454621379619
(0.1, 0.1, 3, 233.50537558513633, -7.083237833742817) -0.7578165588081307
(0.1, 0.1, 3, 232.10356398044974, -7.0020489

(0.1, 0.1, 3, 128.69562968389243, -4.007973146865451) -0.5510867060376314
(0.1, 0.1, 3, 128.1041877581779, -2.954252418944168) -0.4410536638871134
(0.1, 0.1, 3, 127.16982849623162, -4.667124513421617) -0.6069858326620006
(0.1, 0.1, 3, 126.46887944286323, -3.501240521575109) -0.5020681672159432
(0.1, 0.1, 3, 125.75784485405188, -3.5516177711126646) -0.5072206616446528
(0.1, 0.1, 3, 125.06029393090989, -3.484266861094236) -0.5003167361617853
(0.1, 0.1, 3, 124.28266853789899, -3.8842388380894843) -0.5396588055933995
(0.1, 0.1, 3, 123.52987467215132, -3.760205359409574) -0.5278681012539157
(0.1, 0.1, 3, 122.78386687380942, -3.7263089527178246) -0.5245847248381106
(0.1, 0.1, 3, 122.0476025247741, -3.6776404234313698) -0.5198228210932048
(0.1, 0.1, 3, 121.44741864974264, -2.997918455782152) -0.44632344188491324
(0.1, 0.1, 3, 120.59398504127466, -4.262900874297563) -0.5736536341849923
(0.1, 0.1, 3, 119.95812449050027, -3.176123451118108) -0.4670407626891726
(0.1, 0.1, 3, 119.21142353469781, -

(0.1, 0.1, 3, 64.9779599719053, -1.7321108768831777) -0.23233623176694007
(0.1, 0.1, 3, 64.5722927269768, -2.0263078884179015) -0.2993079510740506
(0.1, 0.1, 3, 64.19988919424316, -1.8601556460045434) -0.26361710207617184
(0.1, 0.1, 3, 63.79309361265746, -2.0319439300205593) -0.3004360252851086
(0.1, 0.1, 3, 63.44498852584411, -1.738784908632673) -0.2340608820097051
(0.1, 0.1, 3, 63.04360303205533, -2.004920541474956) -0.29498161150180635
(0.1, 0.1, 3, 62.69513007899616, -1.7406224005305648) -0.23453377681212317
(0.1, 0.1, 3, 62.32589541167863, -1.844327163251052) -0.2599456225223261
(0.1, 0.1, 3, 61.96894393657283, -1.7829726181534822) -0.2452075124204135
(0.1, 0.1, 3, 61.602554385419126, -1.8301158080127418) -0.2566047049879413
(0.1, 0.1, 3, 61.266275607100056, -1.6797124977037532) -0.2184003222964489
(0.1, 0.1, 3, 60.88779097720612, -1.8905307263201945) -0.2705209135349681
(0.1, 0.1, 3, 60.5298160082732, -1.7880849698199448) -0.24646751891330798
(0.1, 0.1, 3, 60.17763150824078, -1.7