In [13]:
with open('input.txt', 'r') as f:
    puzzle = f.read().strip()

In [15]:
def create_graph(s):
    lines = s.split('\n')
    graph = {}
    weight = {}
    for line in lines:
        node, *edges = line.split(' -> ')
        if edges != []:
            edges = edges[0].split(', ')
        node, w = node.split(' ')
        weight[node] = int(w.strip('()'))
        graph[node] = edges
    return graph, weight

In [5]:
example="""pbga (66)
xhth (57)
ebii (61)
havc (66)
ktlj (57)
fwft (72) -> ktlj, cntj, xhth
qoyq (66)
padx (45) -> pbga, havc, qoyq
tknk (41) -> ugml, padx, fwft
jptl (61)
ugml (68) -> gyxo, ebii, jptl
gyxo (61)
cntj (57)"""

In [6]:
G, W = create_graph(example)
G, W

({'cntj': [],
  'ebii': [],
  'fwft': ['ktlj', 'cntj', 'xhth'],
  'gyxo': [],
  'havc': [],
  'jptl': [],
  'ktlj': [],
  'padx': ['pbga', 'havc', 'qoyq'],
  'pbga': [],
  'qoyq': [],
  'tknk': ['ugml', 'padx', 'fwft'],
  'ugml': ['gyxo', 'ebii', 'jptl'],
  'xhth': []},
 {'cntj': 57,
  'ebii': 61,
  'fwft': 72,
  'gyxo': 61,
  'havc': 66,
  'jptl': 61,
  'ktlj': 57,
  'padx': 45,
  'pbga': 66,
  'qoyq': 66,
  'tknk': 41,
  'ugml': 68,
  'xhth': 57})

In [7]:
from collections import defaultdict

def reverse(graph):
    reverse_graph = defaultdict(list)
    for parent, edges in graph.items():
        reverse_graph[parent].extend([])
        for child in edges:
            reverse_graph[child].append(parent)
    return dict(reverse_graph)

In [8]:
reverse(G)

{'cntj': ['fwft'],
 'ebii': ['ugml'],
 'fwft': ['tknk'],
 'gyxo': ['ugml'],
 'havc': ['padx'],
 'jptl': ['ugml'],
 'ktlj': ['fwft'],
 'padx': ['tknk'],
 'pbga': ['padx'],
 'qoyq': ['padx'],
 'tknk': [],
 'ugml': ['tknk'],
 'xhth': ['fwft']}

In [9]:
def find_parent(G):
    node = next(iter(G.keys()))
    while len(G[node]) > 0:
        node = G[node][0]
    return node

In [10]:
find_parent(reverse(G))

'tknk'

In [14]:
find_parent(reverse(create_graph(puzzle)[0]))

'dtacyn'

## P7.2

In [16]:
def total_weight(graph, weight):
    total_weight = {}
    def _compute_weight(node):
        if node not in total_weight:
            w = weight[node]
            w += sum( _compute_weight(child) for child in graph[node])
            total_weight[node] = w
        return total_weight[node]

    P = find_parent(reverse(graph))
    _compute_weight(P)
    
    return total_weight
    

In [58]:
from collections import Counter

def find_unbalanced(graph, weight):
    parent = find_parent(reverse(graph))
    _total_weight = total_weight(graph, weight)

    proper_parent_weight = _total_weight[parent]
    while True:
        children_weights = [_total_weight[child] for child in graph[parent]]
        if len(set(children_weights)) > 1:
            counts = Counter(children_weights).most_common()
            proper_parent_weight = counts[0][0]
            parent = [child  for child in graph[parent] if _total_weight[child] != proper_parent_weight]
            parent = parent[0]
        else:
            return proper_parent_weight - sum(children_weights)

In [61]:
G, W = create_graph(example)
find_unbalanced(G, W)

60

In [62]:
G, W = create_graph(puzzle)
find_unbalanced(G, W)

521