In [1]:
import aocd
from collections import namedtuple, deque
from math import prod

In [2]:
Packet = namedtuple('Packet', ['version', 'type', 'payload'])

In [3]:
def pop_bits(packet, nbits, to_int=True):
    bits = [packet.popleft() for _ in range(nbits)]
    return int(''.join(bits), 2) if to_int else deque(bits)

In [4]:
def parse_hex_packet(h):
    bin_packet = f'{int(h, 16):0b}'.zfill(len(h) * 4)
    return parse_packet(deque(bin_packet))

def parse_packet(packet):
    version = pop_bits(packet, 3)
    type_id = pop_bits(packet, 3)

    if type_id == 4:
        literal_value = 0
        while True:
            v = pop_bits(packet, 5)
            literal_value <<= 4
            literal_value += v & 0b1111
            if not v >> 4:
                break
        return Packet(version, type_id, literal_value)
    
    subpackets = []
    length_type = pop_bits(packet, 1)
    if length_type == 0:
        bitnum_subpackets = pop_bits(packet, 15)
        bits_subpackets = pop_bits(packet, bitnum_subpackets, to_int=False)
        while bits_subpackets:
            subpackets.append(parse_packet(bits_subpackets))
    else:
        num_subpackets = pop_bits(packet, 11)
        for _ in range(num_subpackets):
            subpackets.append(parse_packet(packet))
    return Packet(version, type_id, subpackets)

In [5]:
def version_count(packet):
    if packet.type == 4:
        return packet.version
    else:
        return packet.version + sum(version_count(p) for p in packet.payload)

In [6]:
def calculate(packet):
    if packet.type == 0:
        return sum(calculate(p) for p in packet.payload)
    elif packet.type == 1:
        return prod(calculate(p) for p in packet.payload)
    elif packet.type == 2:
        return min(calculate(p) for p in packet.payload)
    elif packet.type == 3:
        return max(calculate(p) for p in packet.payload)
    elif packet.type == 4:
        return packet.payload
    elif packet.type == 5:
        return 1 if calculate(packet.payload[0]) > calculate(packet.payload[1]) else 0
    elif packet.type == 6:
        return 1 if calculate(packet.payload[0]) < calculate(packet.payload[1]) else 0
    elif packet.type == 7:
        return 1 if calculate(packet.payload[0]) == calculate(packet.payload[1]) else 0

In [7]:
data = aocd.get_data(day=16)
packet = parse_hex_packet(data)
task1 = version_count(packet)
task2 = calculate(packet)

print(task1)
print(task2)
assert(task1 == 873)
assert(task2 == 402817863665)

873
402817863665
