In [1]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.insert(0, "../")

In [2]:
import math
import pydot

In [3]:
class Color:
    def __init__(self, r=0, g=0, b=0):
        self._r = Color.clamp(r)
        self._g = Color.clamp(g)
        self._b = Color.clamp(b)
    def r(self):
        return self._r
    def g(self):
        return self._g
    def b(self):
        return self._b
    
    def set_r(self, val):
        self._r = Color.clamp(val)
    def set_g(self, val):
        self._g = Color.clamp(val)
    def set_b(self, val):
        self._b = Color.clamp(val)
        
    def as_hex(self):
        return "#{0:02x}{1:02x}{2:02x}".format(self.r(), self.g(), self.b())
    
    @staticmethod
    def mix(one, other, mix_fac=0.5):
        assert((type(one) is Color) or (type(other) is Color))
        if (one is None) or (type(one) is not Color):
            return other
        elif(other is None) or (type(other) is not Color):
            return one
        return Color(
            r=Color.clamp(math.sqrt((1 - mix_fac) * math.pow(one.r(), 2) + mix_fac * math.pow(other.r(), 2))), 
            g=Color.clamp(math.sqrt((1 - mix_fac) * math.pow(one.g(), 2) + mix_fac * math.pow(other.g(), 2))), 
            b=Color.clamp(math.sqrt((1 - mix_fac) * math.pow(one.b(), 2) + mix_fac * math.pow(other.b(), 2)))
        )
    @staticmethod
    def clamp(val): 
        return round(max(0, min(val, 255)))

In [4]:
def load_dot_files(file_dir, prefix, max_epoch, min_epoch=0):
    dots = []
    for i in range(min_epoch, max_epoch+1):
        dots.append(pydot.graph_from_dot_file(file_dir + prefix + str(i))[0])
    return dots

In [23]:
def format_node(node):
    attributes = node.get_attributes()
    new_color = None
    if 'label' not in attributes:
        return
    if attributes['label'] == 'node':
        return
    # mix in the color for the skip attribute
    if 'RTrue' in attributes['label']:
        new_color = Color.mix(new_color, Color(r=50, g=255, b=50))
    if 'RFalse' in attributes['label']:
        new_color = Color.mix(new_color, Color(r=255, g=50, b=50))
    # possibly mix in colors for additionally ENAS attributes
    # ...
    
    # set the new color in the graph
    if new_color:
        node.set_color(new_color.as_hex())

In [14]:
def format_graphs(graphs):
    for graph in graphs:
        node_list = graph.get_node_list()
        for node in node_list:
            format_node(node)

In [15]:
def write_graphs_to_pdf(graphs, out_dir='./', prefix=''):
    ctr = 0
    for graph in graphs:
        graph.write_pdf(out_dir + prefix + str(ctr) + '.pdf')
        ctr += 1

In [24]:
dot_file_dir = '../data/logs/meliusnet22/architectures/'
graphs = load_dot_files(dot_file_dir, 'epoch_', min_epoch=0, max_epoch=2)
format_graphs(graphs)

In [25]:
write_graphs_to_pdf(graphs, prefix='test_graph_')

In [39]:
red = Color(r=255)
blue = Color(b=255)

In [45]:
mixed = Color.mix(red, blue)

In [46]:
print(mixed.as_hex())

#b400b4


False