# Define Environment

In [None]:
import ipywidgets
import ipywidgets.widgets as widgets
from IPython.display import display
import time
from IPython.display import clear_output as cls
import matplotlib.pyplot as plt
%matplotlib inline  

In [1]:
class Environment():
    
    '''
    This wrapper works as an RL-Environment for the Jetson Nano.
    Arguments are: 
        model  -> pytorch model trained in Unity.
        angles -> list of angles required to measure
        FOV    -> How wide the camera lens is. Default = 160
    '''
    
    def __init__(self, model, angles, robot, FOV=160, speed_move=0.4, time_move=0.5, speed_rotate=0.3, time_rotate=0.3):
        self.previous_readings = {x:0 for x in angles} 
        self.angles = angles
        self.model = model
        self.FOV = FOV
        self.robot = robot
        self.speed_move = speed_move
        self.time_move = time_move
        self.speed_rotate = speed_rotate
        self.time_rotate = time_rotate
        
    def calculate_angle_and_phase(self):
        keypoints, image, counts, objects, peaks = execute2({'new': camera.value})
        angle, keypoints = calculate_angle(WIDTH, keypoints, self.FOV)
        phase = self.calculate_phase(keypoints)
        return phase, angle
    
    
    def step(tensor):
        turn, move = tensor.cpu()
        
    
    def stop(self):
        self.robot.stop()

        
    def step_forward(self):
        self.robot.forward(self.speed_move)
        time.sleep(self.time_move)
        self.robot.stop()

        
    def step_backward(self):
        self.robot.backward(self.speed_move)
        time.sleep(self.time_move)
        self.robot.stop()

        
    def step_left(self):
        self.robot.left(self.speed_rotate)
        time.sleep(self.time_rotate)
        self.robot.stop()

        
    def step_right(self):
        self.robot.right(self.speed_rotate)
        time.sleep(self.time_rotate)
        self.robot.stop()
    
    
    def calculate_phase(self, keypoints):
        phase = estimating_phase(keypoints)
        return phase
    
    
    def read_lidar(self):
        self.previous_readings = read_lidar_wrapper(self.angles,self.previous_readings)
        
    
    def observe(self):
        self.read_lidar()
        phase, angle = self.calculate_angle_and_phase()
        
        if isinstance(angle, dict):
            angles = list(angle.values())
        
        else: angles = [-1]
        
        observation = phase + angle + list(self.previous_readings.values())
        return observation
    
    
    def sample_action(self, observation):
        observation = torch.Tensor(observation).cuda()
        hidden,_ = self.model.network_body(vis_inputs=[0],vec_inputs=[observation])
        distribution = self.model.distribution(hidden)
        action = distribution[0].sample()
        
        reward = self.model.critic(vis_inputs=[0],vec_inputs=[observation])
        reward = reward[0]['extrinsic'].cpu().detach().numpy()
        return action[0], reward
    
    
    def step(self,action):
        action = action.cpu().detach().numpy()
        speed_move, speed_turn = float(action[0]), float(action[1])
        
        #speed_move = (abs(speed_move)-0)*(1-0.5)/(10-0)+0.5
        #speed_turn = (abs(speed_turn)-0)*(1-0.5)/(10-0)+0.5
        
        #speed_move/=10
        #if speed_move<0.5: speed_move=0.4
        #    
        #speed_turn/=10
        #if abs(speed_turn) < 0.5: speed_turn=0.3
        
        # backward
        if action[0] < 0:
            robot.backward(abs(speed_move))
            time.sleep(0.2)
            robot.stop()
            
        # forward
        else:
            robot.forward(abs(speed_move))
            time.sleep(0.2)
            robot.stop()
            
        time.sleep(0.2)
            
        # turn left
        if action[1] < 0:
            robot.left(abs(speed_turn))
            time.sleep(0.1)
            robot.stop()
            
        # turn right
        else:
            robot.right(abs(speed_turn))
            time.sleep(0.1)
            robot.stop()
            
        #print(speed_move,speed_turn)
        
        
