# ~~~~~ Day 18 ~~~~~ 

In [112]:
from functools import reduce
from itertools import product
from math import ceil, floor

with open('d18.txt', 'r') as f:
    content = f.readlines()


def parse(term):
    vals = []
    lvls = []
    lvl = 0
    for c in term:
        if c == '[':
            lvl += 1
        elif c == ']':
            lvl -= 1
        elif c == ',':
            lvls.append(lvl)
        elif c.isdigit():
            vals.append(int(c))
    return vals, lvls


def assemble(term):
    vals, lvls = term
    result, prv_lvl = '', 0
    for idx, lvl in enumerate(lvls):
        if lvl > prv_lvl:
            result += '[' * (lvl - prv_lvl) + str(vals[idx]) + ','
        else:
            result += str(vals[idx]) + ']' * (prv_lvl - lvl) + ','
        prv_lvl = lvl
    result += str(vals[-1]) + ']' * prv_lvl
    return result


def explode(term):
    vals, lvls = term
    try:
        idx = lvls.index(5)
    except ValueError:
        return term, False

    new_vals = vals[::]
    if idx > 0:
        new_vals[idx - 1] += vals[idx]
    if idx + 2 < len(new_vals):
        new_vals[idx + 2] += vals[idx + 1]
    new_vals = new_vals[:idx] + [0] + new_vals[idx + 2:]
    lvls = lvls[:idx] + lvls[idx + 1:]
    return (new_vals, lvls), True


def split(term):
    vals, lvls = term
    if all((splitval := val) < 10 for val in vals):
        return term, False

    idx = vals.index(splitval)
    new_lvl = 1 + max(
        lvls[idx - 1] if idx > 0 else 0,
        lvls[idx] if idx < len(lvls) else 0,
    )
    vals = vals[:idx] + [floor(splitval / 2), ceil(splitval / 2)] + vals[idx + 1:]
    lvls = lvls[:idx] + [new_lvl] + lvls[idx:]
    return (vals, lvls), True


def magnitude(term):
    def combine(stack):
        stack.append(2 * stack.pop() + 3 * stack.pop())

    vals, lvls = term
    prvlvl, ms = 0, []
    for val, dep in zip(vals, lvls):
        ms.append(val)
        for _ in range(prvlvl - dep):
            combine(ms)
        prvlvl = dep
    ms.append(vals[-1])
    for _ in range(prvlvl):
        combine(ms)
    return ms[0]


def add(left, right):
    (lv, llvl), (rv, rlvl) = left, right
    term = (lv + rv, [1 + d for d in llvl] + [1] + [1 + d for d in rlvl])
    while True:
        term, b = explode(term)
        if b: continue
        term, b = split(term)
        if not b: break
    return term


terms = [parse(li.strip()) for li in content]

print(f'Magnitude: {magnitude(reduce(add, terms))}')
print(f'Max magnitude: {max(magnitude(add(left, right)) for left, right in product(terms, repeat=2) if left != right)}')


Magnitude: 4057
Max magnitude: 4683
