In [None]:
from tabulate import tabulate

EXAMPLE = "../example.txt"
INPUT = "../input.txt"

In [None]:
def parse_input(input_file_name):
    keys = []
    locks = []
    current_key = None
    current_lock = None
    with open(input_file_name, 'r') as f:
        for line in f:
            if line == '\n':
                if current_key is not None:
                    keys.append(current_key)
                    current_key = None
                if current_lock is not None:
                    locks.append(current_lock)
                    current_lock = None
                continue
            if current_key is None and current_lock is None:
                if '.' in line:
                    current_key = []
                elif '#' in line:
                    current_lock = []
                else:
                    raise(ValueError)
            if current_key is not None:
                current_key.append(line.strip().replace('\n', ''))
            if current_lock is not None:
                current_lock.append(line.strip().replace('\n', ''))
        # Adding last lock/key as there is no empty line at the end of the file
        if current_key is not None:
            keys.append(current_key)
            current_key = None
        if current_lock is not None:
            locks.append(current_lock)
            current_lock = None
    return keys, locks

In [None]:
keys, locks = parse_input(EXAMPLE)
for key in keys:
    print(tabulate(key))
for lock in locks:
    print(tabulate(lock))

In [None]:
def get_pin_heights(schematic):
    height = len(schematic)
    width = len(schematic[0])
    pin_heights = []
    for j in range(width):
        nb_of_pins = 0
        for i in range(1, height-1):
            if schematic[i][j] == '#':
                nb_of_pins += 1
        pin_heights.append(nb_of_pins)
    return tuple(pin_heights)

In [None]:
for key in keys:
    print(get_pin_heights(key))
for lock in locks:
    print(get_pin_heights(lock))

For each column, for each possible pin height, we build a set of all the keys that have that height for that column.

In [None]:
def build_key_set():
    key_set = []
    for _ in range(5):
        key_set.append([])
        for _ in range(6):
            key_set[-1].append(set())
    return key_set

In [None]:
key_set = build_key_set()
print(key_set)

In [None]:
def fill_key_set(key_set, keys):
    for key in keys:
        pin_heights = get_pin_heights(key)
        for i in range(5):
            pin_height = pin_heights[i]
            key_set[i][pin_height].add(pin_heights)

In [None]:
fill_key_set(key_set, keys)
print(key_set)

Finding the pairs is now only a matter of calculating the intersection of all the necessary sets for each lock.

In [None]:
def find_nb_of_pairs(key_set, locks):
    total = 0
    for lock in locks:
        pin_heights = get_pin_heights(lock)
        matching_keys = None
        for i in range(5):
            pin_height = pin_heights[i]
            possible_matches = set()
            for j in range(6):
                if pin_height + j < 6:
                    possible_matches |= key_set[i][j]
            if matching_keys is None:
                matching_keys = possible_matches
            else:
                matching_keys &= possible_matches
        if matching_keys is not None:
            total += len(matching_keys)
    return total

In [None]:
def part_1(input_file_name):
    keys, locks = parse_input(input_file_name)
    key_set = build_key_set()
    fill_key_set(key_set, keys)
    result = find_nb_of_pairs(key_set, locks)
    print(result)

In [None]:
part_1(EXAMPLE)

In [None]:
part_1(INPUT)