In [1]:
import gym
import random
import collections
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from matplotlib import pyplot as plt

In [2]:
import CSTRenv_v2

# hyperpararameters

In [3]:
#Hyperparameters
lr_mu        = 0.0005
lr_q         = 0.001
gamma        = 0.99
batch_size   = 32
buffer_limit = 50000
tau          = 0.005 # for target network soft update

# Replay Buffer

In [4]:
class ReplayBuffer():
    def __init__(self):
        self.buffer = collections.deque(maxlen=buffer_limit)

    def put(self, transition):
        self.buffer.append(transition)
    
    def sample(self, n):
        mini_batch = random.sample(self.buffer, n)
        s_lst, a_lst, r_lst, s_prime_lst, done_mask_lst = [], [], [], [], []

        for transition in mini_batch:
            s, a, r, s_prime, done = transition
            s_lst.append(s)
            a_lst.append(a)
            r_lst.append([r])
            s_prime_lst.append(s_prime)
            done_mask = 0.0 if done else 1.0 
            done_mask_lst.append([done_mask])
        
        return torch.tensor(s_lst, dtype=torch.float), torch.tensor(a_lst, dtype=torch.float), \
                torch.tensor(r_lst, dtype=torch.float), torch.tensor(s_prime_lst, dtype=torch.float), \
                torch.tensor(done_mask_lst, dtype=torch.float)
    
    def size(self):
        return len(self.buffer)

# actor critic network

In [5]:
class MuNet(nn.Module):
    def __init__(self):
        super(MuNet, self).__init__()
        self.fc1 = nn.Linear(2, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc_mu = nn.Linear(64, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        mu = torch.tanh(self.fc_mu(x))  # range [-1,1]
        return mu

In [6]:
class QNet(nn.Module):
    def __init__(self):
        super(QNet, self).__init__()
        self.fc_s = nn.Linear(2, 64)
        self.fc_a = nn.Linear(1,64)
        self.fc_q = nn.Linear(128, 32)
        self.fc_out = nn.Linear(32,1)

    def forward(self, x, a):
        h1 = F.relu(self.fc_s(x))
        h2 = F.relu(self.fc_a(a))
        cat = torch.cat([h1,h2], dim=1)
        q = F.relu(self.fc_q(cat))
        q = self.fc_out(q)
        return q

In [7]:
class OrnsteinUhlenbeckNoise:
    def __init__(self, mu):
        self.theta, self.dt, self.sigma = 0.1, 0.01, 0.1
        self.mu = mu
        self.x_prev = np.zeros_like(self.mu)

    def __call__(self):
        x = self.x_prev + self.theta * (self.mu - self.x_prev) * self.dt + \
                self.sigma * np.sqrt(self.dt) * np.random.normal(size=self.mu.shape)
        self.x_prev = x
        return x

# train

In [8]:
def train(mu, mu_target, q, q_target, memory, q_optimizer, mu_optimizer):
    s,a,r,s_prime,done_mask  = memory.sample(batch_size)
    
    target = r + gamma * q_target(s_prime, mu_target(s_prime)) * done_mask
    q_loss = F.smooth_l1_loss(q(s,a), target.detach())
    q_optimizer.zero_grad()
    q_loss.backward()
    q_optimizer.step()
    
    mu_loss = -q(s,mu(s)).mean() # That's all for the policy loss.
    mu_optimizer.zero_grad()
    mu_loss.backward()
    mu_optimizer.step()

In [9]:
def soft_update(net, net_target):
    for param_target, param in zip(net_target.parameters(), net.parameters()):
        param_target.data.copy_(param_target.data * (1.0 - tau) + param.data * tau)

In [10]:
def main():
    env = CSTRenv_v2.CSTRenv()
    memory = ReplayBuffer()

    ## critic 
    q, q_target = QNet(), QNet()
    q_target.load_state_dict(q.state_dict())
    
    ## actor for Kp
    mu_p, mu_target_p = MuNet(), MuNet()
    mu_target_p.load_state_dict(mu_p.state_dict())
    
    ## actor for Ki
    mu_i, mu_target_i = MuNet(), MuNet()
    mu_target_i.load_state_dict(mu_i.state_dict())

    score = []
    print_interval = 20

    mu_p_optimizer = optim.Adam(mu_p.parameters(), lr=lr_mu)
    mu_i_optimizer = optim.Adam(mu_i.parameters(), lr=lr_mu)
    q_optimizer  = optim.Adam(q.parameters(), lr=lr_q)
    ou_noise = OrnsteinUhlenbeckNoise(mu=np.zeros(1))

    for n_epi in range(10000):
        s = env.reset()  # initial state
        done = False

        count = 0
        while count < 200 and not done:
            a = mu_p(torch.from_numpy(np.array(s)).float()) 
            a = [a[i].item() + ou_noise()[0] for i in range(len(a))]
            s_prime, r, done, update_phase = env.step(a)
            memory.put((s,a,r/100.0,s_prime,done))
            score.append(r)
            s = s_prime
            count += 1
                
            if memory.size()>2000:
                if update_phase == 'Kp':
                    for i in range(10):
                        train(mu_p, mu_target_p, q, q_target, memory, q_optimizer, mu_p_optimizer)
                        soft_update(mu_p, mu_target_p)
                        soft_update(q,  q_target)
                elif update_phase == 'Ki':
                    for i in range(10):
                        train(mu_i, mu_target_i, q, q_target, memory, q_optimizer, mu_i_optimizer)
                        soft_update(mu_i, mu_target_i)
                        soft_update(q,  q_target)                    
                
        
        if n_epi%print_interval==0 and n_epi!=0:
            print("# of episode :{}, avg score : {:.1f}".format(n_epi, sum(score[count-print_interval:count])/print_interval))

    env.close()

In [11]:
if __name__ == '__main__':
    main()

  logger.warn(f"Box bound precision lowered by casting to {self.dtype}")
  return torch.tensor(s_lst, dtype=torch.float), torch.tensor(a_lst, dtype=torch.float), \


# of episode :20, avg score : -0.4
# of episode :40, avg score : -0.4
# of episode :60, avg score : -0.4
# of episode :80, avg score : -0.4
# of episode :100, avg score : -0.4
# of episode :120, avg score : -0.4
# of episode :140, avg score : -0.4
# of episode :160, avg score : -0.4
# of episode :180, avg score : -0.4
# of episode :200, avg score : -0.4
# of episode :220, avg score : -0.4
# of episode :240, avg score : -0.4
# of episode :260, avg score : -0.4
# of episode :280, avg score : -0.4
# of episode :300, avg score : -0.4
# of episode :320, avg score : -0.4
# of episode :340, avg score : -0.4
# of episode :360, avg score : -0.4
# of episode :380, avg score : -0.4


KeyboardInterrupt: 