# Day 18 - Snailfish

## Read input

In [101]:
from utils import read_input

ABORT = [False]
num_input = read_input(18, eval)

## Part 1

## Let's build up our snailfish calculus

We start by adding `+` operation via `add` function. It takes two lists and returns a list with parameters as its items.

In [2]:
# ADD MATH

def add(a, b):
    return [a, b]

def is_int(n):
    return type(n) == int

In [3]:
# SPLIT MATH
import math

def should_split(number):
    if is_int(number):
        return number >= 10
    
    left, right = number
    return should_split(left) or should_split(right)

def split(number):
    if ABORT[0]:
        return number
    if is_int(number):
        if number >= 10:
            ABORT[0] = True
            return num_split(number)
        else:
            return number
        
    left, right = number
    return [split(left), split(right)]

def num_split(natural_number):
    left = math.floor(natural_number / 2)
    right = math.ceil(natural_number / 2)
    return [left, right]

Explosions all around.

In [67]:
# EXPLODE MATH
from collections import namedtuple

E = namedtuple('E', ["number", "left", "right"])
Extras = namedtuple('EX', ['fromleft', "fromright"])

def is_regular_number_pair(number):
    match number:
        case E() as e:
            number = e.number
    if is_int(number):
        return False
    left, right = number
    return is_int(left) and is_int(right)
    
    
def replace_leftmost(tree, value):
    if is_int(tree[0]):
        tree[0] += value
    else:
        tree[0] = replace_leftmost(tree[0], value)
    return tree

def replace_rightmost(tree, value):
    if is_int(tree[1]):
        tree[1] += value
    else:
        tree[1] = replace_rightmost(tree[1], value)
    return tree

def explode(number, d):
    if is_int(number):
        return number
    
    if ABORT[0]:
        return E(number, left=None, right=None)
    
    if is_regular_number_pair(number) and d >= 4:
        ABORT[0] = True
        match number:
            case E(left=left, right=right) as e:
                return E(0, left, right)
            case _:
                return E(0, *number)
            
    # These are the explosion bits that need to carry
    new_left, new_right = None, None
    
    left = explode2(number[0], d+1)
    match left:

        # Something to our left exploded
        case E() as e:
            # We came from left so we replace our left child with the explosion's number
            number[0] = e.number
            # Do we have a number to carry?
            if e.right:
                if is_int(number[1]): # Is our right child a number?
                    number[1] += e.right
                else: # It's not so we need to find its leftmost value 
                    number[1] = replace_leftmost(number[1], e.right)
            
            new_left = e.left
        # Nothing exploded, let's just keep it like it is
        case _:
            number[0] = left
            
    right = explode2(number[1], d+1)
    match right:
        # Something to our right exploded
        case E() as e:
            # We came from right so we replace right child with explosion's number
            number[1] = e.number
            if e.left:
                if is_int(number[0]):
                    number[0] += e.left
                else:
                    number[0] = replace_rightmost(number[0], e.left)
            
            new_right = e.right
        case _:
            number[1] = right
    
    return E(number, left=new_left, right=new_right)
        
            
        

def depth(number):
    left, right = number
    nl = is_int(left)
    nr = is_int(right)
    if nl and nr:
        return 1
    
    if nl and not nr:
        return 1 + max(1, depth(right))
    if not nl and nr:
        return 1 + max(depth(left), 1)
    
    return 1+ max(depth(left), depth(right))

def should_explode(number):
    d = depth(number) - 1
    return d >= 4

Reducing

In [58]:
# REDUCER
    
def reduce(number):
    keep_reducing = True
    # If any nested in four pairs, leftmost pair explodes
    # If any number >= 10 leftmost number splits
    # Repeat until no such thing happens
    # Always go back to start after applying any operation
    while keep_reducing:
        if should_explode(number):
            number = explode(number, 0).number
            ABORT[0] = False
            continue
        if should_split(number):
            number = split(number)
            ABORT[0] = False
            continue
        keep_reducing = False

    return number

## Part 1

In [102]:
def main_loop(num_input):
    ABORT[0] = False
    number = num_input[0]
    for new_number in num_input[1:]:
        number = add(number, new_number)
        number = reduce(number)
    return number

final_sum = main_loop(num_input)

magnitude = consume(final_sum)
print(magnitude)

4173


> To check whether it's the right answer, the snailfish teacher only checks the magnitude of the final sum. The magnitude of a pair is 3 times the magnitude of its left element plus 2 times the magnitude of its right element. The magnitude of a regular number is just that number.

In [98]:
def consume(tree):
    if is_int(tree):
        return tree
    
    left, right = tree
    return consume(left) * 3 + consume(right) * 2

In [100]:
import unittest
import inspect


class ExampleTests(unittest.TestCase):


    def test_single_explodes(self):
        print("::TEST::", inspect.stack()[0][3])
        examples = [
            ([[[[[9,8],1],2],3],4], [[[[0,9],2],3],4]),
            ([7,[6,[5,[4,[3,2]]]]], [7,[6,[5,[7,0]]]]),
            ([[6,[5,[4,[3,2]]]],1], [[6,[5,[7,0]]],3]),
            ([[3,[2,[1,[7,3]]]],[6,[5,[4,[3,2]]]]], [[3,[2,[8,0]]],[9,[5,[4,[3,2]]]]]),
            ([[3,[2,[8,0]]],[9,[5,[4,[3,2]]]]], [[3,[2,[8,0]]],[9,[5,[7,0]]]])
        ]
        
        case = 1
        for inp, out in examples:
            case += 1
            ABORT[0] = False
            res = explode(inp, 0).number
            self.assertEqual(res, out)
     
    def test_first_full_example(self):
        print("::TEST::", inspect.stack()[0][3])
        numbers = [
            [[[[4,3],4],4],[7,[[8,4],9]]],
            [1,1]
        ]
        
        add_result = add(*numbers)
        self.assertEqual(add_result, [[[[[4,3],4],4],[7,[[8,4],9]]],[1,1]])
        
        ABORT[0] = False
        first_explode = explode(add_result, 0).number
        fe_out = [[[[0,7],4],[7,[[8,4],9]]],[1,1]]
        self.assertEqual(first_explode, fe_out)
        
        
        ABORT[0] = False
        second_explode = explode(first_explode, 0).number
        se_out = [[[[0,7],4],[15,[0,13]]],[1,1]]
        self.assertEqual(second_explode, se_out)
        
        self.assertFalse(should_explode(se_out))
        self.assertTrue(should_split(second_explode))
        
        ABORT[0] = False
        
        
        first_split = split(second_explode)
        fs_out = [[[[0,7],4],[[7,8],[0,13]]],[1,1]]
        
        self.assertEqual(first_split, fs_out)
        
        self.assertTrue(should_split(second_explode))
        
        ABORT[0] = False
        
        second_split = split(first_split)
        ss_out = [[[[0,7],4],[[7,8],[0,[6,7]]]],[1,1]]
        
        self.assertEqual(second_split, ss_out)
    

        res = main_loop(numbers)
        out = [[[[0,7],4],[[7,8],[6,0]]],[8,1]]
        
        self.assertEqual(res, out)
    
    def test_next_examples(self):
        print("::TEST::", inspect.stack()[0][3])
        
        examples = [
            ([[1,1], [2,2], [3,3], [4,4]], [[[[1,1], [2,2]],[3,3]],[4,4]]),
            ([[1,1], [2,2], [3,3], [4,4], [5,5]], [[[[3,0],[5,3]],[4,4]],[5,5]]),
            ([[1,1], [2,2], [3,3], [4,4], [5,5], [6,6]], [[[[5,0],[7,4]],[5,5]],[6,6]]),
        ]
        
        for t_case, (inp, out) in enumerate(examples):
            res = main_loop(inp)
            self.assertEqual(res, out)
            
    def test_slightly_larger_example(self):
        print("::TEST::", inspect.stack()[0][3])
        inp = [
            [[[0,[4,5]],[0,0]],[[[4,5],[2,6]],[9,5]]],
            [7,[[[3,7],[4,3]],[[6,3],[8,8]]]],
            [[2,[[0,8],[3,4]]],[[[6,7],1],[7,[1,6]]]],
            [[[[2,4],7],[6,[0,5]]],[[[6,8],[2,8]],[[2,1],[4,5]]]],
            [7,[5,[[3,8],[1,4]]]],
            [[2,[2,2]],[8,[8,1]]],
            [2,9],
            [1,[[[9,3],9],[[9,0],[0,7]]]],
            [[[5,[7,4]],7],1],
            [[[[4,2],2],6],[8,7]],
        ]
        
        out = [[[[8,7],[7,7]],[[8,6],[7,7]]],[[[0,7],[6,6]],[8,7]]]
       
        res = main_loop(inp)
        self.assertEqual(res, out)
        
    def test_last_example(self):
        print("::TEST::", inspect.stack()[0][3])

        inp = [
            [[[0,[5,8]],[[1,7],[9,6]]],[[4,[1,2]],[[1,4],2]]],
            [[[5,[2,8]],4],[5,[[9,9],0]]],
            [6,[[[6,2],[5,6]],[[7,6],[4,7]]]],
            [[[6,[0,7]],[0,9]],[4,[9,[9,0]]]],
            [[[7,[6,4]],[3,[1,3]]],[[[5,5],1],9]],
            [[6,[[7,3],[3,2]]],[[[3,8],[5,7]],4]],
            [[[[5,4],[7,7]],8],[[8,3],8]],
            [[9,3],[[9,9],[6,[4,9]]]],
            [[2,[[7,7],7]],[[5,8],[[9,3],[0,2]]]],
            [[[[5,2],5],[8,[3,7]]],[[5,[7,5]],[4,4]]]
        ]
        
        out = [[[[6,6],[7,6]],[[7,7],[7,0]]],[[[7,7],[7,7]],[[7,8],[9,9]]]]
        
        res = main_loop(inp)
        self.assertEqual(res, out)

class MagnitudeTest(unittest.TestCase):
    
    def test_simple(self):
        print("::TEST::", inspect.stack()[0][3])
        out = consume([9,1])
        self.assertEqual(out, 29)
        
    def test_simple_2(self):
        print("::TEST::", inspect.stack()[0][3])
        
        out = consume([[9,1],[1,9]])
        self.assertEqual(out, 129)
        
    def test_few_more_examples(self):
        print("::TEST::", inspect.stack()[0][3])
        
        self.assertEqual(consume([[1,2],[[3,4],5]]), 143)
        self.assertEqual(consume([[[[0,7],4],[[7,8],[6,0]]],[8,1]]), 1384)
        self.assertEqual(consume([[[[1,1],[2,2]],[3,3]],[4,4]]), 445)
        self.assertEqual(consume([[[[3,0],[5,3]],[4,4]],[5,5]]), 791)
        self.assertEqual(consume([[[[5,0],[7,4]],[5,5]],[6,6]]), 1137)
        self.assertEqual(consume([[[[8,7],[7,7]],[[8,6],[7,7]]],[[[0,7],[6,6]],[8,7]]]), 3488)
            
unittest.main(argv=[''], verbosity=0, exit=False)

::TEST:: test_first_full_example
::TEST:: test_last_example
::TEST:: test_next_examples
::TEST:: test_single_explodes
::TEST:: test_slightly_larger_example
::TEST:: test_few_more_examples
::TEST:: test_simple
::TEST:: test_simple_2


----------------------------------------------------------------------
Ran 8 tests in 0.248s

OK


<unittest.main.TestProgram at 0x114d1ff70>

## Part 2

>You notice a second question on the back of the homework assignment:
>
>What is the largest magnitude you can get from adding only two of the snailfish numbers?
>
>Note that snailfish addition is not commutative - that is, x + y and y + x can produce different results.
>
>Again considering the last example homework assignment above:
>
>```
[[[0,[5,8]],[[1,7],[9,6]]],[[4,[1,2]],[[1,4],2]]]
[[[5,[2,8]],4],[5,[[9,9],0]]]
[6,[[[6,2],[5,6]],[[7,6],[4,7]]]]
[[[6,[0,7]],[0,9]],[4,[9,[9,0]]]]
[[[7,[6,4]],[3,[1,3]]],[[[5,5],1],9]]
[[6,[[7,3],[3,2]]],[[[3,8],[5,7]],4]]
[[[[5,4],[7,7]],8],[[8,3],8]]
[[9,3],[[9,9],[6,[4,9]]]]
[[2,[[7,7],7]],[[5,8],[[9,3],[0,2]]]]
[[[[5,2],5],[8,[3,7]]],[[5,[7,5]],[4,4]]]
>```
>
>The largest magnitude of the sum of any two snailfish numbers in this list is 3993. This is the magnitude of `[[2,[[7,7],7]],[[5,8],[[9,3],[0,2]]]] + [[[0,[5,8]],[[1,7],[9,6]]],[[4,[1,2]],[[1,4],2]]]`, which reduces to `[[[[7,8],[6,6]],[[6,0],[7,7]]],[[[7,8],[8,8]],[[7,9],[0,6]]]]`.
>
>**What is the largest magnitude of any sum of two different snailfish numbers from the homework assignment?**



In [113]:
example_input = [
    [[[0,[5,8]],[[1,7],[9,6]]],[[4,[1,2]],[[1,4],2]]],
    [[[5,[2,8]],4],[5,[[9,9],0]]],
    [6,[[[6,2],[5,6]],[[7,6],[4,7]]]],
    [[[6,[0,7]],[0,9]],[4,[9,[9,0]]]],
    [[[7,[6,4]],[3,[1,3]]],[[[5,5],1],9]],
    [[6,[[7,3],[3,2]]],[[[3,8],[5,7]],4]],
    [[[[5,4],[7,7]],8],[[8,3],8]],
    [[9,3],[[9,9],[6,[4,9]]]],
    [[2,[[7,7],7]],[[5,8],[[9,3],[0,2]]]],
    [[[[5,2],5],[8,[3,7]]],[[5,[7,5]],[4,4]]],
    ]


largest_magnitude = 0
for pair in permutations(example_input, 2):
    final_sum = main_loop(pair)
    magnitude = consume(final_sum)
    
    if magnitude > largest_magnitude:
        largest_magnitude = magnitude
        
print(largest_magnitude)


KeyboardInterrupt: 

In [105]:
from itertools import permutations

largest_magnitude = 0
for pair in permutations(num_input, 2):
    final_sum = main_loop(pair)
    magnitude = consume(final_sum)
    if magnitude > largest_magnitude:
        print(magnitude)
        largest_magnitude = magnitude
        
print(largest_magnitude)

2971
4037
4096
4137
4311
4366


KeyboardInterrupt: 