In [None]:
import yaml
from lxml import etree


In [4]:
class detector_config_yaml:
    
    def __init__(self, file_path):
        with open(file_path, 'r') as f:
            self.content = yaml.safe_load(f)
            self.file_path = file_path
    
    def get_elem(self, key_path):
        
        path = key_path.split(".")
        curr_path = ""
        curr_node = self.content
        #Traverse tree until node that contains elem is reached.
        for key in path:
            if (key not in curr_node.keys()):
                raise Exception(curr_path + " not found")
            else:
                curr_node = curr_node[key]
                
        #Return value of found node.
        return curr_node
            
                                
    def set_elem(self, key_path, elem):
        path = key_path.split(".")
        curr_path = ""
        curr_node = self.content
        
        #Traverse tree until the parent node of the desired elem is reached.
        for key in path[:-1]:
            print(key)
            curr_path += ("." + key)
            if (key not in curr_node.keys()):
                raise Exception(curr_path + " not found")
            else:
                curr_node = curr_node[key]
                
        #Set value at desired path
        curr_node[path[-1]] = elem
                                
    def add_elem(self, key_path, elem):
        path = key_path.split(".")
        curr_path = ""
        curr_node = self.content
        
        #Traverse tree until parent node of the desired elem is reached.
        for key in path[:-1]:

            curr_path += ("." + key)
            if (key not in curr_node.keys()):
                raise Exception(curr_path + " not found")
            else:
                curr_node = curr_node[key]
                
        #Set value at desired path
        curr_node[path[-1]] = elem
                                
    def save_as(self, new_filepath):
        
        with open(new_filepath, 'w') as writeable:
            yaml.dump(self.content, writeable)
        
    def save(self):
        self.save_as(self.file_path)
       
                                              

In [None]:
def tag_to_xpath(tags=""):
    if len(tags) == 0:
        return ""
    if isinstance(tags, str):
        if not tags.startswith("//") and not tags.startswith(".//"):
            return '//{t}'.format(t=tags)
        else:
            return tags
    elif isinstance(tags, list):
        path=""
        for tag in tags:
            path += tag_to_xpath(tag)
        return path


def attribute_to_xpath(attribute, value):
    if len(attribute) == 0:
        return ""
    if (value == None):
        return '@{a}'.format(a=attribute)
    return '@{a}="{val}"'.format(a=attribute, val=value)

def attribute_dict_to_xpath(attribute_dict={}):
    attribute_str = ""
    attributes = list(attribute_dict.keys())

    if (len(attributes) == 0):
        return ""
    
    for attribute in attributes[:-1]:
        attribute_str += attribute_to_xpath(attribute, attribute_dict[attribute])
        attribute_str += (" and ")
    last_attribute = attributes[-1]
    attribute_str += attribute_to_xpath(last_attribute, attribute_dict[last_attribute])
    
    return "[{all}]".format(all=attribute_str)


class detector_config_xml:

    def __init__(self, filepath, autosave=False):

        try:
            self.tree = etree.parse(filepath)
            self.root = self.tree.getroot()
            self.filepath = filepath
            self.autosave = autosave
        except:
            raise Exception("No XML File found at: " + str(filepath))
        
    def get_root(self):
        return self.root
        
    def autosave(func):
        def wrapper(self, *args, **kw):
            #Call function before post-processing
            out = func(self, *args, **kw)
            #Auto save post processing
            if (self.autosave):
                self.save()
            return out
        return wrapper

    def save(self, filepath=""):
        if (len(filepath) == 0):
            to_save = self.filepath
        else:
            to_save = filepath          
        try:
            self.tree.write(to_save, xml_declaration=True)
        except:
            print("Could not save file at: " + to_save)    

    def get_elements(self, ancestor_tag_="", ancestor_attributes_={}, element_tag_="//*", element_attributes_={}):

        
        #Format attributes for xml.eTree findall function call.

        full_xpath = ""
        full_xpath += (tag_to_xpath(ancestor_tag_))
        full_xpath += (attribute_dict_to_xpath(ancestor_attributes_))
        full_xpath += (tag_to_xpath(element_tag_))
        full_xpath += (attribute_dict_to_xpath(element_attributes_))

        print(full_xpath)
        found_keys = self.root.xpath(full_xpath)
        
        if len(found_keys) == 0:
            print("Could not find path: " + full_xpath)
        elif len(found_keys) == 1:
            return found_keys[0]
        else:
            elements = []
            for element in found_keys:
                elements.append(element)
            return elements

    def get_val(self, ancestor_tag="", ancestor_attributes={}, element_tag="//*", element_attributes={}):

        elems = []
        for element in self.get_elements(ancestor_tag_=ancestor_tag, ancestor_attributes_=ancestor_attributes, element_tag_=element_tag, element_attributes_=element_attributes):
            elems.append(element.text)
        return elems

    @autosave
    def set_text(self, text, ancestor_tag="", ancestor_attributes={}, element_tag="//*", element_attributes={}):

        for element in self.get_elements(ancestor_tag_=ancestor_tag, ancestor_attributes_=ancestor_attributes, element_tag_=element_tag, element_attributes_=element_attributes):
            element.text = text

    @autosave
    def set_attribute(self, attr, val, ancestor_tag="", ancestor_attributes={}, element_tag="//*", element_attributes={}):

        for element in self.get_elements(ancestor_tag_=ancestor_tag, ancestor_attributes_=ancestor_attributes, element_tag_=element_tag, element_attributes_=element_attributes):
            element.set(attr, val)

    @autosave
    def add_element(self, new_tag, new_text, new_attributes={}, parent_tag="//*", parent_attributes={}, ancestor_tag="", ancestor_attributes={}):
        new_element = etree.Element(new_tag)
        new_element.text = new_text
        for attr in new_attributes.keys():
            new_element.set(attr, new_attributes[attr])
        for element in self.get_elements(ancestor_tag_=ancestor_tag, ancestor_attributes_=ancestor_attributes, element_tag_=parent_tag, element_attributes_=parent_attributes):
            element.append(new_element)
            
    

In [27]:
import os

#Simple test 
cwd = os.getcwd()
config_path = os.path.join(cwd, "test.yaml")
config = detector_config_yaml(file_path=config_path)

config.set_elem("one.a.I", "changed to this")
config.add_elem("one.a.III", "adding")
config.save()

one
a


In [46]:
xml_path = os.path.join(cwd, "xml_test.xml")
xml_config = detector_config_xml(xml_path, autosave=True)


for elem in xml_config.get_elements(ancestor_tag_="detector", ancestor_attributes_={"id" : "TrackerBarrel_1_ID"}, element_attributes_={"sensitive" : None , "name" : "Service"}):
    print(elem)
# xml_config.set_attribute(attr="sensitive", val="false", element_attributes={"sensitive" : "true", "name" : "Service"})





//detector[@id="TrackerBarrel_1_ID"]//*[@sensitive and @name="Service"]


In [None]:
class BenchmarkConfig:

    def __init__(self, *args):

        print("example init")

    def set_benchmark_name(self, *args):

        print("Name set")

    def set_git_branch(self, *args):

        print("git branch set")


    
class DetectorConfig:

    def __init__(self, *args):

        print("example init")


    def SetModuleAttribute(module_name, attribute_name, attribute_value):

        print("set attribute at ...")

class SimulationConfig:

    def __init__(self, *args):

        print("example init")

    #in GeV
    def set_max_momentum(p):

        print("Max momentum set")

    def set_particle(particle):

        print("Particle set")

    def set_max_theta(theta):

        print("Max theta set")
    
    def set_max_eta(eta):

        print("Max eta set")


def run_benchmark(benchmark_config, detector_config, simulation_config):

    print("running benchmark")

    #Run benchmark on compute node with snake make

benchmark_config_1 = "benchmark_config_1"
detector_config_1 = "detector_config_1"
simulation_config_1 = "simulation_config_1"
simulation_config_2 = "simulation_config_2"

benchmarks={
    benchmark_config_1 : {
        detector_config_1 : [simulation_config_1, simulation_config_2]
    }
}


def run_multiple_benchmarks(benchmark_dict):

    #use snakemake to run benchmarks on multiple compute nodes
    print("running benchmarks")
    