In [4]:
import parse

In [81]:
from aocd.models import Puzzle

puzzle = Puzzle(year=2021, day=16)

data = puzzle.input_data

In [185]:
import math
from dataclasses import dataclass
from typing import Sequence
    
@dataclass
class Packet:
    version: int
    type_: int

@dataclass
class Literal(Packet): 
    value: int
        
@dataclass
class Operator(Packet):
    subpackets: Sequence[Packet]
        
    OP_FN = {
        0: sum,
        1: math.prod,
        2: min,
        3: max,
        5: lambda x: int(x[0] > x[1]),
        6: lambda x: int(x[0] < x[1]),
        7: lambda x: int(x[0] == x[1]),  
    }

    @property
    def value(self):
        return self.OP_FN[self.type_]([s.value for s in self.subpackets])

In [187]:
s = '9C005AC2F8F0'
Buffer(s).decode().value

0

In [188]:
Buffer(data).decode().value

144595909277

In [156]:

        
def list2num(nums):
    return sum(2**i*n for i,n in enumerate(reversed(nums)))

class Buffer:
    
    def __init__(self, data):
        self.p = 0
        bytes_ = [int(h, 16) for h in data.strip()]
        self.bits = [int(h&i > 0) for h in bytes_ for i in [8,4,2,1]]
        
    def read(self, n) -> int:
        data = self.bits[self.p:self.p+n]
        self.p += n
        return list2num(data)
    
    def read_literal(self) -> int:
        b, n = 1, 0
        while b == 1:
            b = self.read(1)
            n = (n << 4) | self.read(4)
        return n

    def decode(self) -> 'Packet':
        version = self.read(3)
        type_ = self.read(3)
        if type_ == 4:
            val = self.read_literal()
            return Literal(version, type_, val)
        else:
            length = self.read(1)
            if length == 0:
                end = self.read(15) + self.p # order is critical here
                subpackets = []
                while self.p < end:
                    subpackets.append(self.decode())
            else:
                subpackets = [self.decode() for _ in range(self.read(11))]
            return Operator(version, type_, subpackets)
        


In [157]:
L = Literal
Op = Operator

In [158]:
assert Buffer('D2FE28').decode() == L(6,4,2021)

In [159]:
Buffer('38006F45291200').decode() #== Op(1,6,[L])

Operator(version=1, type=6, subpackets=[Literal(version=6, type=4, value=10), Literal(version=2, type=4, value=20)])

In [160]:
Buffer('EE00D40C823060').decode()

Operator(version=7, type=3, subpackets=[Literal(version=2, type=4, value=1), Literal(version=4, type=4, value=2), Literal(version=1, type=4, value=3)])

In [161]:
len('8A004A801A8002F478') * 4

72

In [162]:
Buffer('8A004A801A8002F478').decode()

Operator(version=4, type=2, subpackets=[Operator(version=1, type=2, subpackets=[Operator(version=5, type=2, subpackets=[Literal(version=6, type=4, value=15)])])])

In [163]:
def solve_a(data):
    root = Buffer(data).decode()
    def sumver(pkt):
        if isinstance(pkt, Op):
            return pkt.version + sum(sumver(p) for p in pkt.subpackets)
        return pkt.version
    return sumver(root)


In [170]:
samples_a = [
    ('8A004A801A8002F478', 16),
    ('620080001611562C8802118E34', 12),
    ('C0015000016115A2E0802F182340', 23),
    ('A0016C880162017C3686B18A3D4780', 31),
]

In [None]:
samples_b  = [
    ('C200B40A82', 3),
    ('04005AC33890', 54),
    ('880086C3E88112', 7),
    ('CE00C43D881120', 9),
    ('D8005AC2A8F0', 1),
    ('F600BC2D8F', 0),
    ('9C005AC2F8F0', 0),
    ('9C0141080250320F1802104A08', 1),
]

In [171]:
for sample, sol in samples_a:
    assert solve_a(sample) == sol

In [172]:
solve_a(data)

993