# Day 21

## Read input

In [19]:
class Node:
    
    def __init__(self, name):
        self.name = name
        self.value = None
        self.left = None
        self.right = None
        self.operation = None
        
    def add_value(self, value):
        self.value = value
        
    def add_left(self, node):
        self.left = node
    
    def add_right(self, node):
        self.right = node
        
    def add_operation(self, operation):
        self.operation = operation      
        
    def __repr__(self):
        if self.value:
            return f'<{self.name}: {self.value}>'
        elif self.left and self.right:
            return f'<{self.name}: {self.left.name} {self.operation} {self.right.name}>'
        else:
            return f'<{self.name}: [empty]>'

In [96]:
from utils import read_input
import re

def transformer(line):
    name, operation = line.split(': ')
    try:
        operation = int(operation)
    except ValueError:
        operation = re.split(r' (\+|-|\/|\*) ', operation)
    
    return name, operation

data = read_input(21, transformer)
temp = {}
for node in data:
    temp[node[0]] = node[1]

tree = {}
unfound = []
while True:
    try:
        name, op = temp.popitem()
    except KeyError:
        if unfound:
            for name, op in unfound:
                temp[name] = op
            unfound = []
            continue
        else:
            break
    if isinstance(op, int):
        node = Node(name)
        node.value = op
        tree[node.name] = node
        continue
    else:
        left, oper, right = op
        if left in tree and right in tree:
            node = Node(name)
            node.left = tree[left]
            node.right = tree[right]
            node.operation = oper
            tree[node.name] = node
        else:
            unfound.append((name, op))

In [54]:
def calculate(node):
    if node.value:
        return node.value
    
    match node.operation:
        case '+':
            return calculate(node.left) + calculate(node.right)
        case '-':
            return calculate(node.left) - calculate(node.right)
        case '/':
            return calculate(node.left) // calculate(node.right)
        case '*':
            return calculate(node.left) * calculate(node.right)

## Part 1

In [56]:
result = calculate(tree['root'])
print(f'Part 1: {result}')
assert result == 31017034894002

Part 1: 31017034894002


## Part 2

In [99]:
def is_child(root):
    if root.name == 'humn':
        return True
    if root.value is not None:
        return False
    return is_child(root.left) or is_child(root.right)
    

def correct(node, target):
    if is_child(node.left):
        value = calculate(node.right)
        match node.operation:
            case '+':
                new_target = abs(target - value)
            case '-':
                new_target = target + value
            case '/':
                new_target = target * value
            case '*':
                new_target = target // value
        new_path = node.left
        if node.left.name == 'humn':
            return new_target
        else:
            return correct(new_path, new_target)
    else:
        value = calculate(node.left)
        match node.operation:
            case '+':
                new_target = abs(value - target)
            case '-':
                new_target = value - target
            case '/':
                new_target = value // target
            case '*':
                new_target = target // value
        new_path = node.right
        if node.right.name == 'humn':
            return new_target
        else:
            return correct(new_path, new_target)

def solve2(root):
    if is_child(root.left):
        return correct(root.left, calculate(root.right))
    else:
        return correct(root.right, calculate(root.left))

In [100]:
part2 = solve2(tree['root'])
print(f'Part 2: {part2}')
assert part2 == 3555057453229

Part 2: 3555057453229
