In [157]:
import json
import math

In [134]:
fname = 'example.txt'
with open(fname) as f:
    lines = f.readlines()
data = lines[0].strip()

In [183]:
def parse_packet(packet_string, start_idx=0):
    
    bits_consumed = 0
    
    version = int(packet_string[start_idx:start_idx+3], 2)
    type_ID = int(packet_string[start_idx+3:start_idx+6], 2)
    bits_consumed += 6
    
    # literal 
    if type_ID == 4:
        bits = ""
        while True:
            begin_idx, end_idx = start_idx + bits_consumed, start_idx + bits_consumed + 5
            encoded_bits = packet_string[begin_idx:end_idx]
            group = encoded_bits[0]
            bits += encoded_bits[1:]
            bits_consumed += 5
            if group == '0':
                break

        literal = int(bits, 2)
        return {"type": "literal", "type_ID": type_ID, "version": version, "literal": literal}, bits_consumed

    # operator packet
    else:
        length_type_ID = int(packet_string[start_idx+6], 2)
        bits_consumed += 1
        if length_type_ID == 0:
            # length in bits of sub packets
            # get next 15
            sub_packets_length_in_bits = int(packet_string[start_idx+7:start_idx+7+15], 2)
            bits_consumed += 15
    
            sub_packet_bits_consumed = 0
            sub_packets = []
            while sub_packet_bits_consumed < sub_packets_length_in_bits:
                sub_packet_start_idx = start_idx + bits_consumed + sub_packet_bits_consumed
                packet_content, num_bits = parse_packet(packet_string, start_idx=sub_packet_start_idx)
                sub_packet_bits_consumed += num_bits
                sub_packets.append(packet_content)
                                                        
            return {"type": "operator", "type_ID": type_ID, "version": version, "sub_packets": sub_packets}, bits_consumed + sub_packet_bits_consumed
            
        else:
            # number of subpackets, get next 11
            num_sub_packets = int(packet_string[start_idx+7:start_idx+7+11], 2)
            bits_consumed += 11
            
            sub_packet_bits_consumed = 0
            sub_packets = []
            while len(sub_packets) < num_sub_packets:
                sub_packet_start_idx = start_idx + bits_consumed + sub_packet_bits_consumed
                packet_content, num_bits = parse_packet(packet_string, start_idx=sub_packet_start_idx)
                sub_packet_bits_consumed += num_bits
                sub_packets.append(packet_content)
                                                 
            return {"type": "operator", "type_ID": type_ID, "version": version, "sub_packets": sub_packets}, bits_consumed + sub_packet_bits_consumed

def sum_versions(parsed):
    packet = parsed
    if 'sub_packets' in packet:
        return packet['version'] + sum([sum_versions(sp) for sp in packet['sub_packets']]) 
    else:
        return packet['version']

def hex_to_bin(hex_string):
    char_lookup = {'0': "0000", '1': "0001", '2': "0010", '3': "0011", '4': "0100", '5': "0101", '6': "0110", '7': "0111", '8': "1000", '9': "1001", 'A': "1010", 'B': "1011", 'C': "1100", 'D': "1101", 'E': "1110", 'F': "1111"}
    bin_string = ""
    for char in hex_string:
        bin_string += char_lookup[char]
    return bin_string

def eval_exprs(packet):
    
    operator_lookup = {
        0: sum,
        1: math.prod,
        2: min,
        3: max,
        5: lambda arr: int(arr[0] > arr[1]),
        6: lambda arr: int(arr[0] < arr[1]),
        7: lambda arr: int(arr[0] == arr[1])
    }
    
    if packet["type"] == "operator":
        values = [eval_exprs(sp) for sp in packet['sub_packets']]
        return operator_lookup[packet["type_ID"]](values)
    else:
        return packet['literal']

In [184]:
packet_string = hex_to_bin(data)
parsed, _ = parse_packet(packet_string)
# print(json.dumps(parsed, indent=2))
print(f"part 1: {sum_versions(parsed)}")
print(f"part 2: {eval_exprs(parsed)}")

part 1: 957
part 2: 744953223228
