Organization:
- Work
  - 1 test: defining functions for part 1, testing on test input
  - 1 run: getting answer for part 1
  - 2 test: ...
  - 2 run: ...
- Utilities: functions I think might help parse general inputs
- Inputs: where I define the test (_t_) and problem (_s_) inputs

# Work

## 1 test

In [8]:
pairs = [(eval(pair[0]), eval(pair[1])) for pair in split(t)]
pairs

[([1, 1, 3, 1, 1], [1, 1, 5, 1, 1]),
 ([[1], [2, 3, 4]], [[1], 4]),
 ([9], [[8, 7, 6]]),
 ([[4, 4], 4, 4], [[4, 4], 4, 4, 4]),
 ([7, 7, 7, 7], [7, 7, 7]),
 ([], [3]),
 ([[[]]], [[]]),
 ([1, [2, [3, [4, [5, 6, 7]]]], 8, 9], [1, [2, [3, [4, [5, 6, 0]]]], 8, 9])]

In [25]:
# Compare a pair! Recursive
# Things return None by default
# If it made a decision about the order, return True (in order) or False (out of order)
def compare(x, y):
    # Int case: process the 3 comparison cases appropriately
    if isinstance(x, int) and isinstance(y, int):
        if x < y:
            r = True
        elif x > y:
            r = False
        else:
            r = None
        return r
    
    # Split cases: just formatting inputs, so no processing is needed on the returns
    if isinstance(x, int) and isinstance(y, list):
        return compare([x], y)
    if isinstance(x, list) and isinstance(y, int):
        return compare(x, [y])
    
    # List comparisons
    
    # Compare up to the common length:
    r = None
    for xp, yp in zip(x, y):
        r = compare(xp, yp)
        if isinstance(r, bool):
            break
    
    # Check length conditions only if we don't yet have a binding condition
    if not isinstance(r, bool):
        if len(x) < len(y):
            r = True
        elif len(x) > len(y):
            r = False
        else:
            r = None
    
    return r

In [26]:
results = [compare(*pair) for pair in pairs]
results

[True, True, False, True, False, True, False, False]

In [28]:
import numpy as np

In [34]:
sum(np.array(results, dtype=int) * (np.arange(len(results)) + 1))

13

## 1 run

In [38]:
pairs = [(eval(pair[0]), eval(pair[1])) for pair in split(s)]
results = [compare(*pair) for pair in pairs]

In [37]:
# Check they all have results
for result in results:
    if not isinstance(result, bool):
        print('fail!')

In [39]:
sum(np.array(results, dtype=int) * (np.arange(len(results)) + 1))

5350

## 2 test

In [44]:
packets = [eval(packet) for pair in split(t) for packet in pair]
packets += [[[2]], [[6]]]
packets

[[1, 1, 3, 1, 1],
 [1, 1, 5, 1, 1],
 [[1], [2, 3, 4]],
 [[1], 4],
 [9],
 [[8, 7, 6]],
 [[4, 4], 4, 4],
 [[4, 4], 4, 4, 4],
 [7, 7, 7, 7],
 [7, 7, 7],
 [],
 [3],
 [[[]]],
 [[]],
 [1, [2, [3, [4, [5, 6, 7]]]], 8, 9],
 [1, [2, [3, [4, [5, 6, 0]]]], 8, 9],
 [[2]],
 [[6]]]

In [54]:
# Will use to change the compare function of the built-in sorting function
from functools import cmp_to_key

In [50]:
# Parses into -1/0/1
def compare_wrapper(x,y):
    r = compare(x,y)
    if isinstance(r, bool):
        if r:
            out = 1
        else:
            out = -1
    else:
        out = 0
    return out

In [59]:
packets = sorted(packets, key=cmp_to_key(compare_wrapper), reverse=True)
packets

[[],
 [[]],
 [[[]]],
 [1, 1, 3, 1, 1],
 [1, 1, 5, 1, 1],
 [[1], [2, 3, 4]],
 [1, [2, [3, [4, [5, 6, 0]]]], 8, 9],
 [1, [2, [3, [4, [5, 6, 7]]]], 8, 9],
 [[1], 4],
 [[2]],
 [3],
 [[4, 4], 4, 4],
 [[4, 4], 4, 4, 4],
 [[6]],
 [7, 7, 7],
 [7, 7, 7, 7],
 [[8, 7, 6]],
 [9]]

In [60]:
for i in range(len(packets)):
    packet = packets[i]
    if compare(packet, [[2]]) is None:
        print(f'[[2]] at {i+1}')
    if compare(packet, [[6]]) is None:
        print(f'[[6]] at {i+1}')

[[2]] at 10
[[6]] at 14


## 2 run

In [63]:
packets = [eval(packet) for pair in split(s) for packet in pair]
packets += [[[2]], [[6]]]
packets = sorted(packets, key=cmp_to_key(compare_wrapper), reverse=True)

In [64]:
for i in range(len(packets)):
    packet = packets[i]
    if compare(packet, [[2]]) is None:
        print(f'[[2]] at {i+1}')
    if compare(packet, [[6]]) is None:
        print(f'[[6]] at {i+1}')

[[2]] at 103
[[6]] at 190


In [65]:
103 * 190

19570

# Utilities

In [1]:
# Remove initial/final \n characters
def clean(s):
    return s[1:-1]

# Split at \n characters
# If there are \n\n characters, split into blocks too
def split(s, block_char = '\n\n', line_char = '\n'):
    out = [block.split(line_char) for block in clean(s).split(block_char)]
    if len(out) == 1:
        return out[0]
    else:
        return out

# Apply a function(s) to a list or "block" data (2-level list)
def apply_func(data, func, nested=False):
    if not isinstance(func, list):
        func = [func]
        
    def _func(x):
        for f in func:
            x = f(x)
        return x
        
    if nested:
        return [[_func(x) for x in block] for block in data]
    else:
        return [_func(x) for x in data]

# Split, parsing everything as ints
def split_int(s):
    return apply_func(split(s), int)

# Split, parsing everything as float
def split_float(s):
    return apply_func(split(s), float)

# Inputs

In [2]:
t = """
[1,1,3,1,1]
...
[1,[2,[3,[4,[5,6,0]]]],8,9]
"""

In [3]:
s = """
[[[6,10],[4,3,[4]]]]
...
[[[[9],[10,4,7,0],[5,6],[7,3,6,0,10]]],[],[[[9,8,10,4,2],[1,6,1],10],3]]
"""