#### Import Libraries & Params

In [None]:
import gym
import pybulletgym
import numpy as np
import collections
import random
import torch
from torch._C import Size
from torch.distributions import Normal
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import os
import time
from copy import deepcopy
import ray

from IPython import display
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
# Params for SAC Algorithm
env = gym.make('AntPyBulletEnv-v0')
buffer_limit = 1000000
lr_q = 0.0003
lr_pi = 0.0003
lr_alpha = 0.0003
gamma = 0.99
batch_size = 256
init_alpha = 0.1
tau = 0.005
target_entropy = -env.action_space.shape[0]
n_workers = 6
n_cpu = 6
gradient_step = 1 * n_workers

# Params for run & save model
folder =0 
device = torch.device('cpu')

#### ReplayBuffer & Actor & Critic Class & RAY

In [None]:
class ReplayBuffer():
    def __init__(self):
        self.buffer = collections.deque(maxlen=buffer_limit)

    def put(self,item):
        self.buffer.append(item)

    def sample(self,n):
        mini_batch = random.sample(self.buffer,n)
        s_list, a_list, r_list, s_prime_list, done_mask_list = [], [], [], [], []

        for item in mini_batch:
            s, a, r, s_prime, done = item
            s_list.append(s)
            a_list.append(a)
            r_list.append([r])
            s_prime_list.append(s_prime)
            done_mask = 0.0 if done else 1.0 
            done_mask_list.append([done_mask])
        s_list = torch.tensor(np.array(s_list), dtype = torch.float).to(device)
        a_list = torch.tensor(np.array(a_list), dtype = torch.float).to(device)
        r_list = torch.tensor(np.array(r_list), dtype = torch.float).to(device)
        s_prime_list = torch.tensor(np.array(s_prime_list), dtype = torch.float).to(device)
        done_mask_list = torch.tensor(np.array(done_mask_list), dtype = torch.float).to(device)
        return s_list, a_list, r_list, s_prime_list, done_mask_list

    def size(self):
        return len(self.buffer)

In [None]:
class Actor(nn.Module):
    def __init__(self, learning_rate):
        super(Actor,self).__init__()
        # Gaussian Distribution
        self.fc1 = nn.Linear(28,256)
        self.fc2 = nn.Linear(256,256)
        self.fc_mean = nn.Linear(256,8)
        self.fc_std = nn.Linear(256,8)
        self.optimizer = optim.Adam(self.parameters(),lr=learning_rate)

        # Autotuning Alpha
        self.log_alpha = torch.tensor(np.log(init_alpha))
        self.log_alpha.requires_grad = True
        self.log_alpha_optimizer = optim.Adam([self.log_alpha],lr = lr_alpha)  

    def forward(self,x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        mean = self.fc_mean(x)
        # std : softplus or ReLU activate function
        std = F.softplus(self.fc_std(x))
        Gaussian = Normal(mean,std)
        action = Gaussian.rsample()
        log_prob = Gaussian.log_prob(action)
        # action range : -1 ~ 1
        real_action = torch.tanh(action)
        real_log_prob = log_prob - torch.log(1-torch.tanh(action).pow(2) + 1e-7)
        return real_action, real_log_prob

    def train_p(self,q1,q2,mini_batch):
        s, _, _, _, _ = mini_batch
        a, log_prob = self.forward(s)
        entropy = -self.log_alpha.exp() * log_prob

        q1_val, q2_val = q1(s,a), q2(s,a)
        q1_q2 = torch.cat([q1_val, q2_val], dim=1)
        min_q = torch.min(q1_q2, 1, keepdim=True)[0]

        loss = (-min_q - entropy)
        self.optimizer.zero_grad()
        loss.mean().backward()
        self.optimizer.step()

        self.log_alpha_optimizer.zero_grad()
        alpha_loss = -(self.log_alpha.exp() * (log_prob + target_entropy).detach()).mean()
        alpha_loss.backward()
        self.log_alpha_optimizer.step()

pi = Actor(lr_pi)
print('Actor Loaded')

In [None]:
class Critic(nn.Module):
    def __init__(self, learning_rate):
        super(Critic,self).__init__()
        self.fc_s = nn.Linear(28,128)
        self.fc_a = nn.Linear(8,128)
        self.fc_cat = nn.Linear(256,256)
        self.fc_out = nn.Linear(256,1)
        self.optimizer = optim.Adam(self.parameters(),lr=learning_rate)

    def forward(self,x,a):
        x = F.relu(self.fc_s(x))
        a = F.relu(self.fc_a(a))
        cat = torch.cat([x,a], dim=1)
        q = F.relu(self.fc_cat(cat))
        q_value = self.fc_out(q)

        return q_value

    def train_q(self,target,mini_batch):
        s, a, r, s_prime, done = mini_batch
        loss = F.smooth_l1_loss(self.forward(s,a), target)
        self.optimizer.zero_grad()
        loss.mean().backward()
        self.optimizer.step()

    # DDPG soft_update
    def soft_update(self, net_target):
        for param_target, param in zip(net_target.parameters(), self.parameters()):
            param_target.data.copy_(param_target.data * (1.0 - tau) + param.data * tau)

In [None]:
def get_target(pi, q1, q2, mini_batch):
    s, a, r, s_prime, done = mini_batch
    with torch.no_grad():
        a_prime, log_prob= pi(s_prime)
        entropy = -pi.log_alpha.exp() * log_prob
        q1_val, q2_val = q1(s_prime,a_prime), q2(s_prime,a_prime)
        q = torch.cat([q1_val, q2_val], dim=1)
        min_q = torch.min(q, 1, keepdim=True)[0]
        target = r + gamma * done * (min_q + entropy.mean())
    return target 

In [None]:
class RolloutWorkerClass(object):
    """
    Worker without RAY (for update purposes)
    """
    def __init__(self, device, seed=2000):
        self.seed = seed
        self.env = env

        # Create SAC model and target networks
        self.q1, self.q2, self.q1_target, self.q2_target = Critic(lr_q).to(device), Critic(lr_q).to(device), Critic(lr_q).to(device), Critic(lr_q).to(device)
        self.pi = Actor(lr_pi).to(device)
        self.q1_target.load_state_dict(self.q1.state_dict())
        self.q2_target.load_state_dict(self.q2.state_dict())

        # Initialize model
        torch.manual_seed(self.seed)
        np.random.seed(self.seed)
        random.seed(self.seed)

    def get_weights(self):
        return self.q1.state_dict(), self.q2.state_dict(), self.q1_target.state_dict(), self.q2_target.state_dict()

    def set_weights(self, q1_weight, q2_weight, q1_target_weight, q2_target_weight):
        self.q1.load_state_dict(q1_weight)
        self.q2.load_state_dict(q2_weight)
        self.q1_target.load_state_dict(q1_target_weight)
        self.q2_target.load_state_dict(q2_target_weight)

@ray.remote
class RayRolloutWorkerClass(object):
    """
    Worker with RAY (for rollout)
    """
    def __init__(self, device, worker_id=0, seed=1):
        self.seed = seed
        self.worker_id = worker_id
        self.env = env

        # Create SAC model and target networks
        self.q1, self.q2, self.q1_target, self.q2_target = Critic(lr_q).to(device), Critic(lr_q).to(device), Critic(lr_q).to(device), Critic(lr_q).to(device)
        self.pi = Actor(lr_pi).to(device)
        self.q1_target.load_state_dict(self.q1.state_dict())
        self.q2_target.load_state_dict(self.q2.state_dict())

        # Initialize model
        torch.manual_seed(self.seed)
        np.random.seed(self.seed)
        random.seed(self.seed)

    def set_weights(self, q1_weight, q2_weight, q1_target_weight, q2_target_weight):
        self.q1.load_state_dict(q1_weight)
        self.q2.load_state_dict(q2_weight)
        self.q1_target.load_state_dict(q1_target_weight)
        self.q2_target.load_state_dict(q2_target_weight)

    def get_env(self):
        return self.env
    
    def rollout(self, memory, rollout):
        s = self.env.reset()
        done = False
        score = 0
        bestsc = 0
        step = 0
        small_buffer = collections.deque(maxlen=1000)

        for i in range(1000):
            a, log_prob = self.pi(torch.from_numpy(s).float().to(device))
            a_ = []
            # self.env.render()
            for i in a:
                a_.append(i.item())
            s_prime, reward, done, info = self.env.step(a_)
            small_buffer.append((s,a_,reward,s_prime,done))
            score += reward
            step += 1
            s = s_prime 
    
            if done is True:
                break

        return score, step, small_buffer

#### Function to render pybulletgym & SAC alogirhtm

In [None]:
def play_ant_render():
    env = gym.make('AntPyBulletEnv-v0')
    s = env.reset()
    global pi
    score = 0
    done = False
    plt.figure(figsize=(9,9))
    img = plt.imshow(env.render(mode='rgb_array'))
    while done is not True:
        img.set_data(env.render(mode='rgb_array'))
        display.display(plt.gcf())
        display.clear_output(wait=True)
        a, _ = pi(torch.from_numpy(s).float())
        a_ = []
        for i in a:
            a_.append(i.item())
        s_prime, r, done, info = env.step(a_)
        score += r
        s = s_prime
        
        if done is True:
            env.close()
            break
    print("Rendering is Finished")
    print("Final Ant Score : {}".format(score))

In [None]:
def main():
    # For RAY
    ray.init(num_cpus=n_cpu,
         _memory = 5*1024*1024*1024,
         object_store_memory = 10*1024*1024*1024,
         _driver_object_store_memory = 1*1024*1024*1024)

    R = RolloutWorkerClass(device, seed=0)
    workers = [RayRolloutWorkerClass.remote(device, worker_id=i)
           for i in range(n_workers)]
    print("RAY initialized with [%d] cpus and [%d] workers."%(n_cpu,n_workers))

    memory = ReplayBuffer()

    score = 0
    bestsc = 0
    print_interval = 100
    gradient_update = 30
    step = 0

    for episodes in range(1,100000):
        q1_weight, q2_weight, q1_target_weight, q2_target_weight = R.get_weights()
        [worker.set_weights.remote(q1_weight, q2_weight, q1_target_weight, q2_target_weight) for worker in workers]
        ops = [worker.rollout.remote(memory, R) for worker in workers]
        rollout_vals = ray.get(ops)

        for score_, step_, small_buffer in rollout_vals:
            memory.buffer.extend(list(small_buffer))
            score += score_ / n_workers

        if score_ > bestsc:
            bestsc = score_ # best among workers
        
        if memory.size() > 30000:
            for i in range(gradient_update):
                mini_batch = memory.sample(batch_size)
                td_target = get_target(R.pi, R.q1_target, R.q2_target, mini_batch)
                R.q1.train_q(td_target, mini_batch)
                R.q2.train_q(td_target, mini_batch)
                R.pi.train_p(R.q1, R.q2, mini_batch)
                R.q1.soft_update(R.q1_target)
                R.q2.soft_update(R.q2_target)
            
        if episodes % print_interval==0 and episodes!=0:
            print("number of episode :{}, avg score :{:.1f}, best score :{:.1f}, avg step :{:.1f}, alpha:{:.4f}".format(episodes, score/print_interval, best_score, step/print_interval, pi.log_alpha.exp()))
            if not os.path.exists("weights_{}".format(folder)):
                os.mkdir("weights_{}".format(folder))
            torch.save(pi.state_dict(), 'weights_{}/model_weights_{}.pth'.format(folder,episodes))
            score = 0
            step = 0

        if episodes >= 2500:
            break

    env.close()

#### Main loop and render

In [None]:
main()

In [None]:
play_ant_render()