# day 18

https://adventofcode.com/2021/day/18

In [None]:
import logging
import logging.config
import os

import yaml

In [None]:
with open('../logging.yaml') as fp:
    logging_config = yaml.load(fp, Loader=yaml.FullLoader)

logging.config.dictConfig(logging_config)

In [None]:
FNAME = os.path.join('data', 'day18.txt')

LOGGER = logging.getLogger('day18')

## part 1

### problem statement:

#### loading data

In [None]:
test_data = """[[[0,[5,8]],[[1,7],[9,6]]],[[4,[1,2]],[[1,4],2]]]
[[[5,[2,8]],4],[5,[[9,9],0]]]
[6,[[[6,2],[5,6]],[[7,6],[4,7]]]]
[[[6,[0,7]],[0,9]],[4,[9,[9,0]]]]
[[[7,[6,4]],[3,[1,3]]],[[[5,5],1],9]]
[[6,[[7,3],[3,2]]],[[[3,8],[5,7]],4]]
[[[[5,4],[7,7]],8],[[8,3],8]]
[[9,3],[[9,9],[6,[4,9]]]]
[[2,[[7,7],7]],[[5,8],[[9,3],[0,2]]]]
[[[[5,2],5],[8,[3,7]]],[[5,[7,5]],[4,4]]]""".strip().split('\n')

In [None]:
def load_data(fname=FNAME):
    with open(fname) as fp:
        return [line.strip() for line in fp]

In [None]:
import networkx as nx

def parse_snailfish_number(sn):
    """build a tree from a snailfish number"""
    g = nx.DiGraph()
    g.add_node('')
    left, right = sn
    add_branches(g, '', left, right)
    return g

def add_branches(g, node_name, left, right):
    left_node_name = node_name + 'l'
    right_node_name = node_name + 'r'
    g.add_edge(node_name, left_node_name, direction='left')
    g.add_edge(node_name, right_node_name, direction='right')
    
    if isinstance(left, int):
        g.nodes[left_node_name]['value'] = left
    else:
        left_left, left_right = left
        add_branches(g, left_node_name, left_left, left_right)
    
    if isinstance(right, int):
        g.nodes[right_node_name]['value'] = right
    else:
        right_left, right_right = right
        add_branches(g, right_node_name, right_left, right_right)

In [None]:
sn = eval(test_data[0])
sn

In [None]:
g = parse_snailfish_number(sn)

nx.draw_kamada_kawai(g, with_labels=True, node_size=1_000)

In [None]:
def add_graphs(g_left, g_right):
    # relabel all nodes
    g_left = nx.relabel_nodes(g_left, {k: 'l' + k for k in g_left.nodes})
    g_right = nx.relabel_nodes(g_right, {k: 'r' + k for k in g_right.nodes})
    g = nx.compose(g_left, g_right)
    g.add_node('')
    g.add_edge('', 'l', direction='left')
    g.add_edge('', 'r', direction='right')
    return g

In [None]:
def graphs_match(g0, g1):
    return (g0.nodes == g1.nodes) and (g0.edges == g1.edges)

In [None]:
g_left = parse_snailfish_number([1,2])
g_right = parse_snailfish_number([[3,4],5])
g = add_graphs(g_left, g_right)

g_answer = parse_snailfish_number([[1,2],[[3,4],5]])

assert graphs_match(g, g_answer)

In [None]:
def get_value_nodes(g):
    """return only nodes with values"""
    return sorted([n for n in g if g.nodes[n].get('value') is not None])

In [None]:
sn = [[[[[9,8],1],2],3],4]
g = parse_snailfish_number(sn)
get_value_nodes(g)

In [None]:
def get_left_neighbor(node_name, g):
    vns = get_value_nodes(g)
    i = vns.index(node_name)
    return vns[i - 1] if i > 0 else None

def get_right_neighbor(node_name, g):
    vns = get_value_nodes(g)
    L = len(vns)
    i = vns.index(node_name)
    return vns[i + 1] if i < L - 1 else None

In [None]:
sn = [[[[[9,8],1],2],3],4]
g = parse_snailfish_number(sn)
assert get_left_neighbor('lllll', g) is None
assert get_left_neighbor('llllr', g) == 'lllll'
assert get_right_neighbor('llllr', g) == 'lllr'
assert get_right_neighbor('r', g) is None

In [None]:
def should_explode(node_name, g):
    """we should explode if this node is 4 levels down from the origin"""
    return len(node_name) >= 5

def explode(left_node_name, g):
    """node_name is the left element in a pair that should be exploded"""
    right_node_name = left_node_name[:-1] + 'r'
    assert right_node_name in g
    
    left_neighbor = get_left_neighbor(left_node_name, g)
    right_neighbor = get_right_neighbor(right_node_name, g)
    
    if left_neighbor is not None:
        g.nodes[left_neighbor]['value'] += g.nodes[left_node_name]['value']
    if right_neighbor is not None:
        g.nodes[right_neighbor]['value'] += g.nodes[right_node_name]['value']
    
    # the parent node of left / right gets value set to 0
    g.nodes[left_node_name[:-1]]['value'] = 0
    
    # left and right node are removed
    g.remove_nodes_from([left_node_name, right_node_name])

In [None]:
sn_before_after = [
    ([[[[[9,8],1],2],3],4], [[[[0,9],2],3],4]),
    ([7,[6,[5,[4,[3,2]]]]], [7,[6,[5,[7,0]]]]),
    ([[6,[5,[4,[3,2]]]],1], [[6,[5,[7,0]]],3]),
    ([[3,[2,[1,[7,3]]]],[6,[5,[4,[3,2]]]]], [[3,[2,[8,0]]],[9,[5,[4,[3,2]]]]]),
    ([[3,[2,[8,0]]],[9,[5,[4,[3,2]]]]], [[3,[2,[8,0]]],[9,[5,[7,0]]]])
]

for sn_before, sn_after in sn_before_after:
    g = parse_snailfish_number(sn_before)
    g_answer = parse_snailfish_number(sn_after)
    for node_name in get_value_nodes(g):
        if should_explode(node_name, g):
            explode(node_name, g)
            assert graphs_match(g, g_answer)
            break

In [None]:
import math

def should_split(node_name, g):
    return g.nodes[node_name]['value'] >= 10

def split(node_name, g):
    # get new left and right value
    node_val = g.nodes[node_name]['value']
    left_val = math.floor(node_val / 2)
    right_val = math.ceil(node_val / 2)
    
    # add as new nodes
    g.add_node(node_name + 'l', value=left_val)
    g.add_node(node_name + 'r', value=right_val)
    g.add_edge(node_name, node_name + 'l', direction='left')
    g.add_edge(node_name, node_name + 'r', direction='right')
    
    # remove value from node
    _ = g.nodes[node_name].pop('value')

In [None]:
g0 = parse_snailfish_number([[[[0,7],4],[15,[0,13]]],[1,1]])
g1 = parse_snailfish_number([[[[0,7],4],[[7,8],[0,13]]],[1,1]])
g2 = parse_snailfish_number([[[[0,7],4],[[7,8],[0,[6,7]]]],[1,1]])

for node_name in get_value_nodes(g0):
    if should_split(node_name, g0):
        split(node_name, g0)
        break
assert graphs_match(g0, g1)

for node_name in get_value_nodes(g0):
    if should_split(node_name, g0):
        split(node_name, g0)
        break
assert graphs_match(g0, g2)

In [None]:
def reduce(g):
    while True:
        made_changes = False
        
        # check for explodable
        for node_name in get_value_nodes(g):
            if should_explode(node_name, g):
                explode(node_name, g)
                #LOGGER.debug(str({n: g.nodes[n]['value'] for n in get_value_nodes(g)}))
                #LOGGER.debug(g.nodes(data=True))
                made_changes = True
                break
        
        if made_changes:
            continue
        
        # check for split
        for node_name in get_value_nodes(g):
            if should_split(node_name, g):
                split(node_name, g)
                #LOGGER.debug(str({n: g.nodes[n]['value'] for n in get_value_nodes(g)}))
                #LOGGER.debug(g.nodes(data=True))
                made_changes = True
                break
        
        if made_changes:
            continue
        
        # if we got here, no explosions and no splits, it's reduced
        assert not made_changes
        return

In [None]:
g = parse_snailfish_number([[[[[4,3],4],4],[7,[[8,4],9]]],[1,1]])
g_answer = parse_snailfish_number([[[[0,7],4],[[7,8],[6,0]]],[8,1]])
reduce(g)
assert graphs_match(g, g_answer)

In [None]:
def get_magnitude(node_name, g):
    try:
        return g.nodes[node_name]['value']
    except KeyError:
        s = 0
        for child_node_name, edge_data in g[node_name].items():
            direction = edge_data['direction']
            if direction == 'left':
                s += 3 * get_magnitude(child_node_name, g)
            elif direction == 'right':
                s += 2 * get_magnitude(child_node_name, g)
        return s

In [None]:
sn_mag = [
    ([9,1], 29),
    ([1, 9], 21),
    ([[1,2],[[3,4],5]], 143),
    ([[[[0,7],4],[[7,8],[6,0]]],[8,1]], 1384),
    ([[[[1,1],[2,2]],[3,3]],[4,4]], 445),
    ([[[[3,0],[5,3]],[4,4]],[5,5]], 791),
    ([[[[5,0],[7,4]],[5,5]],[6,6]], 1137),
    ([[[[8,7],[7,7]],[[8,6],[7,7]]],[[[0,7],[6,6]],[8,7]]], 3488),
]

for (sn, mag) in sn_mag:
    g = parse_snailfish_number(sn)
    assert get_magnitude('', g) == mag

In [None]:
def get_final_sum(graphs):
    """given multiple graphs, add then reduce them left to right"""
    g_accumulator = graphs[0]
    for g_next in graphs[1:]:
        #LOGGER.debug(g_accumulator.nodes(data=True))
        g_accumulator = add_graphs(g_accumulator, g_next)
        reduce(g_accumulator)
    #LOGGER.debug(g_accumulator.nodes(data=True))
    return g_accumulator

In [None]:
input_sns = [[1,1], [2,2], [3,3], [4,4]]
answer_sn = [[[[1,1],[2,2]],[3,3]],[4,4]]
graphs = [parse_snailfish_number(sn) for sn in input_sns]
assert graphs_match(get_final_sum(graphs),
                    parse_snailfish_number(answer_sn))

In [None]:
input_sns = [[1,1], [2,2], [3,3], [4,4], [5,5]]
answer_sn = [[[[3,0],[5,3]],[4,4]],[5,5]]
graphs = [parse_snailfish_number(sn) for sn in input_sns]
assert graphs_match(get_final_sum(graphs),
                    parse_snailfish_number(answer_sn))

In [None]:
input_sns = [[1,1], [2,2], [3,3], [4,4], [5,5], [6,6]]
answer_sn = [[[[5,0],[7,4]],[5,5]],[6,6]]
graphs = [parse_snailfish_number(sn) for sn in input_sns]
assert graphs_match(get_final_sum(graphs),
                    parse_snailfish_number(answer_sn))

In [None]:
input_sns = [
    [[[0,[4,5]],[0,0]],[[[4,5],[2,6]],[9,5]]],
    [7,[[[3,7],[4,3]],[[6,3],[8,8]]]],
    [[2,[[0,8],[3,4]]],[[[6,7],1],[7,[1,6]]]],
    [[[[2,4],7],[6,[0,5]]],[[[6,8],[2,8]],[[2,1],[4,5]]]],
    [7,[5,[[3,8],[1,4]]]],
    [[2,[2,2]],[8,[8,1]]],
    [2,9],
    [1,[[[9,3],9],[[9,0],[0,7]]]],
    [[[5,[7,4]],7],1],
    [[[[4,2],2],6],[8,7]]
]
answer_sn = [[[[8,7],[7,7]],[[8,6],[7,7]]],[[[0,7],[6,6]],[8,7]]]
graphs = [parse_snailfish_number(sn) for sn in input_sns]
assert graphs_match(get_final_sum(graphs),
                    parse_snailfish_number(answer_sn))

#### function def

In [None]:
def q_1(data):
    sns = [eval(line) for line in data]
    graphs = [parse_snailfish_number(sn) for sn in sns]
    g_sum = get_final_sum(graphs)
    return get_magnitude(node_name='', g=g_sum)

#### tests

In [None]:
def test_q_1():
    LOGGER.setLevel(logging.DEBUG)
    assert q_1(test_data) == 4140
    LOGGER.setLevel(logging.INFO)

In [None]:
test_q_1()

#### answer

In [None]:
q_1(load_data())

## part 2

### problem statement:

#### function def

In [None]:
list([1, 2, 3], 2))

In [None]:
import itertools

def q_2(data):
    sns = [eval(line) for line in data]
    graphs = [parse_snailfish_number(sn) for sn in sns]
    sums = [get_magnitude('', get_final_sum([g_a, g_b]))
            for (g_a, g_b) in itertools.permutations(graphs, 2)]
    return max(sums)

#### tests

In [None]:
def test_q_2():
    LOGGER.setLevel(logging.DEBUG)
    assert q_2(test_data) == 3993
    LOGGER.setLevel(logging.INFO)

In [None]:
test_q_2()

#### answer

In [None]:
q_2(load_data())

fin