In [1]:
import numpy as np

class OUNoise(object):
    def __init__(self, action_space, mu=0.0, theta=0.15, max_sigma=0.7, min_sigma=0.4, decay_period=600_000):
        self.mu = mu
        self.theta = theta
        self.sigma = max_sigma
        self.max_sigma = max_sigma
        self.min_sigma = min_sigma
        self.decay_period = decay_period
        self.action_dim = action_space
        self.reset()

    def reset(self):
        self.state = np.ones(self.action_dim) * self.mu

    def evolve_state(self):
        x = self.state
        dx = self.theta * (self.mu - x) + self.sigma * np.random.randn(self.action_dim)
        self.state = x + dx
        return self.state

    def get_noise(self, t=0):
        ou_state = self.evolve_state()
        decaying = float(float(t) / self.decay_period)
        self.sigma = max(self.sigma - (self.max_sigma - self.min_sigma) * min(1.0, decaying), self.min_sigma)
        print('sigma:', self.sigma, 'state:', ou_state)
        return ou_state

In [2]:
noise = OUNoise(action_space=2, max_sigma=0.9, min_sigma=0.1, decay_period=500_000)

In [3]:
noise.get_noise(t=60_00)

sigma: 0.8904 state: [-0.04716897 -0.49914117]


array([-0.04716897, -0.49914117])

In [4]:
from env_utils import GoalManager

In [5]:
GM = GoalManager()

[92mObstacle name: wall_outler, base pose: (0.0, 0.0, 0.0)[0m
Coordinates: [[[11.5, -8.425], [11.5, -11.575], [-11.5, -11.575], [-11.5, -8.425]], [[21.5, 1.575], [21.5, -1.575], [-1.5, -1.575], [-1.5, 1.575]], [[11.5, 11.575], [11.5, 8.425], [-11.5, 8.425], [-11.5, 11.575]], [[1.5, 1.575], [1.5, -1.575], [-21.5, -1.575], [-21.5, 1.575]]]

[92mObstacle name: wall_single_5m_1, base pose: (3.5, 7.5, 1.5708)[0m
Coordinates: [[[5.125, 11.5], [5.125, 3.5], [1.875, 3.5], [1.875, 11.5]]]

[92mObstacle name: wall_single_5m_2, base pose: (7.5, 1.5, 0.0)[0m
Coordinates: [[[11.5, 3.125], [11.5, -0.125], [3.5, -0.125], [3.5, 3.125]]]

[92mObstacle name: wall_single_5m_3, base pose: (-2.0, -7.5, 1.5708)[0m
Coordinates: [[[-0.375, -3.5], [-0.375, -11.5], [-3.625, -11.5], [-3.625, -3.5]]]

[92mObstacle name: wall_single_5m_4, base pose: (-6.0, -1.5, 1.5708)[0m
Coordinates: [[[-4.375, 2.5], [-4.375, -5.5], [-7.625, -5.5], [-7.625, 2.5]]]

[92mObstacle name: wall_Lshape_2_1, base pose: (4.5, 

In [8]:
#!/usr/bin/env python

from settings.constparams import ENABLE_VISUAL, NUM_SCAN_SAMPLES

if ENABLE_VISUAL:
    from PyQt5 import QtWidgets, QtCore
    import pyqtgraph as pg
    import numpy as np
    import time

    import torch

    pg.setConfigOptions(antialias=False)

    class DrlVisual(QtWidgets.QWidget):
        def __init__(self, state_size, hidden_size):
            super().__init__(None)
            # Set the window 
            self.show()
            self.resize(1980, 1200)

            self.state_size = state_size
            self.hidden_sizes = [hidden_size, hidden_size]

            # Create the main layout
            self.mainLayout = QtWidgets.QVBoxLayout()
            self.setLayout(self.mainLayout)

            # Add tab widget
            self.tab_widget = QtWidgets.QTabWidget()
            self.mainLayout.addWidget(self.tab_widget)

            # Create the tab for the visualizer
            # DRL State tab
            self.tab_state = QtWidgets.QWidget()
            self.tab_widget.addTab(self.tab_state, "DRL State Visualizer")
            # Actor tab
            self.tab_actor = QtWidgets.QWidget()
            self.tab_widget.addTab(self.tab_actor, "Actor Visualizer")
            # Critic tab
            self.tab_critic = QtWidgets.QWidget()
            self.tab_widget.addTab(self.tab_critic, "Critic Visualizer")

            # Tab Initialization
            self.init_tab_state()
            # self.init_tab_actor()
            # self.init_tab_critic()
            
            self.iteration = 0

        def init_tab_state(self):
            # Create the state layout
            self.tab_state_layout = QtWidgets.QVBoxLayout(self.tab_state)

            # Create the state graph layout
            self.tab_state_graph_layout = pg.GraphicsLayoutWidget()
            self.tab_state_layout.addWidget(self.tab_state_graph_layout)

            # -------- All State plots -------- #
            self.plot_all_states_item = self.tab_state_graph_layout.addPlot(title="All States"              , row=0, col=0, colspan=5)
            self.plot_all_states_item.setXRange(-1, self.state_size, padding=0)
            self.plot_all_states_item.setYRange(-1, 1, padding=0)
            self.bar_graph_all_states           =   pg.BarGraphItem(x=range(self.state_size), height = np.zeros(self.state_size), width=0.8)
            self.plot_all_states_item.addItem(self.bar_graph_all_states)

            # -------- Details State plots -------- #
            # XY Plane plot
            self.xy_plane_item = self.tab_state_graph_layout.addPlot(title="XY Plane"                       , row=1, col=0, colspan=3, rowspan=3)
            self.xy_plane_item.setXRange(-1, 1, padding=0)
            self.xy_plane_item.setYRange(-1, 1, padding=0)
            self.xy_plane_item.showGrid(x=True, y=True)
            # Position
            self.scatter_robot_xy               =   pg.ScatterPlotItem(x=[0], y=[0], size=10, pen=pg.mkPen(None), brush=pg.mkBrush(0, 255, 0, 255))
            self.scatter_obstacle_xy            =   pg.ScatterPlotItem(x=[0], y=[0], size=10, pen=pg.mkPen(None), brush=pg.mkBrush(255, 0, 0, 255))
            # Velocity
            self.arrow_robot_vx                 =   pg.PlotDataItem([0, 0], [0, 0], pen=pg.mkPen({'color': "#00FF00", 'width': 2}))
            self.arrow_robot_vy                 =   pg.PlotDataItem([0, 0], [0, 0], pen=pg.mkPen({'color': "#00FF00", 'width': 2}))
            self.arrow_obstacle_vx              =   pg.PlotDataItem([0, 0], [0, 0], pen=pg.mkPen({'color': "#FF0000", 'width': 2}))
            self.arrow_obstacle_vy              =   pg.PlotDataItem([0, 0], [0, 0], pen=pg.mkPen({'color': "#FF0000", 'width': 2}))
        
            self.xy_plane_item.addItem(self.scatter_robot_xy)
            self.xy_plane_item.addItem(self.scatter_obstacle_xy)
            self.xy_plane_item.addItem(self.arrow_robot_vx)
            self.xy_plane_item.addItem(self.arrow_robot_vy)
            self.xy_plane_item.addItem(self.arrow_obstacle_vx)
            self.xy_plane_item.addItem(self.arrow_obstacle_vy)

            # Distance to Goal
            self.dtg_item = self.tab_state_graph_layout.addPlot(title="Distance to Goal"                    , row=1, col=3, colspan=1)
            self.dtg_item.setXRange(-1, 1, padding=0)
            self.dtg_item.setYRange(-1, 1, padding=0)
            self.bar_graph_dtg                  =   pg.BarGraphItem(x=[1], height=[0], width=0.5)
            self.dtg_item.addItem(self.bar_graph_dtg)
            # Angle to Goal
            self.atg_item = self.tab_state_graph_layout.addPlot(title="Angle to Goal"                      , row=1, col=4, colspan=1)
            self.atg_item.setXRange(-1, 1, padding=0)
            self.atg_item.setYRange(-1, 1, padding=0)
            self.bar_graph_atg                  =   pg.BarGraphItem(x=[1], height=[0], width=0.5)
            self.bar_graph_atg.setRotation(-90)
            self.atg_item.addItem(self.bar_graph_atg)
            # Theta
            self.theta_item = self.tab_state_graph_layout.addPlot(title="Theta"                            , row=2, col=3, colspan=1)
            self.theta_item.setXRange(-1, 1, padding=0)
            self.theta_item.setYRange(-1, 1, padding=0)
            self.bar_graph_theta                =   pg.BarGraphItem(x=[1], height=[0], width=0.5)
            self.bar_graph_theta.setRotation(-90)
            self.theta_item.addItem(self.bar_graph_theta)
            # Angular Velocity
            self.angular_item = self.tab_state_graph_layout.addPlot(title="Angular Velocity"                , row=2, col=4, colspan=1)
            self.angular_item.setXRange(-1, 1, padding=0)
            self.angular_item.setYRange(-1, 1, padding=0)
            self.bar_graph_angular              =   pg.BarGraphItem(x=[1], height=[0], width=0.5)
            self.bar_graph_angular.setRotation(-90)
            self.angular_item.addItem(self.bar_graph_angular)
            # Last Action linear
            self.last_action_linear_item = self.tab_state_graph_layout.addPlot(title="Last Action Linear"    , row=3, col=3, colspan=1)
            self.last_action_linear_item.setXRange(-1, 1, padding=0)
            self.last_action_linear_item.setYRange(-1, 1, padding=0)
            self.bar_graph_last_action_linear    =   pg.BarGraphItem(x=[1], height=[0], width=0.5)
            self.last_action_linear_item.addItem(self.bar_graph_last_action_linear)
            # Last Action angular
            self.last_action_angular_item = self.tab_state_graph_layout.addPlot(title="Last Action Angular"  , row=3, col=4, colspan=1)
            self.last_action_angular_item.setXRange(-1, 1, padding=0)
            self.last_action_angular_item.setYRange(-1, 1, padding=0)
            self.bar_graph_last_action_angular   =   pg.BarGraphItem(x=[1], height=[0], width=0.5)
            self.bar_graph_last_action_angular.setRotation(-90)
            self.last_action_angular_item.addItem(self.bar_graph_last_action_angular)
    

        def prepare_data(self, tensor : torch.Tensor):
            return tensor.squeeze().flip(0).detach().cpu()

        def update_layers(self, states, actions, hidden, biases):
            # States data
            states_data = self.prepare_data(states)

            # Update the states
            self.bar_graph_all_states.setOpts(height=states_data)

            START_IDX = self.state_size - NUM_SCAN_SAMPLES - 1

            # Update the details states start from the end of the states_data
            self.bar_graph_dtg.setOpts(height=[states_data[START_IDX]])
            self.bar_graph_atg.setOpts(height=[states_data[START_IDX-1]])
            self.bar_graph_theta.setOpts(height=[states_data[START_IDX-4]])
            x = states_data[START_IDX-2]
            y = states_data[START_IDX-3]
            vx = states_data[START_IDX-5]
            vy = states_data[START_IDX-6]
            self.bar_graph_angular.setOpts(height=[states_data[START_IDX-7]])
            obs_x = states_data[START_IDX-8]
            obs_y = states_data[START_IDX-9]
            obs_vx = states_data[START_IDX-10]
            obs_vy = states_data[START_IDX-11]
            self.bar_graph_last_action_linear.setOpts(height=[states_data[START_IDX-12]])
            self.bar_graph_last_action_angular.setOpts(height=[states_data[START_IDX-13]])
        
            # Plot the XY position
            self.scatter_robot_xy.setData(x=[x], y=[y])
            self.scatter_obstacle_xy.setData(x=[obs_x], y=[obs_y])
            # Set the arrow position
            # Calculate the new arrow end position based on velocity
            self.arrow_robot_vx.setData([x, x + vx], [y, y])
            self.arrow_robot_vy.setData([x, x], [y, y + vy])
            self.arrow_obstacle_vx.setData([obs_x, obs_x + obs_vx], [obs_y, obs_y])
            self.arrow_obstacle_vy.setData([obs_x, obs_x], [obs_y, obs_y + obs_vy])

            # Update the Actions
            # actions = actions.detach().cpu().numpy().tolist()
            # self.bar_graph_action_linear.setOpts(height=[actions[0]])
            # self.bar_graph_action_angular.setOpts(height=[actions[1]])
            # for i in range(len(hidden)):
            #     self.hidden_bar_graphs[i].setOpts(height=self.prepare_data(hidden[i]))
            # pg.QtGui.QGuiApplication.processEvents()
            # if self.iteration % 100 == 0:
            #     self.update_bias(biases)
            self.iteration += 1

        def update_bias(self, biases):
            # for i in range(len(biases)):
            #     self.hidden_line_plots[i].setData(y=self.prepare_data(biases[i]))
            pass

        def update_reward(self, acc_reward):
            # self.bar_graph_reward.setOpts(height=[acc_reward])
            # if acc_reward > 0:
            #     self.bar_graph_reward.setOpts(brush='g')
            # else:
            #     self.bar_graph_reward.setOpts(brush='r')
            pass

In [None]:
Visual = DrlVisual(state_size=14, hidden_size=64)