In [1]:
from collections import defaultdict

In [2]:
with open('input1.txt','r') as file:
    input_file = file.read()

## PART 1

In [3]:
test_values = """47|53
97|13
97|61
97|47
75|29
61|13
75|53
29|13
97|29
53|29
61|53
97|53
61|29
47|13
75|47
97|75
47|61
75|61
47|29
75|13
53|13

75,47,61,53,29
97,61,53,29,13
75,29,13
75,97,47,61,53
61,13,29
97,13,75,29,47"""


In [4]:
def _get_downstream_and_update_list(input_values:str) -> tuple[defaultdict[set], list[list[int]]]:

    order_rules, updates = input_values.split("\n\n")
    downstream_dependency_map = defaultdict(set)

    for ordering in order_rules.split("\n"):
        v1, v2 = ordering.split("|")
        downstream_dependency_map[int(v1)].add(int(v2))

    updates_list = [[int(elem) for elem in line.split(',')] for line in updates.split("\n")]

    return downstream_dependency_map, updates_list


def get_mid_elem_safe_updates(input_values:str) -> int:
    
    downstream_map, updates = _get_downstream_and_update_list(input_values)
    safe_update_list_mid_sum = 0

    for update_list in updates:

        is_safe_update = True

        for indx in range(len(update_list)):
            occured_updates = set(update_list[:indx])
            if occured_updates & downstream_map[update_list[indx]]: # basically intersection != [] i.e. if downstream updates have already occured for the current elem, NONONONONONO
                is_safe_update = False
                break

        if is_safe_update:
                safe_update_list_mid_sum += update_list[len(update_list) // 2]

    return safe_update_list_mid_sum

In [5]:
get_mid_elem_safe_updates(test_values)

143

In [6]:
get_mid_elem_safe_updates(input_file)

7074

## PART 2

Couldn't do it myself, Ouch!. Got a neat solution understanding from : [link](https://www.youtube.com/watch?v=LA4RiCDPUlI)

In [7]:
def _get_downstream_and_update_list(input_values:str) -> tuple[defaultdict[set], list[list[int]]]:

    order_rules, updates = input_values.split("\n\n")
    downstream_dependency_map = defaultdict(set)

    for ordering in order_rules.split("\n"):
        v1, v2 = ordering.split("|")
        downstream_dependency_map[int(v1)].add(int(v2))

    updates_list = [[int(elem) for elem in line.split(',')] for line in updates.split("\n")]

    return downstream_dependency_map, updates_list


def _fix_unsafe_update(unsafe_update, downstream_map) -> list[int]:
    downstream_map_local = defaultdict(int) # meant to hold relations between unsafe elements ONLY to keep indeg limited to only relavant values


    for elem in unsafe_update:
        downstream_map_local[elem] = set(unsafe_update) & downstream_map.get(elem, {})

    indeg = defaultdict(int)
    for k,v_set in downstream_map_local.items():
        for v in v_set:
            indeg[v] += 1

    safe_update = []

    while len(safe_update) < len(unsafe_update):
        for elem in unsafe_update:
            if elem in safe_update: continue
            if indeg[elem] <= 0:
                safe_update.append(elem)
                for downstream_elem in downstream_map_local[elem]:
                    indeg[downstream_elem] -= 1

    return safe_update


def _safeify_unsafe_updates(unsafe_updates: list[list[int]], downstream_map) -> int:
    running_sum = 0

    for unsafe_update in unsafe_updates:
        safe_update = _fix_unsafe_update(unsafe_update, downstream_map)
        running_sum += safe_update[len(safe_update) // 2]

    return running_sum


def get_mid_elem_unsafe_updates(input_values:str) -> int:
    downstream_map, updates = _get_downstream_and_update_list(input_values)

    unsafe_updates = []
    
    for update_list in updates:

        is_safe_update = True

        for indx in range(len(update_list)):
            occured_updates = set(update_list[:indx])
            if occured_updates & downstream_map[update_list[indx]]: # basically intersection != []
                is_safe_update = False
                unsafe_updates.append(update_list)
                break
    
    return _safeify_unsafe_updates(unsafe_updates, downstream_map)

In [8]:
get_mid_elem_unsafe_updates(test_values)

123

In [9]:
get_mid_elem_unsafe_updates(input_file)

4828