In [1]:
from dataclasses import dataclass
from numpy import product
from operator import lt, gt, eq

In [2]:
@dataclass
class Packet:
    version: int
    type_id: int
    literal: int
    length: int
    sub_packets: list


In [11]:
def parse_packet(bits, depth=0):
    version = int(bits[0:3], 2)
    type_id = int(bits[3:6], 2)
    position = 6

    # literal packet
    if type_id == 4:  # literal
        literal, is_last = 0, False
        while not is_last:
            is_last = bits[position] == "0"
            literal = (literal << 4) | int(bits[position + 1 : position + 5], 2)
            position += 5
        return Packet(version, type_id, literal, position, [])

    # operator packet
    length_type_id = int(bits[position], 2)
    position += 1
    if length_type_id == 0:  # total length in bits
        total_length = int(bits[position : position + 15], 2)
        position += 15
        length = 0
        sub_packets = []
        while length != total_length:
            sub = parse_packet(bits[position : position + total_length - length], depth + 1)
            sub_packets.append(sub)
            length += sub.length
            position += sub.length
        return Packet(version, type_id, 0, position, sub_packets)

    # number of sub-packets immediately contained
    number = int(bits[position : position + 11], 2)
    position += 11
    sub_packets = []
    for _ in range(number):
        sub = parse_packet(bits[position:], depth + 1)
        sub_packets.append(sub)
        position += sub.length
    return Packet(version, type_id, 0, position, sub_packets)


def add_versions(packets):
    total = 0
    for packet in packets:
        total += packet.version
        total += add_versions(packet.sub_packets)
    return total


def compute(p):
    packet_operations = {
        0: sum,
        1: product,
        2: min,
        3: max,
        5: gt,
        6: lt,
        7: eq
    }
    if p.type_id == 4:
        return p.literal
    values = [compute(x) for x in p.sub_packets]
    return packet_operations[p.type_id](values) if p.type_id not in [5,6,7] else packet_operations[p.type_id](*values)*1

In [12]:
with open('../input/D16.txt', 'r') as f:
    data = f.read().strip()
bits = bin(int(data, 16))[2:].zfill(len(data) * 4)
packets = parse_packet(bits)

In [13]:
add_versions([packets])

891

In [14]:
compute(packets)

673042777597