In [1]:
from typing import Union, Optional
import math


class SnailValue:
    def __init__(self, value: int, parent=None):
        self.value = value
        self.parent = parent
        
    def split(self):
        assert self.parent is not None
#         print('split', self)
        replacement_pair = SnailMathPair(parent=self.parent)
        replacement_pair.left = self.__class__(math.floor(self.value / 2), parent=replacement_pair)
        replacement_pair.right = self.__class__(math.ceil(self.value / 2), parent=replacement_pair)
        if self.parent.left == self:
            self.parent.left = replacement_pair
        else:
            self.parent.right = replacement_pair
            
    def magnitude(self):
        return self.value
    
    def copy(self, parent=None):
        return self.__class__(self.value, parent=parent)
        
    def __repr__(self):
        return str(self.value)


class SnailMathPair:
    def __init__(self, 
                 left = None, 
                 right = None, 
                 parent = None):
        self.parent = parent
        self.left = left
        self.right = right
        
    @classmethod
    def parse(cls, expr: str, parent = None):
        if expr.startswith('['):
            expr = expr[1:len(expr)-1]
        bracket_count = 0
        comma_idx = None
        for i, c in enumerate(expr):
            if c == ',' and bracket_count == 0:
                comma_idx = i
                break
            elif c == '[':
                bracket_count += 1
            elif c == ']':
                bracket_count -= 1
        
        left, right = expr[:comma_idx], expr[comma_idx+1:]
        
        me = cls(parent=parent)
        me.left = cls.parse(left, me) if left.startswith('[') else SnailValue(int(left), me)
        me.right = cls.parse(right, me) if right.startswith('[') else SnailValue(int(right), me)
        return me
    
    def _find_left(self):
        parent = self.parent
        child = self
        while parent is not None:
            if parent.left != child:
                node = parent.left
                while not isinstance(node, SnailValue):
                    node = node.right
                if isinstance(node, SnailValue):
                    return node
                else:
                    return None
            child = parent
            parent = parent.parent
        return None
    
    def _find_right(self):
        parent = self.parent
        child = self
        while parent is not None:
            if parent.right != child:
                node = parent.right
                while not isinstance(node, SnailValue):
                    node = node.left
                if isinstance(node, SnailValue):
                    return node
                else:
                    return None
            child = parent
            parent = parent.parent
        return None
    
    def explode(self):
        assert self.parent is not None and isinstance(self.left, SnailValue) and isinstance(self.right, SnailValue)
#         print('explode', self)
        # add yourself to left & right nodes
        left = self._find_left()
        right = self._find_right()
        if left is not None:
            left.value = self.left.value + left.value
        if right is not None:
            right.value = self.right.value + right.value

        # remove yourself
        replacement_val = SnailValue(0, self.parent)
        if self.parent.left == self:
            self.parent.left = replacement_val
        else:
            self.parent.right = replacement_val
    
    def in_order_trav(self, depth = 1):
        if isinstance(self.left, self.__class__):
            yield from self.left.in_order_trav(depth + 1)
        else:
            yield self.left, depth
        yield self, depth
        if isinstance(self.right, self.__class__):
            yield from self.right.in_order_trav(depth + 1)
        else:
            yield self.right, depth
    
    def is_simple_pair(self):
        return isinstance(self.left, SnailValue) and isinstance(self.right, SnailValue)
    
    def reduce_one(self, depth: int = 1):
        action = False
        caused_explosion = True
        # Explosions cause chain-reactions
        while caused_explosion:
#             print('>>>', self)
            for node, depth in self.in_order_trav():
                if isinstance(node, self.__class__) and depth > 4 and node.is_simple_pair():
                    node.explode()
                    caused_explosion = action = True
                    break
            else:
                caused_explosion = False
        
        if action:
            return True
        
        # node splits do not
        for node, depth in self.in_order_trav():
            if isinstance(node, SnailValue) and node.value >= 10:
                node.split()
                return True
    
    def reduce(self):
        while self.reduce_one():
            pass
#             print('>>', self)
        
    def __add__(self, other):
        new_pair = self.__class__()
        self.parent = new_pair
        other.parent = new_pair
        new_pair.left = self
        new_pair.right = other
        return new_pair
    
    def copy(self, parent=None):
        new_me = self.__class__(parent=parent)
        new_me.left = self.left.copy(new_me)
        new_me.right = self.right.copy(new_me)
        return new_me
#         return self.__class__.parse(str(self))
    
    def magnitude(self):
        return 3 * self.left.magnitude() + 2 * self.right.magnitude()

    def __repr__(self):
        return f'[{self.left},{self.right}]'

In [2]:
expr = SnailMathPair.parse('[[6,[5,[4,[3,2]]]],1]')
expr

[[6,[5,[4,[3,2]]]],1]

In [3]:
expr = SnailMathPair.parse('[[[[[9,8],1],2],3],4]')
print('before', expr)
print('left', expr.left.left.left.left)
expr.left.left.left.left.explode()
print('after', expr)

before [[[[[9,8],1],2],3],4]
left [9,8]
after [[[[0,9],2],3],4]


In [4]:
expr = SnailMathPair.parse('[[[[[9,8],1],2],3],4]')
expr.reduce()
expr

[[[[0,9],2],3],4]

In [5]:
expr = SnailMathPair.parse('[[[[4,3],4],4],[7,[[8,4],9]]]') + SnailMathPair.parse('[1,1]')
print('after addition: ', expr)
expr.reduce()
expr

after addition:  [[[[[4,3],4],4],[7,[[8,4],9]]],[1,1]]


[[[[0,7],4],[[7,8],[6,0]]],[8,1]]

### Explosion examples

In [6]:
# [[[[[9,8],1],2],3],4] becomes [[[[0,9],2],3],4]
e = SnailMathPair.parse('[[[[[9,8],1],2],3],4]')
e.reduce()
e

[[[[0,9],2],3],4]

In [7]:
# [7,[6,[5,[4,[3,2]]]]] becomes [7,[6,[5,[7,0]]]]
e = SnailMathPair.parse('[7,[6,[5,[4,[3,2]]]]]')
e.reduce()
e

[7,[6,[5,[7,0]]]]

In [8]:
# [[6,[5,[4,[3,2]]]],1] becomes [[6,[5,[7,0]]],3].
e = SnailMathPair.parse('[[6,[5,[4,[3,2]]]],1]')
e.reduce()
e

[[6,[5,[7,0]]],3]

In [9]:
# [[3,[2,[1,[7,3]]]],[6,[5,[4,[3,2]]]]] becomes [[3,[2,[8,0]]],[9,[5,[4,[3,2]]]]] (the pair [3,2] is unaffected because the pair [7,3] is further to the left; [3,2] would explode on the next action).
e = SnailMathPair.parse('[[3,[2,[1,[7,3]]]],[6,[5,[4,[3,2]]]]]')
e.reduce_one()
e

[[3,[2,[8,0]]],[9,[5,[7,0]]]]

In [10]:
# [[3,[2,[8,0]]],[9,[5,[4,[3,2]]]]] becomes [[3,[2,[8,0]]],[9,[5,[7,0]]]].
e = SnailMathPair.parse('[[3,[2,[8,0]]],[9,[5,[4,[3,2]]]]]')
e.reduce()
e

[[3,[2,[8,0]]],[9,[5,[7,0]]]]

### Example

In [11]:
def add_parsed_list(some_str):
    expressions = [SnailMathPair.parse(expr) for expr in some_str.split('\n') if expr]
    final_expr = expressions[0]
    print(expressions[0] + expressions[1])
    for expr in expressions[1:]:
        final_expr = final_expr + expr
        final_expr.reduce()
        print('+', final_expr)
    return final_expr

In [12]:
add_parsed_list('''
[1,1]
[2,2]
[3,3]
[4,4]
''')

[[1,1],[2,2]]
+ [[1,1],[2,2]]
+ [[[1,1],[2,2]],[3,3]]
+ [[[[1,1],[2,2]],[3,3]],[4,4]]


[[[[1,1],[2,2]],[3,3]],[4,4]]

In [13]:
add_parsed_list('''
[1,1]
[2,2]
[3,3]
[4,4]
[5,5]
''')

[[1,1],[2,2]]
+ [[1,1],[2,2]]
+ [[[1,1],[2,2]],[3,3]]
+ [[[[1,1],[2,2]],[3,3]],[4,4]]
+ [[[[3,0],[5,3]],[4,4]],[5,5]]


[[[[3,0],[5,3]],[4,4]],[5,5]]

In [14]:
add_parsed_list('''
[1,1]
[2,2]
[3,3]
[4,4]
[5,5]
[6,6]
''')

[[1,1],[2,2]]
+ [[1,1],[2,2]]
+ [[[1,1],[2,2]],[3,3]]
+ [[[[1,1],[2,2]],[3,3]],[4,4]]
+ [[[[3,0],[5,3]],[4,4]],[5,5]]
+ [[[[5,0],[7,4]],[5,5]],[6,6]]


[[[[5,0],[7,4]],[5,5]],[6,6]]

In [15]:
add_parsed_list('''
[[[0,[4,5]],[0,0]],[[[4,5],[2,6]],[9,5]]]
[7,[[[3,7],[4,3]],[[6,3],[8,8]]]]
''')

[[[[0,[4,5]],[0,0]],[[[4,5],[2,6]],[9,5]]],[7,[[[3,7],[4,3]],[[6,3],[8,8]]]]]
+ [[[[4,0],[5,4]],[[7,7],[6,0]]],[[8,[7,7]],[[7,9],[5,0]]]]


[[[[4,0],[5,4]],[[7,7],[6,0]]],[[8,[7,7]],[[7,9],[5,0]]]]

In [16]:
add_parsed_list('''
[[[0,[4,5]],[0,0]],[[[4,5],[2,6]],[9,5]]]
[7,[[[3,7],[4,3]],[[6,3],[8,8]]]]
[[2,[[0,8],[3,4]]],[[[6,7],1],[7,[1,6]]]]
[[[[2,4],7],[6,[0,5]]],[[[6,8],[2,8]],[[2,1],[4,5]]]]
[7,[5,[[3,8],[1,4]]]]
[[2,[2,2]],[8,[8,1]]]
[2,9]
[1,[[[9,3],9],[[9,0],[0,7]]]]
[[[5,[7,4]],7],1]
[[[[4,2],2],6],[8,7]]
''')

[[[[0,[4,5]],[0,0]],[[[4,5],[2,6]],[9,5]]],[7,[[[3,7],[4,3]],[[6,3],[8,8]]]]]
+ [[[[4,0],[5,4]],[[7,7],[6,0]]],[[8,[7,7]],[[7,9],[5,0]]]]
+ [[[[6,7],[6,7]],[[7,7],[0,7]]],[[[8,7],[7,7]],[[8,8],[8,0]]]]
+ [[[[7,0],[7,7]],[[7,7],[7,8]]],[[[7,7],[8,8]],[[7,7],[8,7]]]]
+ [[[[7,7],[7,8]],[[9,5],[8,7]]],[[[6,8],[0,8]],[[9,9],[9,0]]]]
+ [[[[6,6],[6,6]],[[6,0],[6,7]]],[[[7,7],[8,9]],[8,[8,1]]]]
+ [[[[6,6],[7,7]],[[0,7],[7,7]]],[[[5,5],[5,6]],9]]
+ [[[[7,8],[6,7]],[[6,8],[0,8]]],[[[7,7],[5,0]],[[5,5],[5,6]]]]
+ [[[[7,7],[7,7]],[[8,7],[8,7]]],[[[7,0],[7,7]],9]]
+ [[[[8,7],[7,7]],[[8,6],[7,7]]],[[[0,7],[6,6]],[8,7]]]


[[[[8,7],[7,7]],[[8,6],[7,7]]],[[[0,7],[6,6]],[8,7]]]

### Part 1

In [17]:
with open('./day18.txt') as f:
    part1 = f.read()

answer = add_parsed_list(part1)
print('=', answer)
print('mag:', answer.magnitude())

[[[[6,[8,3]],[2,0]],[[[9,5],[9,1]],3]],[[[9,[2,2]],[5,4]],[[[2,2],[9,6]],[7,7]]]]
+ [[[[6,0],[7,7]],[[8,7],9]],[[[6,6],[6,0]],[[6,6],[6,7]]]]
+ [[[[6,6],[8,7]],[[8,8],[8,0]]],[[[8,7],8],[[2,5],[0,5]]]]
+ [[[[6,7],[7,8]],[[7,0],[8,7]]],[[[7,8],[5,6]],[[0,5],[3,1]]]]
+ [[[[6,6],[6,0]],[[7,7],[8,8]]],[[[6,7],4],1]]
+ [[[[7,7],[8,0]],[[8,9],[8,9]]],[[[5,0],[7,7]],[[9,1],9]]]
+ [[[[7,7],[7,7]],[[0,7],[7,8]]],[[[8,8],[8,8]],[[8,8],[7,7]]]]
+ [[[[7,6],[7,7]],[[7,7],[0,7]]],[[[7,7],[8,8]],[[8,9],[9,8]]]]
+ [[[[6,6],[6,6]],[[7,7],[7,8]]],[[[7,0],[8,7]],[[7,7],[7,7]]]]
+ [[[[6,7],[7,7]],[[7,7],[8,8]]],[[[8,8],[8,8]],[[8,0],[8,8]]]]
+ [[[[6,6],[0,7]],[[7,7],[7,7]]],[[[7,9],[8,7]],[5,[0,9]]]]
+ [[[[6,6],[6,6]],[[6,6],[7,7]]],[[[7,0],[7,7]],[[7,8],[8,7]]]]
+ [[[[7,7],[7,7]],[[7,7],[7,0]]],[[[8,9],[5,6]],[[8,7],[0,8]]]]
+ [[[[7,7],[7,8]],[[8,7],[7,7]]],[[[0,8],[8,8]],[[8,8],[8,8]]]]
+ [[[[6,6],[7,7]],[[7,7],[8,8]]],[[[8,8],[7,7]],[[8,0],[9,9]]]]
+ [[[[7,7],[7,0]],[[7,7],[7,7]]],[[[7,7],[7,7]],8]]
+ 

### Part 2
(which is somehow easier than Part 1)

In [18]:
def parse_and_find_max_magnitude_permutation(lines):
    expressions = [SnailMathPair.parse(line) for line in lines.split('\n') if line]
    max_magnitude = float('-inf')
    max_expr = None, None
    for i, first in enumerate(expressions):
        for j, second in enumerate(expressions):
            if i == j:
                continue
            expr = first.copy() + second.copy()
            expr.reduce()
            mag = expr.magnitude()
            if mag > max_magnitude:
                max_magnitude = mag
                max_expr = (first, second)
    print('max:', *max_expr)
    return max_magnitude

In [19]:
example = '''
[[[0,[5,8]],[[1,7],[9,6]]],[[4,[1,2]],[[1,4],2]]]
[[[5,[2,8]],4],[5,[[9,9],0]]]
[6,[[[6,2],[5,6]],[[7,6],[4,7]]]]
[[[6,[0,7]],[0,9]],[4,[9,[9,0]]]]
[[[7,[6,4]],[3,[1,3]]],[[[5,5],1],9]]
[[6,[[7,3],[3,2]]],[[[3,8],[5,7]],4]]
[[[[5,4],[7,7]],8],[[8,3],8]]
[[9,3],[[9,9],[6,[4,9]]]]
[[2,[[7,7],7]],[[5,8],[[9,3],[0,2]]]]
[[[[5,2],5],[8,[3,7]]],[[5,[7,5]],[4,4]]]
'''

parse_and_find_max_magnitude_permutation(example)

max: [[2,[[7,7],7]],[[5,8],[[9,3],[0,2]]]] [[[0,[5,8]],[[1,7],[9,6]]],[[4,[1,2]],[[1,4],2]]]


3993

In [20]:
with open('./day18.txt') as f:
    part2 = f.read()
parse_and_find_max_magnitude_permutation(part2)

max: [[[3,7],[[9,8],8]],[[[8,4],7],[3,[1,7]]]] [[[2,[8,6]],[[9,8],2]],[[9,5],[1,[9,8]]]]


4680