In [1]:
import numpy as np

input_file = "data/input.txt"

memory = {}

def tilt_north(matrix):
    Ny, Nx = matrix.shape
    for col in range(Nx):
        mat_col = matrix[:,col]
        hashes = [i for i, c in enumerate(mat_col) if c == '#']
        prev = 0
        for hash in [*hashes, Ny+1]:
            free = mat_col[prev:hash]
            sorted_free = sorted(free, key = lambda x: x != 'O')
            matrix[prev:hash, col] = sorted_free
            prev = hash + 1
    return matrix

def tilt_west(matrix):
    matrix = np.rot90(matrix, k=3)
    matrix = tilt_north(matrix)
    matrix = np.rot90(matrix, k=1)
    return matrix

def tilt_south(matrix):
    matrix = np.rot90(matrix, k=2)
    matrix = tilt_north(matrix)
    matrix = np.rot90(matrix, k=2)
    return matrix

def tilt_east(matrix):
    matrix = np.rot90(matrix, k=1)
    matrix = tilt_north(matrix)
    matrix = np.rot90(matrix, k=3)
    return matrix

def spin(matrix):
    matrix = tilt_north(matrix)
    matrix = tilt_west(matrix)
    matrix = tilt_south(matrix)
    matrix = tilt_east(matrix)
    return matrix

def get_key(matrix):
    return " ".join([" ".join(row) for row in matrix])

def get_load(matrix):
    Ny, _ = matrix.shape
    total_load = 0
    for row in range(Ny):
        total_load += sum([(Ny - row) for c in matrix[row,:] if c == 'O'])
    return total_load

with open(input_file, 'r') as f:
    lines = [l.strip() for l in f.readlines()]
    matrix = np.array([[c for c in line] for line in lines], dtype=str)
    
    ans1 = get_load(tilt_north(matrix))

    MAX_ROUNDS = 1000000000
    for round in range(1, MAX_ROUNDS):
        matrix = spin(matrix)
        mat_key = get_key(matrix)
        if mat_key in memory:
            recurrence_start = memory[mat_key]
            recurrence_interval = round - recurrence_start
            n_cycles_left = (MAX_ROUNDS - round)
            diff = n_cycles_left % recurrence_interval
            for _ in range(diff):
                matrix = spin(matrix)
            break
        memory[mat_key] = round

    ans2 = get_load(matrix)

    print(f"{ans1 = }")
    print(f"{ans2 = }")


ans1 = 107430
ans2 = 96317
