# Imports

In [11]:
from treelib import Node, Tree
import tqdm

import os
import time
import pickle
import keyboard
import numpy as np
import random
import pandas as pd
from scipy.interpolate import interp1d
import matplotlib.pyplot as plt

import tminterface as tmi
from tminterface.interface import TMInterface, Client

# Useful Functions

In [12]:
def discrete_to_continuous(n):

    """
    Equivalents:
    
    0: no action
    1: left
    2: left + acceleration
    3: acceleration
    4: right + acceleration
    5: right
    """

    current_action = {
    'sim_clear_buffer': True,  
    "steer":           0,
    "accelerate":      False, 
    "brake" :          False
    }
    
    if n == 1:
        current_action["steer"] = -65536
    if n == 2:
        current_action["steer"] = -65536
        current_action["accelerate"] = True
    if n == 3:
        current_action["accelerate"] = True
    if n == 4:
        current_action["steer"] = 65536
        current_action["accelerate"] = True
    if n == 5:
        current_action["steer"] = 65536
        
    return current_action
    

def distance_3D(x, y, z, x0, y0, z0):
    d_x = x - x0
    d_y = y - y0
    d_z = z - z0
    dis = np.sqrt( d_x**2 + d_y**2 + d_z**2)
    return dis

# Explorer Class

In [25]:
class TreeExplorer():
    def __init__(self, root_start_state):
        self.tree = Tree()
        self.tree.create_node(identifier="root", data={"root_start_state":root_start_state})
        self.current_position = self.tree["root"]
        self.possible_actions = [0, 1, 2, 3, 4, 5]
        self.terminated = False
        
        self.depth = 0
        
    def explore_node(self):
        explored_nodes = self.tree.children(self.current_position.identifier)
        tried_actions = [node.data["action"] for node in explored_nodes]
        local_possible_actions = []
        for action in self.possible_actions:
            if action not in tried_actions:
                local_possible_actions.append(action)

        # Unexplored node is available
        if len(local_possible_actions) > 0:
            return local_possible_actions[0]

        # All nodes have been explored 
        else:
            best_perf = -np.inf
            future_node = None
            children = self.tree.children(self.current_position.identifier)
            
            # Evaluate all children nodes
            for child in children:
                
                # Check for success
                if child.data["success"]:
                    self.terminated = True
                    future_node = child
                    break
                
                # Pick the best one
                if child.data["perf"] > best_perf and child.data["viable"]:
                    best_perf = child.data["perf"]
                    future_node = child
                
            # Come back if no child is viable
            if future_node is None:
                self.current_position.data["viable"] = False
                
                # Check for root
                if self.current_position.identifier == "root":
                    self.terminated = True
                else:
                    self.current_position = self.tree.parent(self.current_position.identifier)
                    self.depth = self.depth - 1
                
            else:    
                self.current_position = future_node
                self.depth = self.depth + 1

        return self.explore_node()
        
            
    def record_leaf(self, action, perf, viable, success, start_state):        
        data = {"action": action,
                "perf": perf,
                "viable": viable,
                "success": success,
                "start_state": start_state}
        self.tree.create_node(parent=self.current_position, data=data)
        
    def reconstruct_trajectory(self):
        reconstruct_position = self.current_position
        
        if self.terminated:
            action_list = []
            
            while reconstruct_position.identifier != "root":
                action_list.append(reconstruct_position.data["action"])
                reconstruct_position = self.tree.parent(reconstruct_position.identifier)
                
            action_list.reverse()
            return action_list
    

# Client Classes

## Abstract Client

In [21]:
class AbstractClient(Client):

    def __init__(self):
        super().__init__()
        self.period_ms = 1000
        self.final_state = None
        self.start_state = None
        self.is_finished = False
        self.selected_action = False
        self.crashed = False
        
    def on_registered(self, iface: TMInterface) -> None:
        iface.execute_command("press delete")
        print(f'Registered to {iface.server_name}')

    def on_run_step(self, iface, _time: int):
        self.action(iface, _time)

    def on_checkpoint_count_changed(self, iface, current: int, target: int):
        if current >= 1 and current == target:
            self.is_finished = True
            iface.prevent_simulation_finish()

    def reset_detection(self, _time, state):
        if state.position[1] < 9.2:
            return True
    
        if _time >= 500:
            local_velocity = state.scene_mobil.current_local_speed
            local_velocity = np.array(list(local_velocity.to_numpy()))
            local_velocity = local_velocity*3.6 
            if local_velocity[2] < 1:
                return True

        if state.scene_mobil.has_any_lateral_contact:
            return True

        return False

    def action(self, iface, _time: int):
        if _time >= 0:

            command = discrete_to_continuous(self.selected_action)
            iface.set_input_state(**command)

            if self.reset_detection(_time, iface.get_simulation_state()):
                self.crashed = True
                self.finish(iface, _time)
            else: 
                self.crashed = False
                
        if _time == self.start_state.race_time + self.period_ms:
            self.finish(iface, _time)

    def finish(self, iface,  _time):
        self.final_state = iface.get_simulation_state()
        iface.rewind_to_state(self.start_state)

## Training Client and Replay Client

In [22]:
class TrainingClient(AbstractClient):

    def __init__(self, period=500, training_track_name="Deterministic_Proof"):
        super().__init__()

        self.period_ms = period

        # Centerline loading
        centerline, alpha = centerline_objective(training_track_name)
        self.centerline = centerline
        self.alpha = alpha
        self.centerline_x = self.centerline[:,0]
        self.centerline_y = self.centerline[:,1]
        self.centerline_z = self.centerline[:,2]

        # Explorer parameters
        self.anchor = 0
        self.explorer = None
        self.selected_action = None
        

    def objective_function(self):
        position = self.final_state.position
        
        # compute distance 
        dis = distance_3D(self.centerline_x, self.centerline_y, self.centerline_z, 
                          position[0], position[1], position[2])
        # find the minima
        glob_min_idx = np.argmin(dis)
        associated_time = self.alpha[glob_min_idx]
        return associated_time

    def on_run_step(self, iface, _time: int):
        if _time == -10 and self.explorer is None: # IMPORTANT: 1 step offset to prevent missing inputs
            root_start_state = iface.get_simulation_state()
            self.explorer = TreeExplorer(root_start_state)
            self.action = self.explorer.explore_node()
            self.start_state = self.explorer.current_position.data["state"]
        
        self.action(iface, _time)
            
    def finish(self, iface, _time):
        self.final_state = iface.get_simulation_state()

        # Record leaf outcomes
        action = self.selected_action 
        perf = self.objective_function()
        viable = self.crashed is False
        success = self.is_finished
        start_state = self.final_state
        self.explorer.record_leaf(action, perf, viable, success, start_state)

        # Explore a new leaf
        self.selected_action = self.explorer.explore(node)
        self.start_state = self.explorer.current_position.data["state"]
        iface.rewind_to_state(self.start_state)
        
class ReplayClient(AbstractClient):

    def __init__(self, period, dna):
        super().__init__()
        self.period_ms = period
        self.dna = dna

    def on_run_step(self, iface, _time: int):
        self.action(iface, _time)
        if _time == -10:
            self.start_state = iface.get_simulation_state()
                
    def finish(self, iface, _time):
        self.final_state = iface.get_simulation_state()
        iface.rewind_to_state(self.start_state)

# Training Client

In [23]:
interface = TMInterface()
client = TrainingClient(period= 500)

interface.register(client)
print("Start")

while client.is_finished is False:
    time.sleep(0.001)

    if keyboard.is_pressed("q"):
        print("Keybord Interrupt")
        break

if client.is_finished:
    best_memory = client.finish_dna
else:
    best_memory = client.memory

interface.close()
print(best_memory)

TypeError: TrainingClient.__init__() got an unexpected keyword argument 'n_steps'

# TESTING

In [None]:
explorer = TreeExplorer(None)
while True:
    time.sleep(1)

    action = explorer.explore_node()
    perf = np.random.random()
    viable = np.random.random() < 0.8
    success = np.random.random() < 0.01
    start_state = None

    explorer.record_leaf(action, perf, viable, success, start_state)

    print(explorer.depth)

0
0
0
0
0
0
1
1
1
1
1
1
2
2
2
2
2
2
3
3
3
3
3
3
4
4
4
4
4
