In [1]:
""" Imports """
import csv
from math import floor, ceil
import matplotlib.pyplot as plt
from DPOMDP_Writer.Reader import *
from DPOMDP_Writer.DPOMDPWriterMedium import generate_safety_scenarios

## The following code reads the outputs from the MADP Solver and reasons about their values

In [32]:

#tree_values_s0-ACC.csv

def get_tree_value(scenario,mode):
    tree_file = f"results/tree_values_s{scenario}-ACC.csv"
    with open(tree_file, newline='') as csvfile:
        t_reader = csv.reader(csvfile)
        for row in t_reader:
            if row[0]== mode:
                return float(row[1])

# modes = ["standby", "following", "speedcontrol", "error", "hold", "override"]
# for mode in modes:
#     for scenario in range(0,8):
#         value = get_tree_value(scenario,mode)

"""Finding which scenarios required communication"""

def communicates(mode, scenario, agent):
    if agent == "human":
        suf = "hum"
    elif agent == "machine":
        suf = "mach"
    else:
        print("ERROR. Incorrect agent string input.")
    t_file = f"figs/ACC-{mode}-scen{scenario}_{suf}"
    policy = read_unbalanced_tree(t_file,generate_safety_scenarios()[scenario])
    print(policy)
    
communicates("standby", 0, "human")



## The code below works to merge identical joint policies to make it easier to view policy trees

In [2]:



def check_tree(tree_file1, instance_name, mergefile):
    """ Read through to see if mach + hum trees exist in file """
    """ If policies exist, add this instance name to those policies """
    """ If not, add policy to tree"""
    scenario = int(instance_name[-1])
    trees = open(mergefile, "r")
    reader = trees.readlines()
    trees.close()
    output_lines = ""
    tree_str1 = read_unbalanced_tree(tree_file1,generate_safety_scenarios()[scenario])
    #print(tree_str1)
    new_elem = True
    for elem in reader:
        #print(elem)
        [instance, tree1, newline] = elem.split(":")
        if (tree_str1 == tree1):
            new_instance_name = instance + " + " + instance_name
            output_lines += new_instance_name + ":" + tree1 + ":" + "\n"
            new_elem = False
        else:
            output_lines += elem
    if new_elem:
        #print("Adding to output")
        output_lines += instance_name + ":" + tree_str1 + ":" + "\n"
        #print(output_lines)
    #print(new_elem)
    return(output_lines)


In [3]:
def merge_trees(prefix, modes,agent):
    '''Search all results and combine identical policy trees'''
    if agent == "human":
        suf = "_hum"
    elif agent == "machine":
        suf = "_mach"
    else:
        print("Error! Wrong agent type.")
    merge_name = agent + "merge.csv"
    writer = open(merge_name, "w")
    writer.close()
    for mode in modes:
        for scenario in range(8):
            instance_name = mode + str(scenario)
            #print("Instance Name: " + instance_name)
            filename1 = "figs/" + prefix + mode + "-scen" + str(scenario) + suf
            output = check_tree(filename1, instance_name, merge_name)
            writer = open(merge_name, "w")
            writer.writelines(output)
            writer.close()
            


In [4]:
def make_fig(figname, caption, label):
    """ Output tex fig as string """
    output = ""
    output+="\\begin{figure}\n"
    output+="\\centering\n"
    output+="\\includegraphics[\\textwidth]{\n"
    output+= figname + "}\n"
    output+="\\caption{" + caption + "}\n"
    output+="\\label{fig:" + label + "}\n"
    output += "\\end{figure}\n"
    return output

In [5]:
modes = ["standby", "following", "speedcontrol", "error", "hold", "override"]
merge_trees("ACC", modes, "machine")
#merge_trees("ACC", modes, "human")

In [6]:
""" Read policy and return simplified policy dictionary """
from genericpath import exists


def exists_in_dict(dict, elem):
    # Checks if an element is in a dict. 
    # If yes, return its key
    for x in dict.keys():
        if dict[x] == elem:
            return [True, x]
    return [False, -1]


    
def get_policy_from_tree(tree):
    [nodes,edges] = tree.split(";")
    nodes=nodes.split(">")[:-1]
    edges=edges.split("-")[:-1]
    #print(nodes)
    #print(edges)
    level0 = nodes[0]
    level1 = nodes[1:]
    if len(edges) != len(level1):
        print("ERROR!!")
    policy_dict = {}
    for x in range(len(edges)):
        elem = edges[x]
        obs = elem.split(",")[2]
        # print(obs)
        node = level1[x]
        action = node.split(",")[1]
        #action = analyze_action(action)
        [exists, key] = exists_in_dict(policy_dict, action)
        # print(exists)
        if exists:
            new_key = key + " + " + obs
            del policy_dict[key]
            policy_dict[new_key] = action
        else:
            policy_dict[obs] = action
    print(policy_dict)
    return[level0, policy_dict]    

In [7]:
def shorten_str(inp_str):
    if " + " in inp_str:
        output = ""
        for elem in inp_str.split(" + "):
            output += shorten_str(elem) + "+"
        return output[:-1]
    if inp_str == "following":
        return "\\Foll"
    if inp_str == "following + speedcontrol":
        return "\\FSC"
    if inp_str == "speedcontrol":
        return "\\SC"
    if inp_str == "dontcommunicate":
        return "dont comm"
    if inp_str == "communicate":
        return "comm"
    if inp_str == "dontpushbutton":
        return "dont push"
    if inp_str == "pushbutton":
        return "push"
    if inp_str == "error":
        return "\\Err"
    if inp_str == "standby":
        return "\\Stby"
    if inp_str == "override":
        return "\\OVR"
    if inp_str == "hold":
        return "\\hold"
    return inp_str 


def get_action_str(action):
    phys = action.split("-")[0][1:]
    comm = action.split("-")[1][:-1]
    if phys != "none":
        line = shorten_str(phys) + " + " + shorten_str(comm)
    else:
        line = shorten_str(comm) 
    return line

def read_policy(policy):
    """ Takes output from get_policy_from_tree and makes it more readable """
    [line1,line2] = policy
    init_action = line1.split(",")[1]
    line1_str = get_action_str(init_action)
    #print(line1_str)
    line2_str = ""
    line3_str = ""
    if len(line2.keys()) <= 2:
        for obs in line2.keys():
            line2_str += shorten_str(obs) + ": " + get_action_str(line2[obs]) + "; "
    else:
        counter = 0
        for obs in line2.keys():
            if counter < 2:
                line2_str += shorten_str(obs) + ": " + get_action_str(line2[obs]) + "; "
            else: 
                line3_str += shorten_str(obs) + ": " + get_action_str(line2[obs]) + "; "
            counter += 1
    return [line1_str,line2_str[:-2], line3_str[:-2]]
        
    

In [8]:
def combine_range(inp_list):
    """ combines consecutive numbers in a list 
    e.g. [1,2,3] becomes [1-3] """
    consecutive = False
    range_str = ""
    for x in range(0,len(inp_list)):
        if consecutive:
            #check if last elem
            if (x == len(inp_list) -1):
                last = inp_list[x]
                range_str += str(first) + "-" + str(last)
            #check if still consecutive
            elif inp_list[x] == inp_list[x+1] -1:
                last = inp_list[x+1]
            else:
                range_str += str(first) + "-" + str(last) + "," 
                consecutive = False
        elif (x == len(inp_list) -1):
                #we've reached last elem
                range_str += str(inp_list[x])
        elif inp_list[x] == inp_list[x+1] - 1:
            # number is consecutive, so track it
            consecutive = True
            first = inp_list[x]
            last = inp_list[x+1]
        else:
            #consecutive = False
            range_str += str(inp_list[x]) + ","
    return range_str
    
def combine_scens(scens_list):
    #first, combine by mode
    s_dict = {}
    ret_list = []
    for s in scens_list:
        mode = s[:-1]
        scen = int(s[-1])
        if mode in s_dict.keys():
            temp = s_dict[mode]
            temp.append(scen)
            s_dict[mode] = temp
        else:
            s_dict[mode] = [scen]
    # then, combine by scenario range
    for key in s_dict.keys():
        range_str = combine_range(s_dict[key])
        ret_list.append("\\" + key + "{} " + range_str)
    print(ret_list)
    return ret_list

def list_to_string(inp_list):
    output_str = ""
    for elem in inp_list:
        output_str += str(elem) + " + "
    return output_str[:-2]

def make_table_entry(tree_number, l1, num_per_row):
    [scens, tree, newline] = l1.split(":")
    [hline1,hline2,hline3] = read_policy(get_policy_from_tree(tree))
    
    # print("Human tree:")
    # print(hline1)
    # print(hline2)
    # print("Machine tree:")
    # print(mline1)
    # print(mline2)
    
    l = combine_scens(scens.split(" + "))
    print(l)
    
    num_row = ceil(len(l)/num_per_row)
    i = 0
    if num_row > 1:
        if num_row == 2:
            size = num_row + 1
        else: 
            size = num_row
        output = "\\multirow{" + str(size) + "}{*}{" + str(tree_number) + "} & "
        while len(l) > 0:
            if len(l) > num_per_row:
                line = list_to_string(l[:num_per_row])
                l = l[num_per_row:]
            elif len(l) == num_per_row:
                line = list_to_string(l[:num_per_row])
                l = []
            else:
                line = list_to_string(l)
                l = []
            if (i == 0):
                #top line
                output += line + " + & " + hline1 +  "\\\\\n"
            elif (i == 1) & (i != num_row - 1):
                #secondline but not last
                output += " & " + line + " + & " + hline2 + "\\\\\n"
            elif (i == 1):
                #second and last line
                output += " & " + line + " & " + hline2 +  "\\\\\n"
            elif (i == 2) & (i != num_row - 1):
                #thirdline but not last
                output += " & " + line + " + & " + hline3 + "\\\\\n"
            elif (i == 1):
                #third and last line
                output += " & " + line + " & " + hline3 +  "\\\\\n"
            elif i != num_row - 1: 
                #middle lines
                output += " & " + line + " + & \\\\\n"
            else:
                #last_line
                output += " & " + line + " & \\\\\n"
            i += 1
    else:
        output = "\\multirow{3}{*}{" + str(tree_number) + "} & " 
        output += "\\multirow{3}{*}{" + list_to_string(l) + "} & "
        output += hline1 +  "\\\\\n"
        output += "& & " + hline2 + "\\\\\n"
        output += "& & " + hline3 + "\\\\\n"
    #if num_row == 2:
        #handles case where we need a blank line
        #output += "& & " + hline3 + "\\\\\n"
    return output

def start_table_str():
    output = "\\begin{table}[]\n"
    output += "\\centering\n"
    output += "\\begin{tabular}{c c c}\n"
    output += "\\toprule\n"
    output += "Tree & Applicable Scenarios & Policy  \\\\ \n"
    output += "\\toprule\n"
    return output
    
def end_table_str(caption):
    output = "\\bottomrule"
    output += "\\end{tabular}\n"
    output += "\\caption{" + caption + "}\n"
    output += "\\label{tab:my_label}\n"
    output += "\\end{table}\n"
    output += "\n\n"
    return output

In [9]:
def get_tex_table_from_mergefile(merge_name,caption):
    f = open("machinemerge.csv", "r")
    lines = f.readlines()
    f.close()
    tree_number = 0
    output = start_table_str()
    for l in lines:
        output += make_table_entry(tree_number, l, num_per_row=1)
        if tree_number == 10:
            output += end_table_str(caption)
            output += start_table_str()
        else:
            output += "\\midrule\\\\\n"
        tree_number += 1
    output += end_table_str(caption)
    return output

caption = "Human Policies"
g = open("humantable.tex", "w")
g.writelines(get_tex_table_from_mergefile("humanmerge.csv", caption))
g.close()

caption = "Machine Policies"
g = open("machinetable.tex", "w")
g.writelines(get_tex_table_from_mergefile("machinemerge.csv", caption))
g.close()

    

{'following': '"accel-dontcommunicate"', 'speedcontrol': '"decel-dontcommunicate"', 'hold': '"none-dontcommunicate"'}
['\\standby{} 0', '\\override{} 0']
['\\standby{} 0', '\\override{} 0']
{'following + speedcontrol': '"safest-dontcommunicate"', 'hold': '"none-dontcommunicate"'}
['\\standby{} 1-7']
['\\standby{} 1-7']
{'following': '"accel-dontcommunicate"', 'speedcontrol': '"decel-dontcommunicate"', 'error': '"none-dontcommunicate"'}
['\\following{} 0']
['\\following{} 0']
{'following + speedcontrol': '"safest-dontcommunicate"', 'error': '"none-dontcommunicate"'}
['\\following{} 1-7', '\\speedcontrol{} 1-7']
['\\following{} 1-7', '\\speedcontrol{} 1-7']
{'following': '"decel-dontcommunicate"', 'speedcontrol': '"accel-dontcommunicate"', 'error': '"none-dontcommunicate"'}
['\\speedcontrol{} 0']
['\\speedcontrol{} 0']
{'standby': '"none-dontcommunicate"'}
['\\error{} 0-3,5-6']
['\\error{} 0-3,5-6']
{'error': '"none-dontcommunicate"'}
['\\error{} 4,7']
['\\error{} 4,7']
{'following + spe