In [3]:
""" Learn a policy using DDPG for the reach task"""
import numpy as np
import time
import copy

import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions.multivariate_normal import MultivariateNormal

import gym
import pybullet
import pybulletgym.envs

from collections import deque
from operator import itemgetter

import matplotlib.pyplot as plt

np.random.seed(1000)


def weighSync(target_model, source_model, tau=0.001):
    ''' A function to soft update target networks '''
    assert isinstance(tau, float) and tau>0

    for param_target, param_source in zip(target_model.parameters(), source_model.parameters()):
        # Wrap in torch.no_grad() because weights have requires_grad=True, 
        # but we don't need to track this in autograd
        with torch.no_grad():
            param_target = tau*param_source + (1-tau)*param_target
    
    return target_model, source_model


class Replay():
    def __init__(self, buffer_size, init_length, state_dim, action_dim, env):
        """
        A function to initialize the replay buffer.

        param: init_length : Initial number of transitions to collect
        param: state_dim : Size of the state space
        param: action_dim : Size of the action space
        param: env : gym environment object
        """
        self.buffer_size = buffer_size
        self.state_dim = state_dim
        self.action_dim = action_dim
        
        self.buffer = deque() #list like object for which removing elements from left is faster
        
        s = env.reset()
        for i in range(init_length):
            a = env.action_space.sample()
            s_prime, r, done, _ = env.step(a)
            self.buffer.append({
                's': s,
                'a': a,
                'r': r,
                's_prime': s_prime
            })
    
    def __len__(self):
        ''' Return number of elements in buffer'''
        return len(self.buffer)
    
    def buffer_add(self, exp):
        """
        A function to add a dictionary to the buffer
        param: exp : A dictionary consisting of state, action, reward , next state and done flag
        """
        assert isinstance(exp, dict) and len(exp) == 4
        assert len(self.buffer) <= self.buffer_size, 'Buffer size exceeded'
        if len(self.buffer) < self.buffer_size:
            self.buffer.append(exp)
        else:
            self.buffer.popleft() #removing the 1st element (left most element)
            self.buffer.append(exp)

    def buffer_sample(self, N):
        """
        A function to sample N points from the buffer
        param: N : Number of samples to obtain from the buffer
        """
        indices = list(np.random.randint(low=0, high=len(self.buffer), size=N, dtype='int'))
        sample = itemgetter(*indices)(self.buffer) #extarct values at indices from buffer
        
        return sample
        

class Actor(nn.Module):
    #TODO: Complete the function
    def __init__(self, state_dim, action_dim):
        """
        Initialize the network
        param: state_dim : Size of the state space
        param: action_dim: Size of the action space
        """
        assert isinstance(state_dim, int) and state_dim>0
        assert isinstance(action_dim, int) and action_dim>0
        
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(state_dim, 400)
        self.hidden1 = nn.Linear(400, 300)
        self.fc2 = nn.Linear(300, action_dim)
        self.tanh = nn.Tanh()
        self.relu = nn.ReLU()

    #TODO: Complete the function
    def forward(self, state):
        """
        Define the forward pass
        param: state: The state of the environment
        """
        state = self.relu(self.fc1(state))
        state = self.relu(self.hidden(state))
        state = self.tanh(self.fc2(state))

        return state


class Critic(nn.Module):
    def __init__(self, state_dim, action_dim):
        """
        Initialize the critic
        param: state_dim : Size of the state space
        param: action_dim : Size of the action space
        """
        assert isinstance(state_dim, int) and state_dim>0
        assert isinstance(action_dim, int) and action_dim>0
        
        super(Critic, self).__init__()
        self.state_dim = state_dim
        self.action_dim = action_dim
        
        # NN layers and activations
        self.fc1 = nn.Linear(state_dim + action_dim, 400)
        self.hidden1 = nn.Linear(400, 300)
        self.fc2 = nn.Linear(300, 1)
        self.relu = nn.ReLU()

    def forward(self, state, action):
        """ Define the forward pass of the critic """
        assert isinstance(state, np.ndarray)
        assert isinstance(action, np.ndarray)
        assert state.shape == (self.state_dim, ), 'state must be 1D and of size (%d,)'%self.state_dim
        assert action.shape == (self.action_dim, ), 'action must be 1D and of size (%d,)'%self.action_dim
        
        state, action = state.from_numpy(), action.from_numpy() #numpy to torch tensor
        x = torch.cat((state, action), dim=0) #concatenating to form input
        x = self.relu(self.fc1(x))
        x = self.relu(self.hidden(x))
        x = self.relu(self.fc2(x))
        
        return x


class DDPG():
    def __init__(
            self,
            env,
            action_dim,
            state_dim,
            critic_lr=3e-4,
            actor_lr=3e-4,
            gamma=0.99,
            batch_size=100,
    ):
        """
        param: env: An gym environment
        param: action_dim: Size of action space
        param: state_dim: Size of state space
        param: critic_lr: Learning rate of the critic
        param: actor_lr: Learning rate of the actor
        param: gamma: The discount factor
        param: batch_size: The batch size for training
        """
        assert isinstance(state_dim, int) and state_dim>0
        assert isinstance(action_dim, int) and action_dim>0
        assert isinstance(batch_size, int) and batch_size>0
        assert isinstance(critic_lr, (int, float)) and critic_lr>0
        assert isinstance(actor_lr, (int, float)) and actor_lr>0
        assert isinstance(gamma, (int, float)) and gamma>0

        self.gamma = gamma
        self.batch_size = batch_size
        self.env = env

        # Create a actor and actor_target with same initial weights
        self.actor = Actor(state_dim, action_dim)
        self.actor_target = copy.deepcopy(self.actor) #both networks have the same initial weights 

        # Create a critic and critic_target with same initial weights
        self.critic = Critic(state_dim, action_dim)
        self.critic_target = copy.deepcopy(self.critic) #both networks have the same initial weights 

        # Define optimizer for actor and critic
        self.optimizer_actor = optim.Adam(self.actor.parameters(), lr=actor_lr)
        self.optimizer_critic = optim.Adam(self.critic.parameters(), lr=critic_lr)

        # Define a replay buffer
        self.ReplayBuffer = Replay(10000, 1000, state_dim, action_dim, self.env)

    def update_target_networks(self):
        """
        A function to update the target networks
        """
        weighSync(self.actor_target, self.actor)
        weighSync(self.critic_target, self.critic)

    # TODO: Complete the function
    def update_network(self):
        """
        A function to update the function just once
        """
        pass

    # TODO: Complete the function
    def train(self, num_steps):
        """
        Train the policy for the given number of iterations
        :param num_steps:The number of steps to train the policy for
        """
        pass

In [4]:
if __name__ == "__main__":
    # Define the environment
    env = gym.make("modified_gym_env:ReacherPyBulletEnv-v1", rand_init=False)

    ddpg_object = DDPG(
        env,
        8,
        2,
        critic_lr=1e-3,
        actor_lr=1e-3,
        gamma=0.99,
        batch_size=100,
    )
    # # Train the policy
    # ddpg_object.train(100)

    # # Evaluate the final policy
    # state = env.reset()
    # done = False
    # while not done:
    #     action = ddpg_object.actor(state).detach().squeeze().numpy()
    #     next_state, r, done, _ = env.step(action)
    #     env.render()
    #     time.sleep(0.1)
    #     state = next_state


options= 


In [11]:
len(ddpg_object.ReplayBuffer.buffer)

1000