# Imports

In [127]:
from treelib import Node, Tree
import tqdm

import os
import time
import pickle
import keyboard
import numpy as np
import random
import pandas as pd
from scipy.interpolate import interp1d
import matplotlib.pyplot as plt

import tminterface as tmi
from tminterface.interface import TMInterface, Client

# Useful Functions

In [128]:
def discrete_to_continuous(n):

    """
    Equivalents:
    
    0: no action
    1: left
    2: left + acceleration
    3: acceleration
    4: right + acceleration
    5: right
    """

    current_action = {
    'sim_clear_buffer': True,  
    "steer":           0,
    "accelerate":      False, 
    "brake" :          False
    }
    
    if n == 1:
        current_action["steer"] = -65536
    if n == 2:
        current_action["steer"] = -65536
        current_action["accelerate"] = True
    if n == 3:
        current_action["accelerate"] = True
    if n == 4:
        current_action["steer"] = 65536
        current_action["accelerate"] = True
    if n == 5:
        current_action["steer"] = 65536
        
    return current_action
    

def distance_3D(x, y, z, x0, y0, z0):
    d_x = x - x0
    d_y = y - y0
    d_z = z - z0
    dis = np.sqrt( d_x**2 + d_y**2 + d_z**2)
    return dis

def centerline_objective(track_name):
    run_folder = "track_data/"+ track_name + "/run-1"
    positions = pickle.load(open(os.path.join(run_folder, "positions.pkl"), "rb"))
    finish_time = positions[-1]["time"]/1000

    raw_points = [list(pos['position'].to_numpy()) for pos in positions]
    df = pd.DataFrame(raw_points)
    ema = df.ewm(com=40).mean()
    raw_points = ema.values.tolist()
        
    # remove duplicates:
    points = [raw_points[0]]
    for point in raw_points[1:]:
        if point != points[-1]:
            points.append(point)
        else:
            for i in range(len(point)):
                point[i] += 0.01
            points.append(point)
    points = np.array(points)

    # Time along the track:
    time = np.linspace(0, 1, len(points))

    interpolator =  interp1d(time, points, kind='slinear', axis=0)
    alpha = np.linspace(0, 1, len(points))
    curve = interpolator(alpha)

    return curve, alpha

# Explorer Class

In [142]:
class TreeExplorer():
    def __init__(self, root_start_state):
        self.tree = Tree()
        self.node_count = 0
        self.tree.create_node(identifier="root", data={"start_state":root_start_state})
        self.current_position = self.tree["root"]
        self.possible_actions = [2, 3, 4]
        self.terminated = False
        self.recursion_depth = 0
        
        self.depth = 0
        
    def explore_node(self):
        self.recursion_depth += 1
        
        explored_nodes = self.tree.children(self.current_position.identifier)
        tried_actions = [node.data["action"] for node in explored_nodes]
        local_possible_actions = []
        for action in self.possible_actions:
            if action not in tried_actions:
                local_possible_actions.append(action)

        # Unexplored node is available
        if len(local_possible_actions) > 0:
            if self.recursion_depth > 5:
                print(self.recursion_depth)
            return local_possible_actions[0]

        # All nodes have been explored 
        else:
            best_perf = -np.inf
            future_node = None
            
            # Evaluate all children nodes
            for child in explored_nodes:
                
                # Check for success
                if child.data["success"]:
                    self.terminated = True
                    future_node = child
                    break
                
                # Pick the best one
                if child.data["viable"] and child.data["perf"] > best_perf:
                    best_perf = child.data["perf"]
                    future_node = child
                
            # Come back if no child is viable
            if future_node is None:
                # Check for root
                if self.current_position.identifier == "root":
                    self.terminated = True
                    return None

                else:
                    self.current_position.data["viable"] = False

                    current_identifier = self.current_position.identifier
                    current_data = self.current_position.data
                    self.current_position = self.tree.parent(current_identifier)
                    self.tree.remove_node(current_identifier)
                    self.record_leaf(current_data["action"], 
                                     current_data["perf"], 
                                     current_data["viable"], 
                                     current_data["success"], 
                                     current_data["start_state"], 
                                     current_data["_time"])

                    self.depth = self.depth - 1
                
            else:    
                self.depth = self.depth + 1
                self.current_position = future_node

        return self.explore_node()
        

    def record_leaf(self, action, perf, viable, success, start_state, _time): 
        self.recursion_depth = 0
        data = {"action": action,
                "perf": perf,
                "viable": viable,
                "success": success,
                "start_state": start_state, 
                "_time":_time}
        
        if viable:
            sign = "O"
        else:
            sign = "X"
        
        node_id = self.tree.create_node(identifier=f"{_time+10}_{action}_{sign}_{self.node_count}", 
                                        parent=self.current_position, 
                                        data=data).identifier
        self.node_count += 1
        
    def reconstruct_trajectory(self):
        reconstruct_position = self.current_position
        action_list = []
        
        while reconstruct_position.identifier != "root":
            action_list.append(reconstruct_position.data["action"])
            reconstruct_position = self.tree.parent(reconstruct_position.identifier)
            
        action_list.reverse()
        return action_list
    

# Client Classes

## Abstract Client

In [143]:
class AbstractClient(Client):

    def __init__(self):
        super().__init__()
        self.period_ms = 1000
        self.final_state = None
        self.start_state = None
        self.is_finished = False
        self.selected_action = False
        self.crashed = False
        
    def on_registered(self, iface: TMInterface) -> None:
        iface.execute_command("press delete")
        print(f'Registered to {iface.server_name}')

    def on_run_step(self, iface, _time: int):
        self.action(iface, _time)

    def on_checkpoint_count_changed(self, iface, current: int, target: int):
        if current >= 1 and current == target:
            self.is_finished = True
            print(iface.get_simulation_state().position)
            iface.prevent_simulation_finish()

    def reset_detection(self, _time, state):
        if state.position[1] < 9.2:
            return True
    
        if _time >= 500:
            local_velocity = state.scene_mobil.current_local_speed
            local_velocity = np.array(list(local_velocity.to_numpy()))
            local_velocity = local_velocity*3.6
            if local_velocity[2] < 10:
                return True

        if state.scene_mobil.has_any_lateral_contact:
            return True

        return False

    def action(self, iface, _time: int):
        if _time >= 0:
            # print(self.selected_action, _time)
            command = discrete_to_continuous(self.selected_action)
            iface.set_input_state(**command)

            if self.reset_detection(_time, iface.get_simulation_state()):
                self.crashed = True
                self.finish(iface, _time)
            else: 
                self.crashed = False
                
            if _time == self.anchor:
                self.final_state = iface.get_simulation_state()
                self.finish(iface,  _time)

    def finish(self, iface,  _time):
        # self.final_state = iface.get_simulation_state()
        iface.rewind_to_state(self.start_state)

## Training Client and Replay Client

In [144]:
class TrainingClient(AbstractClient):

    def __init__(self, period=1000, training_track_name="Deterministic_Proof"):
        super().__init__()

        self.period_ms = period

        # Centerline loading
        centerline, alpha = centerline_objective(training_track_name)
        self.centerline = centerline
        self.alpha = alpha
        self.centerline_x = self.centerline[:,0]
        self.centerline_y = self.centerline[:,1]
        self.centerline_z = self.centerline[:,2]

        # Explorer parameters
        self.anchor = self.period_ms - 10
        self.explorer = None
        self.selected_action = None
        self.save_state = None

    def objective_function(self):
        # position = self.final_state.position
        
        # # compute distance 
        # dis = distance_3D(self.centerline_x, self.centerline_y, self.centerline_z, 
        #                   position[0], position[1], position[2])
        # # find the minima
        # glob_min_idx = np.argmin(dis)
        # associated_time = self.alpha[glob_min_idx]
        # return associated_time

        speed = self.final_state.scene_mobil.current_local_speed[2]
        return speed

    def on_run_step(self, iface, _time: int):
        
        # IMPORTANT: 1 step offset to prevent missing inputs
        if _time == -10 and self.explorer is None: 
            root_start_state = iface.get_simulation_state()
            self.explorer = TreeExplorer(root_start_state)
            self.selected_action = self.explorer.explore_node()
            self.start_state = self.explorer.current_position.data["start_state"]

        if self.explorer is not None:
            self.action(iface, _time)
            
    def finish(self, iface, _time):

        # Record leaf outcomes
        viable = self.crashed is False
        action = self.selected_action 
        if viable:
            perf = self.objective_function()
            success = self.is_finished
            start_state = self.final_state
        else:
            perf = None
            success = False
            start_state = None

        if self.explorer.terminated:
            self.is_finished = True
        else:
            self.explorer.record_leaf(action, perf, viable, success, start_state, _time)
    
            # Explore a new leaf
            self.selected_action = self.explorer.explore_node()
            self.start_state = self.explorer.current_position.data["start_state"]
            self.anchor = self.explorer.current_position.data["start_state"].race_time + self.period_ms
    
            iface.rewind_to_state(self.start_state)

    def reconstruct_trajectory(self):
        trajectory = self.explorer.reconstruct_trajectory()
        trajectory.append(self.selected_action)
        return trajectory
        
class ReplayClient(AbstractClient):

    def __init__(self, period, dna):
        super().__init__()
        self.period_ms = period
        self.dna = dna

    def on_run_step(self, iface, _time: int):
        if self.start_state is not None:
            self.action(iface, _time)
        if _time == - 10:
            self.start_state = iface.get_simulation_state()


    def action(self, iface, _time: int):
        if _time >= 0:
            action = self.dna[_time//self.period_ms]
            command = discrete_to_continuous(action)
            iface.set_input_state(**command)

            if self.reset_detection(_time, iface.get_simulation_state()):
                self.crashed = True
                self.finish(iface, _time)
            else: 
                self.crashed = False
                
    def finish(self, iface, _time):
        self.final_state = iface.get_simulation_state()
        iface.rewind_to_state(self.start_state)

# TRAINING

In [145]:
interface = TMInterface()
client = TrainingClient(period=900)

interface.register(client)
print("Start")

while client.is_finished is False:
    time.sleep(0.001)

    if keyboard.is_pressed("q"):
        print("Keybord Interrupt")
        break

interface.close()

best_trajectory = client.reconstruct_trajectory()
print('\n', best_trajectory)

Start
Registered to TMInterface0

 [None]


# TESTING

In [126]:
interface = TMInterface()
client = ReplayClient(period=400, dna=best_trajectory)

interface.register(client)

print("Start")
while client.is_finished is False:
    time.sleep(0.001)

    if keyboard.is_pressed("q"):
        print("Keybord Interrupt")
        break
interface.close()

Start
Registered to TMInterface0
[374.1940002441406, 9.397768020629883, 456.59576416015625]


In [146]:
client.explorer.tree.show()

root
├── 700_2_X_0
├── 700_4_X_2
└── 900_3_X_26



In [77]:
aaa = 5

a = f"{aaa}"
a

'5'