##  ✨ [Day 8](https://adventofcode.com/2018/day/8)

In [0]:
class Node:
  def __init__(self, name):
    """Init a very simple Tree structure"""
    self.name = name
    self.metadata = None
    self.children = []      

  def sum_metadata(self):
    """Sum all the metadata of all node"""
    return sum(self.metadata) + sum(c.sum_metadata() for c in self.children)
  
  def get_value(self, acc={}):
    """Get the value of the current node as defined in the Part two of the problem"""
    if len(self.children) == 0:
      return sum(self.metadata)
    else:
      s = 0
      for m in self.metadata:
        if m - 1 < len(self.children):
          c = self.children[m - 1]
          if not c in acc:
            v = c.get_value(acc=acc)
            acc[c] = v
          s += acc[c]
      return s
    
def parse_tree(inputs, depth=0, start_index=0):
  """Parse the inputs intro a tree"""
  if len(inputs) >= 3:
    num_children = inputs[start_index]
    num_metadata = inputs[start_index + 1]
    # leaves
    if num_children == 0:
      node = Node(chr(depth + 65))
      node.metadata = inputs[start_index + 2:start_index + 2 + num_metadata]
      return node, start_index + 2 + num_metadata, depth + 1
    else:
      # collect every child
      node = Node(chr(depth + 65))
      start_index += 2
      depth += 1
      for _ in range(num_children):
        child, start_index, depth = parse_tree(inputs, depth=depth, start_index=start_index)
        node.children.append(child)
      # metadata
      node.metadata = inputs[start_index:start_index + num_metadata]
      return node, start_index + num_metadata, depth

In [2]:
with open("day8.txt", 'r') as f:
  inputs = list(map(int, f.read().split(' ')))
  
tree, _, _ = parse_tree(inputs)
print("Sum of all metadata:", tree.sum_metadata())
print("Value of the root node:", tree.get_value())

Sum of all metadata: 40701
Value of the root node: 21399


In [6]:
#@title Visualize a tree (toy example)
import collections

def get_header(tree, acc, offset=0, depth=0):
  header = '%d %d ' % (len(tree.children), len(tree.metadata))
  if len(tree.children):
    for c in tree.children:
      h = get_header(c, acc, offset=offset + len(header), depth=depth + 1)
      header += h + ' '
  header += ' '.join(map(str, tree.metadata))
  acc[depth].append((tree.name, offset, len(header) - len(tree.name)))
  return header
    
def pretty_print(tree):
  acc = collections.defaultdict(lambda: [])
  header = get_header(tree, acc)
  print(header)
  for depth in sorted(acc.keys()):
    s = ''
    for c, o, n in sorted(acc[depth], key=lambda x: x[1]):
      s += ' ' * (o - len(s)) + c + '-' * n
    print(s)
      
tree = parse_tree(list(map(int, "2 3 0 3 10 11 12 1 1 0 1 99 2 1 1 2".split())))[0]
pretty_print(tree)

2 3 0 3 10 11 12 1 1 0 1 99 2 1 1 2
A----------------------------------
    B----------- C-----------
                     D-----
