In [135]:
import math
import random
import matplotlib
import numpy as np
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count
import sys
import timeit

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# MCTS code imports
sys.path.append("..")  # Adds higher directory to python modules path.
from main import ToyMeasurementControl
from utils import rotate_about_point

class PolicyValueNetwork(nn.Module):
    """
    Neural net combining policy and value network (state -> (policy, value))
    params:
        state_dims: Number of dimensions in the state space
        action_dims: Number of dimensions in the action space
    """
    def __init__(self, state_dims, action_dims):
        super(PolicyValueNetwork, self).__init__()
        self.layer1 = nn.Linear(state_dims, 128)
        self.layer2 = nn.Linear(128, 128)
        self.layer3 = nn.Linear(128, action_dims+1) # +1 for the value output

    # Called with either one element to determine next action, or a transitions
    # during optimization. Returns tensor([[left0exp,right0exp]...]).
    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        return self.layer3(x)


# Named tuple for transitions
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward', 'done'))

class ReplayMemory:
    """
    Replay memory for storing transitions
    Creates a deque with a maximum length of capacity. Transitions are stored as named tuples.
    
    methods:
        push(*args): Save a transition
        sample(batch_size): Sample a batch of transitions
    """
    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

class MCTSRLWrapper:
    def __init__(self, q_network: PolicyValueNetwork, target_q_network: PolicyValueNetwork,
                 replay_memory: ReplayMemory, gamma: float=0.99, batch_size: int=64, 
                 lr: float=0.001, tau: float=0.005):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print("Using device:", self.device)
        
        self.q_network = q_network.to(self.device) # Q Network which is changed only after full batch is processed
        self.target_q_network = target_q_network.to(self.device) # Q network which is updated during batch training
        self.replay_memory = replay_memory # Replay memory for storing transitions
        self.gamma = gamma # Discount factor
        self.batch_size = batch_size # Number of transitions to sample for training
        self.tau = tau # Soft update parameter for target network (target = tau * q_network + (1 - tau) * target_q_network)
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=lr, amsgrad=True) # Adam optimizer for training the Q network
        
    def loss(self, transitions):
        """
        Calculate the loss for a batch of transitions
        """
        # Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
        # detailed explanation). This converts batch-array of Transitions
        # to Transition of batch-arrays.
        transitions = Transition(*zip(*transitions))
        
        # Pull out the components of the batch and concatinate them (concatenate scalar lists and stack tensor lists)
        state_transitions = torch.stack(transitions.state)
        action_transitions = torch.cat(transitions.action)
        next_state_transitions = torch.stack(transitions.next_state)
        reward_transitions = torch.cat(transitions.reward)
        done_transitions = torch.cat(transitions.done)
        
        # Get max Q target values of next state from target network (max_a' Q_target(s', a'))
        max_next_q_value = self.target_q_network(next_state_transitions).max(dim=1).values
        
        # Calculate the target (r + γ * max_a' Q_target(s', a')), what we want the Q network to predict
        y_targets = reward_transitions + self.gamma * max_next_q_value * ~done_transitions # If done, the target is just the reward
        
        # Get the current q_values, pick the values that are from the actions taken and then squeeze the tensor (remove the extra dimension)
        q_values = self.q_network(state_transitions).gather(1, action_transitions).squeeze()

        # Calculate the loss (difference between the target and the current q_value prediction)
        loss = F.mse_loss(y_targets, q_values)
        
        return loss
    
    def optimize_model(self):
        # Check if there are enough transitions in the replay memory to optimize
        if len(self.replay_memory) < self.batch_size:
            return
        
        # Sample a batch of transitions
        transitions = self.replay_memory.sample(self.batch_size)
        
        # Set the network to training mode
        self.q_network.train()
        
        # Zero the gradients
        self.optimizer.zero_grad()
        
        # Calculate the loss
        loss = self.loss(transitions)
        
        # Backpropagate the loss
        loss.backward()
        
        # Perform a step of optimization
        self.optimizer.step()

        # Soft update of the target network
        for target_param, param in zip(self.target_q_network.parameters(), self.q_network.parameters()):
            target_param.data.copy_(self.tau * param.data + (1.0 - self.tau) * target_param.data)
    

In [90]:
# Create state image from car position, OOI corners, car width, car_length and obstacles
def get_image_based_state(env: ToyMeasurementControl, state: tuple, width_pixels=200, width_meters=50) -> tuple():
    # Get car collision length and width
    car_width, car_length = env.car.width, env.car.length
    
    # Get obstacle means and radii
    obstacle_means, obstacle_radii = env.eval_kd_tree.get_obstacle_points(), env.eval_kd_tree.get_obstacle_radii()

    # Pull out the state components
    car_state, corner_means, corner_covariance, horizon = state
    corner_means = corner_means.reshape(-1, 2) # Reshape to 2D array where each row is a corner point
    
    # Get normalized point covariances
    pt_traces = env.get_normalized_cov_pt_traces(state)
    
    # Since image is body frame representation of car, obstacles and OOIs. The neural net only needs [vx, delta, delta_dot] as input
    # These are the components of the state which will determine how actions effect the car state, the rest of the state is used to generate the image
    nn_car_state = car_state[[2, 4, 5]]
    
    # Make the image
    image = np.zeros((width_pixels, width_pixels), dtype=np.float32)
    
    # Calculate the scaling factor from meters to pixels
    scale = width_pixels / width_meters
    
    # Rotate the obstacle and corner points to the car's yaw angle
    car_pos, car_yaw = car_state[:2], car_state[3]
    rotated_corners = rotate_about_point(corner_means, np.pi/2-car_yaw, car_pos) # Negative to rotate into a coordinate system where the car is facing up
    rotated_obstacles = rotate_about_point(obstacle_means, np.pi/2-car_yaw, car_pos)
    
    # Subtract the car's position from the rotated points to get the points relative to the car
    rotated_corners -= car_state[:2]
    rotated_obstacles -= car_state[:2]
    
    # Find which points are within the image bounds
    in_bounds_corners = (-width_meters/2 <= rotated_corners[:, 0]) & (rotated_corners[:, 0] <= width_meters/2) & \
                        (-width_meters/2 <= rotated_corners[:, 1]) & (rotated_corners[:, 1] <= width_meters/2)

    in_bounds_obstacles = (-width_meters/2 <= rotated_obstacles[:, 0]) & (rotated_obstacles[:, 0] <= width_meters/2) & \
                          (-width_meters/2 <= rotated_obstacles[:, 1]) & (rotated_obstacles[:, 1] <= width_meters/2)

    # Convert the car frame in bounds points to pixel coordinates
    in_bounds_corner_pixels = (rotated_corners[in_bounds_corners] * scale + width_pixels / 2).astype(int)
    in_bounds_obstacle_pixels = (rotated_obstacles[in_bounds_obstacles] * scale + width_pixels / 2).astype(int)
    in_bounds_obstacle_radii_pixels = (obstacle_radii[in_bounds_obstacles] * scale).astype(int)
    
    # First place obstacles so that rewards and car overlay them
    for i, point in enumerate(in_bounds_obstacle_pixels):
        x_pixel, y_pixel = point
        radius_pixel = in_bounds_obstacle_radii_pixels[i]
        x, y = np.ogrid[-x_pixel:width_pixels-x_pixel, -y_pixel:width_pixels-y_pixel]
        mask = x*x + y*y <= radius_pixel*radius_pixel
        image[mask] = -1.0
        
    # Place the car (draw a rectangle at the center given length and width), car is always facing up (positive x axis)
    car_width_pixels = int(car_width * scale)
    car_length_pixels = int(car_length * scale)
    car_max_x_index = car_width_pixels + width_pixels // 2
    car_max_y_index = car_length_pixels + width_pixels // 2
    image[-car_max_x_index:car_max_x_index, -car_max_y_index:car_max_y_index] = -0.5
    
    # Place the corners
    image[in_bounds_corner_pixels[:, 0], in_bounds_corner_pixels[:, 1]] = pt_traces[in_bounds_corners]
        
    # Return the neural net state and the image
    return nn_car_state, image

def get_nn_state(env: ToyMeasurementControl, state: tuple, device: torch.device) -> torch.Tensor:
    """
    Convert full state into image and then combine car state and flattened image into a single tensor
    params:
        env: ToyMeasurementControl environment
        state: Full state tuple
        device: Device to put the tensor on
    """
    nn_car_state, image = get_image_based_state(env, state)
    
    nn_state = torch.cat((torch.tensor(nn_car_state, dtype=torch.float32, device=device),
                          torch.tensor(image.flatten(), dtype=torch.float32, device=device)))
    
    return nn_state

def plot_state_image(image, title):
    plt.imshow(image.T, cmap='gray', origin='lower')
    plt.colorbar(label='Value')
    plt.title(title)
    plt.show()

In [136]:
# Make a toy measurement control object (environment for MCTS)
tmc = ToyMeasurementControl(no_flask_server=True, enable_ui=False)

# Get the initial state sizes using the NN state space
nn_car_state, image = get_image_based_state(tmc, tmc.get_state())

# Get the observation (NN state space) length
nn_state_dims = nn_car_state.size + image.size
num_actions = len(tmc.all_actions) # Number of actions (rows)

# Create the Q network and target Q network
q_network = PolicyValueNetwork(nn_state_dims, num_actions)
target_q_network = PolicyValueNetwork(nn_state_dims, num_actions)

# Create the replay memory
replay_memory = ReplayMemory(10000)

# Create the MCTSRL wrapper
mcts_rl = MCTSRLWrapper(q_network, target_q_network, replay_memory, gamma=0.99, batch_size=10, lr=0.001, tau=0.005)

# Create some random transitions to test the loss function
state = tmc.get_state()
for i in range(10):
    # Get the image based state
    nn_car_state = get_nn_state(tmc, state, mcts_rl.device)
    
    # Get the action
    action = torch.tensor([[random.randint(0, num_actions-1)]], dtype=torch.int64, device=mcts_rl.device)
    
    # Take the action
    next_state, reward, done = tmc.step(state, action.item())
    done = bool(done) # Convert to python boolean from np.bool_ (removes warning)
    
    # Get the next image based state
    next_nn_car_state = get_nn_state(tmc, next_state, mcts_rl.device)
    
    # Save the transition
    replay_memory.push(nn_car_state, action, next_nn_car_state,
                       torch.tensor([reward], dtype=torch.float32, device=mcts_rl.device),
                       torch.tensor([done],   dtype=torch.uint8,   device=mcts_rl.device))
    
    # Update the current state
    state = next_state
    
# Optimize the model
mcts_rl.optimize_model()
print("Optimization complete")

Toy Measurement Control Initialized
Using device: cuda
Q values after .gather(1, action_transitions): tensor([ 0.0408,  0.0436,  0.0913,  0.0691, -0.0834,  0.0582,  0.0296,  0.0599,
         0.0638, -0.0413], device='cuda:0', grad_fn=<SqueezeBackward0>)
Shape of y_targets: torch.Size([10])
Shape of q_values: torch.Size([10])
Optimization complete


In [73]:
# grab a random state
state = tmc.get_state()
nn_state = get_nn_state(tmc, state, mcts_rl.device)
print(f'NN State shape: {nn_state.shape}')

# Try inference
start_time = timeit.default_timer()
print(q_network(nn_state))
print(f'Inference time: {timeit.default_timer() - start_time}')

NN State shape: torch.Size([40003])
NN State shape: torch.Size([40003])
tensor([ 0.0074, -0.0027,  0.0533, -0.0953,  0.0834,  0.0507,  0.0026,  0.0062,
         0.0528, -0.0193, -0.0772,  0.0472, -0.0458,  0.0681,  0.0185, -0.0296,
         0.0027, -0.0566,  0.0678, -0.0526, -0.0072,  0.0164,  0.0632, -0.0376,
         0.0647,  0.0160], device='cuda:0', grad_fn=<ViewBackward0>)
Inference time: 0.0024198060000344412


In [132]:
tuple_of_tensors = (torch.tensor([1, 2, 3], device=mcts_rl.device), torch.tensor([4, 5, 6], device=mcts_rl.device))
print(f'Tuple of tensors: {tuple_of_tensors}')
print(f'Concatenated tensors: {torch.stack(tuple_of_tensors).shape}')
tensor_stack = torch.stack(tuple_of_tensors)
tensor_stack.gather(1, torch.tensor([[2], [1]], device=mcts_rl.device)).squeeze()

Tuple of tensors: (tensor([1, 2, 3], device='cuda:0'), tensor([4, 5, 6], device='cuda:0'))
Concatenated tensors: torch.Size([2, 3])


tensor([3, 5], device='cuda:0')

In [106]:
tuple_of_tensors = (torch.tensor([1], device=mcts_rl.device), torch.tensor([4], device=mcts_rl.device))
print(f'Tuple of tensors: {tuple_of_tensors}')
print(f'Concatenated tensors: {torch.cat(tuple_of_tensors).shape}')

Tuple of tensors: (tensor([1], device='cuda:0'), tensor([4], device='cuda:0'))
Concatenated tensors: torch.Size([2])
