# Part 1

In [1]:
def parse_separated(text, sep):
    a, _, b = text.strip().partition(sep)
    return a, b

In [2]:
from functools import partial

In [3]:
parse_line = partial(parse_separated, sep="->")
parse_node = partial(parse_separated, sep=" ")

In [4]:
node, children = parse_separated("a (x) -> b, c, d", "->")

In [5]:
name, value = parse_separated(node, " ")

In [6]:
name, value

('a', '(x)')

In [7]:
# the same can be achieved with str.split, I just wanted recursion :)
def parse_kids(items, sep=", "):
    a, b = parse_separated(items, sep)
    if b == "":
        return [a]
    return [a] + parse_kids(b, sep)

In [8]:
parse_kids(children, ", ")

['b', 'c', 'd']

In [9]:
parse_kids("", ", ")

['']

In [10]:
from more_itertools import with_iter

In [11]:
lines = [line.strip() for line in with_iter(open('day/7/input'))]

In [12]:
nodes = set()
children = set()
for l in lines:
    node, kinds = parse_line(l)
    name, _ = parse_node(node)
    nodes.add(name)
    if kinds:
        children |= set(parse_kids(kinds))

In [13]:
# The only node that's not listed as anyone's child is the root
nodes - children

{'dgoocsw'}

# Part 2

In [14]:
from collections import namedtuple

In [15]:
Node = namedtuple("Node", "weight, parent, children")

In [16]:
nodes = {"ROOT": Node(0, None, set())}
for line in lines:
    node, kids = parse_line(line)
    name, weight = parse_node(node)
    children = parse_kids(kids)
    if name not in nodes:
        nodes[name] = Node(eval(weight), "ROOT", set(children))
        nodes['ROOT'].children.add(name)
    else:
        nodes[name] = nodes[name]._replace(children=set(children), weight=eval(weight))
    for c in children:
        if c in nodes:
            old_parent = nodes[c].parent
            nodes[old_parent].children.discard(c)
            nodes[c] = nodes[c]._replace(parent=name)
        else:
            nodes[c] = Node(0, name, set())
        

In [17]:
from more_itertools import all_equal

def check_weights(node):
    if not node.children:
        return node.weight, node.weight
    else:
        cweights, totals = zip(*[check_weights(nodes[cname]) for cname in node.children])
        if not all_equal(totals):
            raise ValueError("Unbalanced disk! {}".format(", ".join(map(str, zip(totals, cweights)))))
        return node.weight, node.weight + sum(totals)

In [18]:
check_weights(nodes["ROOT"])

ValueError: Unbalanced disk! (1815, 15), (1815, 987), (1823, 1283), (1815, 184), (1815, 1284)

In [19]:
1283 - abs(1815 - 1823)

1275