# Day 7
https://adventofcode.com/2017/day/7

In [1]:
import aocd
data = aocd.get_data(year=2017, day=7)

In [2]:
from dataclasses import dataclass
from typing import Tuple
import regex as re

In [3]:
@dataclass(frozen=True)
class Node():
    name: str
    weight: int
    children: Tuple[object]
        
    @classmethod
    def from_node_info(cls, weights, children, name):
        directchildren = tuple(
            cls.from_node_info(weights, children, childname) for childname in children.get(name, [])
        )
        return cls(name, weights.get(name, 0), directchildren)
    
    @property
    def child_weights(self):
        return [child.total_weight for child in self.children]
    
    @property
    def total_weight(self):
        return self.weight + sum(self.child_weights)

In [4]:
def find_bottom(weights, children):
    all_children = set()
    for node, directchildren in children.items():
        all_children = all_children.union(directchildren)
    
    all_nodes = set(weights.keys())
    
    return next(iter(all_nodes.difference(all_children)))

In [5]:
re_weights = re.compile(r'(\w+) \((\d+)\)')
re_children = re.compile(r'(\w+) \(\d+\) -> ([\w, ]+)')
def read_tower(text):
    weights = dict((name, int(weight)) for name, weight in re_weights.findall(text))
    children = dict((name, children.split(', ')) for name, children in re_children.findall(text))
    bottom = find_bottom(weights, children)
    return Node.from_node_info(weights, children, bottom)

In [6]:
def correct_weight(node, correction=0):
    child_weights = node.child_weights
    def is_standard_weight(weight):
        return sum(1 for cw in child_weights if cw == weight) > 1
    
    standard_weights = [cw for cw in child_weights if is_standard_weight(cw)]
    non_standard_weights = [cw for cw in child_weights if not is_standard_weight(cw)]
    
    if len(non_standard_weights) == 0:
        return node.weight + correction
    
    node_to_correct = next(child for ix, child
                           in enumerate(node.children)
                           if child_weights[ix] == non_standard_weights[0])
    
    return correct_weight(node_to_correct, standard_weights[0] - non_standard_weights[0])

In [7]:
base = read_tower(data)
print('Part 1: {}'.format(base.name))
print('Part 2: {}'.format(correct_weight(base)))

Part 1: ykpsek
Part 2: 1060
