In [15]:
import sys
import numpy as np
import math
import operator
import functools
import json
from pprint import pprint

In [317]:
import copy
class Node:
    def __init__(self, value, level, left=None, right=None):
        self.value = value
        self.level = level
        self.left = left
        self.right = right
        
    @staticmethod
    def createFromList(lst, level=0, addLeft=0, addRight=0):
        left = Node(lst[0], level+1)  if type(lst[0]) == int else Node.createFromList(lst[0], level+1)
        right = Node(lst[1], level+1) if type(lst[1]) == int else Node.createFromList(lst[1], level+1)
        
        return Node(None, level, left, right)
    
    def toList(self):
        if type(self.value) == int:
            return self.value
        return [self.left.toList(), self.right.toList()]
    
    def addRight(self, value):
        if self.right:
            self.right.addRight(value)
        else:
            self.value += value
            
    def addLeft(self, value):
        if self.left:
            self.left.addLeft(value)
        else:
            self.value += value
    
    # Assume max level is 5
    def addLevel(self):
        self.level += 1
        if type(self.value) == int:
            return
        self.left.addLevel()
        self.right.addLevel()
        
    def explode(self):
        if type(self.value) != int and self.level >= 4:
            self.value = 0
            left, right = self.left.value, self.right.value
            self.left = None
            self.right = None
            return (left, right)
        
        add_left_value = self.left.explode() if self.left else None
        if add_left_value:
            # left side exploded
            if add_left_value[1]:
                self.right.addLeft(add_left_value[1]) 
            return (add_left_value[0], None)
            
        
        add_right_value = self.right.explode() if self.right else None
        if add_right_value:
            # right side exploded
            if add_right_value[0]:
                self.left.addRight(add_right_value[0])
            return (None, add_right_value[1])
        
    def split(self):
        if type(self.value) == int and self.value >= 10:
            self.left = Node(math.floor(self.value/2), self.level + 1)
            self.right = Node(math.ceil(self.value/2), self.level + 1)
            self.value = None
            return True
        if type(self.value) == int:
            return False
        left = self.left.split()
        if left:
            return True
        return self.right.split()
    
    def reduce(self):
        while True:
            exploded = self.explode()
            if not exploded:
                splitted = self.split()
                if not splitted:
                    break
                    
    @property
    def magnitude(self):
        if type(self.value) == int:
            return self.value
        return 3*self.left.magnitude + 2 * self.right.magnitude
        
    
    def __add__(self, other):
        node = Node(None, -1, copy.deepcopy(self), copy.deepcopy(other))
        node.addLevel()
        node.reduce()
        return node
        
    def __repr__(self):
        if type(self.value) == int:
            return f"{self.value}[{self.level}]"
        return f"{self.level}[{self.left}, {self.right}]"
    
    
def magnitude(pair):
    if type(pair) == int:
        return pair
    return 3*magnitude(pair[0]) + 2*magnitude(pair[1])


In [318]:
def part1(arr):
    arr = [json.loads(a) for a in arr]
    
    sum = None
    for a in arr:
        if sum:
            sum += Node.createFromList(a)
        else:
            sum = Node.createFromList(a)
    return magnitude(sum.toList())

In [324]:
def part2(arr):
    arr = [Node.createFromList(json.loads(a)) for a in arr]
    m = 0
    # Brute force: check all combinations
    for i in range(len(arr)):
        for j in range(i+1,len(arr)):
            a = arr[i] + arr[j]
            b = arr[j] + arr[i]
            m = max(m, a.magnitude, b.magnitude)
    return m

In [326]:
if __name__ == "__main__":
    file = sys.argv[-1] if sys.argv[-1].endswith(".txt") else "input.txt"
    with open(file, "r") as f:
        arr = f.read().splitlines()

        print(part1(arr))
        print(part2(arr))

4116
4638
