In [42]:
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T
from torch.nn.init import kaiming_uniform_
from torch.distributions import Normal
from scipy.ndimage.filters import uniform_filter1d
from matplotlib import pyplot as plt

In [2]:
torch.normal(mean=torch.arange(1., 6.), std=torch.arange(1., 6.))

tensor([ 0.5581, -1.5493,  5.1617, -0.2754,  5.3021])

In [3]:
import numpy as np
import gym
from tqdm import tqdm
import random as rand
from itertools import count

In [4]:
class ReplayMemory():
    def __init__(self,capacity):   
        self.capacity = capacity
        self.memory = []
        self.push_count = 0
        
    def push(self, experience):
        if len(self.memory) < self.capacity:
            self.memory.append(experience)
        else:
            self.memory[self.push_count%self.capacity] = experience
        self.push_count+=1
    
    def sample(self, batch_size):
        return rand.sample(self.memory,batch_size)
    
    def can_provide_sample(self, batch_size):
        return len(self.memory)>=batch_size
    
    def update_td_error(self, sampled_experiences):
        for sampled_idx,sampled_exp in enumerate(sampled_experiences):
            for mem_idx, mem_exp in enumerate(self.memory):
                if mem_exp.timestep == sampled_exp.timestep:
                    self.memory[mem_idx] = sampled_exp #update memory
                    break
        
    def get_memory_values(self):
        return self.memory    

In [5]:
def extract_tensors(experiences):
    #print(".....................................................")
    #print(experiences)
    batch = Xp(*zip(*experiences))
    state = np.stack(batch.state) #stack
    action = np.stack(batch.action)
    next_state = np.stack(batch.next_state)
    reward = np.stack(batch.reward)
    done = np.stack(batch.done)
    abs_td_error = np.stack(batch.abs_td_error)
    timestep = np.stack(batch.timestep)
    return state,action,next_state,reward,done,abs_td_error,timestep

In [6]:
def rebuild_experiences(state, action, next_state, reward, done, abs_error, timestep):
    exp_list = []
    for idx_ in range(len(state)):
        exp_list.append(\
                        Xp(state[idx_], action[idx_], next_state[idx_], reward[idx_],\
                           done[idx_], abs_error[idx_], timestep[idx_]))
    return exp_list

In [7]:
from collections import namedtuple
Xp = namedtuple('Experience',
                        ('state', 'action', 'next_state', 'reward', 'done', 'abs_td_error','timestep'))
Xp_points = Xp(5,6,7,8,9,10,11)
Xp_points

Experience(state=5, action=6, next_state=7, reward=8, done=9, abs_td_error=10, timestep=11)

In [8]:
def prioritize_samples(experience_samples, alpha, beta):
    state,action,next_state,reward,done,abs_td_error,timesteps = extract_tensors(experience_samples)
    #rank based
    #('state', 'action', 'next_state', 'reward', 'done', 'abs_td_error','timestep')
    abs_td_error  = np.expand_dims(abs_td_error, axis=1)
    abs_td_error = torch.tensor(abs_td_error)
    abs_td_error, indices_ = abs_td_error.sort(0, descending=True)#big to small
    indices = np.arange(1, len(abs_td_error)+1)
    priorities = 1.0/indices
    priorities = priorities**alpha#scale by alpha
    priorities = np.expand_dims(priorities, axis=1)
    probabilities = priorities/np.sum(priorities, axis=0)#sums up to 1(or 0.999999)
    assert np.isclose(probabilities.sum(), 1.0)#ensures probs add up to 1
    
    number_of_samples  = len(probabilities)
    weight_importance_ = number_of_samples*probabilities
    weight_importance_ = weight_importance_**-beta
    weight_importance_max = np.max(weight_importance_)
    weight_importance_scaled = weight_importance_/weight_importance_max
    return weight_importance_scaled, indices_ #return weight important samples, return indices for re_arranging sampled experiences

In [9]:
class linearApproximator_FCGSAP(nn.Module):
    def __init__(self,state_shape,outputs,hidden_dims=(32,32), log_entropy_lr =0.0001,\
                log_std_dev_min=-20, log_std_dev_max= 2, epsilon = 1e-6):
        super(linearApproximator_FCGSAP, self).__init__()
        self.input_size = state_shape
        self.out = outputs
        self.log_std_dev_min = log_std_dev_min
        self.log_std_dev_max = log_std_dev_max
        self.epsilon = epsilon
        self.device = torch.device("cuda" if torch.cuda.is_available()\
                                   else "cpu")
        
        self.fc1  = nn.Linear(self.input_size,hidden_dims[0])
        self.hidden_layers = nn.ModuleList()
        for i in range(len(hidden_dims)-1):
            hidden_layer = nn.Linear(\
                                hidden_dims[i], hidden_dims[i+1])
            self.hidden_layers.append(hidden_layer)
        
        self.output_layer_distribution  = nn.Linear(hidden_dims[-1],self.out)
        self.output_layer_mean = nn.Linear(hidden_dims[-1],self.out)
        
        self.target_entropy = -float(self.out)
        #self.target_entropy = self.target_entropy.float() 
        #according to the eq, log alpha is a learnable parameter
        self.log_alpha = torch.zeros(1,\
                                     requires_grad=True,\
                                     device = self.device)
        self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha],\
                                                    lr=log_entropy_lr)
                                     
        self.to(self.device)
        
    def forward(self, state_shape):
        if not isinstance(state_shape, torch.Tensor):
            state_shape = torch.tensor(state_shape, dtype=torch.float32)
        state_shape = state_shape.to(self.device)
        x = self.fc1(state_shape)
        x = F.relu(x)
        
        for hidden_layer in self.hidden_layers:
            x = F.relu(hidden_layer(x))
        
        distribution = self.output_layer_distribution(x)#logits, preferences of actions
        mean   = self.output_layer_mean(x)
        distribution = torch.clamp(distribution, self.log_std_dev_min, self.log_std_dev_max)
        return mean, distribution
        
    def full_pass(self, state):
        mean, distribution = self.forward(state)
        pi_s = Normal(mean, distribution.exp())
        pre_sampled_actions = pi_s.rsample()
        sampled_actions = torch.tanh(pre_sampled_actions)#scale actions between -1 and 1
        #we also rescale our logprobs to match action space
        log_probs = pi_s.log_prob(pre_sampled_actions) - \
                                    torch.log((1 - sampled_actions.pow(2)).clamp(0,1) + \
                                                self.epsilon)
        log_probs = log_probs.sum(dim=1, keepdim=True)
        return sampled_actions, log_probs, mean

In [10]:
class linearApproximator_FCQV(nn.Module):#Q value of state action pair
    def __init__(self,state_shape,action_outputs_size,hidden_dims=(32,32)):
        super(linearApproximator_FCQV, self).__init__()
        self.input_size = state_shape
        self.action_outputs_size = action_outputs_size
        self.device = torch.device("cuda" if torch.cuda.is_available()\
                                   else "cpu")
        
        self.fc1  = nn.Linear(self.input_size,hidden_dims[0])
        self.hidden_layers = nn.ModuleList()
        for i in range(len(hidden_dims)-1):
            hidden_input_layer = hidden_dims[i]
            if i == 0:
                hidden_input_layer += self.action_outputs_size #increased to account for size/number of actions
            hidden_layer = nn.Linear(\
                                hidden_input_layer, hidden_dims[i+1])
            self.hidden_layers.append(hidden_layer)
        
        self.output_layer  = nn.Linear(hidden_dims[-1],1)
        self.to(self.device)
        
    def forward(self, state_shape, action_shape):
        if not isinstance(state_shape, torch.Tensor):
            state_shape = torch.tensor(state_shape, dtype=torch.float32).to(self.device)
        if not isinstance(action_shape, torch.Tensor):
            action_shape = torch.tensor(action_shape, dtype=torch.float32).to(self.device)
                    
        x = self.fc1(state_shape)
        x = F.relu(x)
        
        for idx, hidden_layer in enumerate(self.hidden_layers):
            if idx == 0:
                x = torch.cat((x, action_shape), dim=1)
            x = F.relu(hidden_layer(x))
        
        q_value = self.output_layer(x)#logits, preferences of actions
        return q_value

In [11]:
def update_networks(online_q_network_a, online_q_network_b,\
                    offline_q_network_a, offline_q_network_b, tau):
        
    for target_weights, online_weights in zip(offline_q_network_a.parameters(), online_q_network_a.parameters()):
        target_weight_update = (1.0 - tau)*target_weights.data
        online_weight_update = tau*online_weights.data
        sum_up = target_weight_update + online_weight_update
        target_weights.data.copy_(sum_up)
        
    for target_weights, online_weights in zip(offline_q_network_b.parameters(), online_q_network_b.parameters()):
        target_weight_update = (1.0 - tau)*target_weights.data
        online_weight_update = tau*online_weights.data
        sum_up = target_weight_update + online_weight_update
        target_weights.data.copy_(sum_up)

    return offline_q_network_a, offline_q_network_b

In [12]:
def update_online_model(experience_samples,\
                        online_policy_network, online_q_network_a, online_q_network_b,\
                        online_policy_optimizer, online_q_optimizer_a, online_q_optimizer_b,\
                        offline_q_network_a, offline_q_network_b,\
                        gamma, weighted_importance, indices):
    
    states, actions, next_states, rewards, done, td_errors, timesteps = extract_tensors(experience_samples)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    #rearrange first
    indices = np.array(indices)#have to be same size as batch size(already taken care of)
    indices = np.squeeze(indices)
    states = states[indices,:]
    states = np.squeeze(states)
    actions = actions[indices]
    next_states = next_states[indices,:]
    next_states = np.squeeze(next_states)
    rewards = rewards[indices]
    done = done[indices]
    td_errors = td_errors[indices]
    timesteps = timesteps[indices]    
    
    states = torch.tensor(states).float().to(device)
    actions = torch.tensor(actions)
    actions = actions.float().to(device)
    next_states=torch.tensor(next_states).float().to(device)
    rewards = torch.tensor(rewards).float().to(device)
    rewards = rewards.unsqueeze(1)
    done = torch.tensor(done).float().to(device)
    done = done.unsqueeze(1)
    weighted_importance = torch.tensor(weighted_importance).float().to(device)
    
    
    current_actions,log_pi, _ = online_policy_network.full_pass(states)
    target_alpha = (log_pi +\
                    online_policy_network.target_entropy).detach()
    target_alpha_loss = -(online_policy_network.log_alpha *\
                         target_alpha).mean()
    online_policy_network.log_alpha_optimizer.zero_grad()
    target_alpha_loss.backward()
    online_policy_network.log_alpha_optimizer.step()
    optimized_alpha = online_policy_network.log_alpha.exp()
    
    
    max_q_sa_online_a = online_q_network_a(states, current_actions.detach())
    max_q_sa_online_b = online_q_network_b(states, current_actions.detach())
    max_q_online__ = torch.min(max_q_sa_online_a, max_q_sa_online_b)
    #max_q_online__*=(1 - done)
    #print("max_q_online__", max_q_online__)
    predicted_online_action_policy,\
                log_pi_ns, _ = online_policy_network.full_pass(next_states)

    policy_loss = -(max_q_online__.detach() - optimized_alpha.detach()*log_pi).mean()#policy loss
    
    max_q_sa_offline_a = offline_q_network_a(next_states, predicted_online_action_policy)
    max_q_sa_offline_b = offline_q_network_b(next_states, predicted_online_action_policy)
    max_q_sa_offline = torch.min(max_q_sa_offline_a, max_q_sa_offline_b)
    TWIN_target = max_q_sa_offline
    TWIN_target*=weighted_importance
    
    TWIN_target = TWIN_target.detach()
    TWIN_target *=(1 - done)
    TWIN_target = TWIN_target - optimized_alpha * log_pi_ns
    TWIN_target = rewards + gamma*TWIN_target
    
    TWIN_target = TWIN_target.detach()
    loss_func = torch.nn.SmoothL1Loss()
    
    
    q_sa_online_a = online_q_network_a(states, actions)
    q_sa_online_b = online_q_network_b(states, actions)
    
    abs_a = abs(TWIN_target.detach().cpu().numpy() - max_q_sa_online_a.detach().cpu().numpy())
    abs_b = abs(TWIN_target.detach().cpu().numpy() - max_q_sa_online_b.detach().cpu().numpy())
    ovr_abs_update = (abs_a + abs_b)/2 #we get the mean(not done)
    
    
    
    q_online_value_loss_a = loss_func(q_sa_online_a,\
                                             TWIN_target)
    q_online_value_loss_b = loss_func(q_sa_online_b,\
                                             TWIN_target)
    online_q_optimizer_a.zero_grad()
    q_online_value_loss_a.backward()
    online_q_optimizer_a.step()
    online_q_optimizer_b.zero_grad()
    q_online_value_loss_b.backward()
    online_q_optimizer_b.step()
    
    #One of the difference between SAC and TD3 is SAC only uses one online policy
    #there is also no delay in policy updates in SAC
    #policy_loss = -(pre_optimized_alpha * log_pi - max_q_online__).mean()
    online_policy_optimizer.zero_grad()
    policy_loss.backward()
    online_policy_optimizer.step()
        
    states, actions, next_states, rewards, done, td_errors, timesteps = extract_tensors(experience_samples)
    experiences_rebuilded = rebuild_experiences(states, actions, next_states, rewards, done, ovr_abs_update, timesteps)
    return experiences_rebuilded

In [13]:
def query_error(online_policy_network, offline_q_network_a, offline_q_network_b,\
                online_q_network_a, online_q_network_b, state, action, next_state, reward, gamma):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    state = torch.tensor(state).float().to(device)
    state = state.unsqueeze(0)
    next_state = torch.tensor(next_state).float().to(device)
    next_state = next_state.unsqueeze(0)
    
    alpha = online_policy_network.log_alpha.exp()
    
    ns_actions,log_pi_ns, _ = online_policy_network.full_pass(next_state)
    q_target_next_states_action_a = offline_q_network_a(next_state,\
                                                    ns_actions.detach())
    q_target_next_states_action_b = offline_q_network_b(next_state,\
                                                    ns_actions.detach())
    
    
    TWIN_target = torch.min(q_target_next_states_action_a, q_target_next_states_action_b)
    TWIN_target = TWIN_target - alpha * log_pi_ns
    TWIN_target = reward + (gamma*TWIN_target.detach())
    
    
    current_action,_, _ = online_policy_network.full_pass(state)
    #print("state: ", state.shape)
    #print("action: ", action.shape)
    action = np.expand_dims(action, axis=0)
    #print("action: ", action.shape)
    q_online_state_action_val_a = online_q_network_a(state, action)
    q_online_state_action_val_b = online_q_network_b(state, action)
    
    abs_a = abs(TWIN_target - q_online_state_action_val_a)
    abs_b = abs(TWIN_target - q_online_state_action_val_b)
    abs_stack = (abs_a + abs_b)/2
    #print("abs querry error stacked: ", abs_stack.shape)
    ovr_abs_update = abs_stack
    #print("abs querry error mean: ", ovr_abs_update.shape)
    return ovr_abs_update.detach().cpu().numpy()

In [14]:
def freeze_model(model):
    for param in model.parameters():
        param.requires_grad = False
    return model

In [15]:
def select_action(state, online_policy_network):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    state = torch.tensor(state).float().to(device)
    state = state.unsqueeze(0)
    with torch.no_grad():
        actions,log_pi_, _ = online_policy_network.full_pass(state)
        actions = actions.cpu().detach()
        actions = actions.data.numpy().squeeze()
    return actions

In [37]:
def SAC_PER(env,
         gamma=0.99,
         alpha_pr=0.6,
         beta_pr=0.3,
         memory_size = 5000,
         tau = 0.02,
         offline_update = 200,
         min_sample_size=140,
         batch_size = 64,
         n_ep=100,
         max_steps = 100000
         ):
    
    
    observation_space = len(env.reset())
    action_space_high, action_space_low = env.action_space.high, env.action_space.low
    n_actions = len(action_space_high)
    online_policy_network = linearApproximator_FCGSAP(observation_space,n_actions,\
                                     hidden_dims=(128,64))
    online_q_network_a = linearApproximator_FCQV(observation_space,\
                                     n_actions,hidden_dims=(128,64))
    online_q_network_b = linearApproximator_FCQV(observation_space,\
                                     n_actions,hidden_dims=(128,64))
    
    offline_q_network_a = linearApproximator_FCQV(observation_space,\
                                     n_actions,hidden_dims=(128,64))
    offline_q_network_b = linearApproximator_FCQV(observation_space,\
                                     n_actions,hidden_dims=(128,64))
    

    offline_q_network_a.eval()
    offline_q_network_a = freeze_model(offline_q_network_a)
    offline_q_network_b.eval()
    offline_q_network_b = freeze_model(offline_q_network_b)
    
    online_policy_optimizer    = torch.optim.Adam(online_policy_network.parameters(),lr=0.001)
    online_q_optimizer_a = torch.optim.Adam(online_q_network_a.parameters(),lr=0.0008)
    online_q_optimizer_b = torch.optim.Adam(online_q_network_b.parameters(),lr=0.0008)
    
    memory = ReplayMemory(memory_size)
    
    t_step = 0 #important
    reward_per_ep = []
    
    
    for e in tqdm(range(n_ep)):
        state = env.reset()
        reward_accumulated = 0
        
        while True:
            env.render()
            action = select_action(state, online_policy_network)
            #if memory.can_provide_sample(min_sample_size):
                #print(action)
            next_state, reward, done, info = env.step(action)
            td_error = query_error(online_policy_network, offline_q_network_a, offline_q_network_b,\
                online_q_network_a, online_q_network_b, state, action, next_state, reward, gamma)
            #print("td error: ", td_error.shape)
            td_error = np.squeeze(td_error, axis = 0)
            #print("td error: ", td_error.shape)
            reward_accumulated+=reward
            is_truncated = 'TimeLimit.truncated' in info and\
                                info['TimeLimit.truncated']
            is_failure = done and not is_truncated
           
            memory.push(Xp(state, action, next_state, reward, is_failure, td_error, t_step))
            state = next_state
            t_step+=1
            if memory.can_provide_sample(min_sample_size):
                experience_samples = memory.sample(batch_size)
                weighted_importance, indices = prioritize_samples(experience_samples, alpha_pr, beta_pr)
                rebuilded_exp = update_online_model(experience_samples,\
                        online_policy_network, online_q_network_a, online_q_network_b,\
                        online_policy_optimizer, online_q_optimizer_a, online_q_optimizer_b,\
                        offline_q_network_a, offline_q_network_b,\
                        gamma, weighted_importance, indices)
                memory.update_td_error(rebuilded_exp)
                
            if t_step%offline_update == 0:
                offline_q_network_a, offline_q_network_b = update_networks(online_q_network_a, online_q_network_b,\
                                                                    offline_q_network_a, offline_q_network_b, tau)
            if done == True:
                reward_per_ep.append(reward_accumulated)
                #print(reward_accumulated)
                break
            if t_step > max_steps:
                env.close()
                return reward_per_ep
    env.close()           
    return reward_per_ep

In [38]:
import gym
env = gym.make('BipedalWalker-v3')

In [41]:
env.close()

In [40]:
rewards = SAC_PER(env)  

 69%|██████▉   | 69/100 [49:32<22:15, 43.08s/it]   


KeyboardInterrupt: 

In [None]:

arr = uniform_filter1d(rewards, size=1)

In [None]:
x = []
y = []
for i,e in enumerate(arr):
    y.append(i)
    x.append(e)

In [None]:
plt.plot(y, x)

In [None]:
-np.prod(env.action_space.high.shape)