In [1]:
import numpy as np
from numba import jit
import time

# Step 1

In [2]:
@jit(nopython=True)
def cleanup(result, n):
        if len(result) == 0:
            return [[0, 0]] # Numba bs for empty lists

        if len(result) == 1 and result[0][1] - result[0, 0] == n:
            return [[0, 0]] # If the entire row is zeros, skip it (or we get into trouble later)

        to_return = [[entry[0], entry[1] - entry[0]] for entry in result]

        if len(to_return) > 1:
            if (to_return[0][0] == 0) and (to_return[-1][0] + to_return[-1][1] == n):
                to_return[-1][1] += to_return[0][1]
                to_return.pop(0)
        return to_return

@jit(nopython=True)
def count(a):
    assert a.shape[0] == a.shape[1]
    n = a.shape[0]

    row_results = []
    col_results = []

    padded = np.zeros((n+2, n+2))
    padded[1:-1, 1:-1] = np.logical_not(a)

    row_switch = np.abs(padded[1:-1, 1:] - padded[1:-1, :-1])
    col_switch = np.abs(padded[1:, 1:-1] - padded[:-1, 1:-1])

    for i in range(n):
        row_results.append(np.where(row_switch[i] == 1)[0].reshape(-1, 2))
        col_results.append(np.where(col_switch[:, i] == 1)[0].reshape(-1, 2))

    row_results_cleanup = [] # More Numba bs because it can't overwrite entries
    col_results_cleanup = []

    for i in range(n):
        row_results_cleanup.append(cleanup(row_results[i], n))
        col_results_cleanup.append(cleanup(col_results[i], n))
    
    return row_results_cleanup, col_results_cleanup


def find_largest_number_of_zeros(count_results):
    largest = []
    maximum = 0
    for row_col in count_results:
        row_col = np.array(row_col)
        largest.append(row_col[row_col[:, 1] == np.max(row_col[:, 1])])
        maximum = largest[-1][0, 1] if largest[-1][0, 1] > maximum else maximum

    to_return = []

    for row_col_index in range(len(largest)):
        if largest[row_col_index][0, 1] == maximum:
            to_return.append((row_col_index, largest[row_col_index][:, 0]))

    return to_return

# Step 2

In [3]:
def array_to_number(array):
    return np.dot(array[::-1], 2**np.arange(len(array)))

def take_lowest_largest_number_of_zeros(a, largest_number_of_zeros, row_col="row"):
    n = a.shape[0]
    lowest = 2**n-1
    results = []

    for candidate_tuple in largest_number_of_zeros:
        row_col_index, candidate = candidate_tuple
        array = a[row_col_index] if row_col == "row" else a[:, row_col_index]
        for entry in candidate:
            number = array_to_number(np.roll(array, -entry))
            if number < lowest:
                lowest = number
                results = [(row_col_index, entry)]
            elif number == lowest:
                results.append((row_col_index, entry))

    return results

# Steps 3 and 4

In [12]:
def test_neighbours(a, lowest_largest_number_results, row_col="row"):
    n = a.shape[0]
    rolled_as = [np.roll(a, -entry[1], axis= 1 if row_col == "row" else 0) for entry in lowest_largest_number_results]
    skipping = []
    for i in range(n-1):
        lowest = 2**n-1
        results = []
        for j, a in enumerate(rolled_as):
            if j in skipping:
                continue
            row_col_index = lowest_largest_number_results[j][0]
            array = a[row_col_index-i] if row_col == "row" else a[:, row_col_index-i]
            number = array_to_number(array)
            if number < lowest:
                lowest = number
                results = [j]
            elif number == lowest:
                results.append(j)
            else:
                skipping.append(j)
        if len(results) == 1:
            return [lowest_largest_number_results[results[0]]]
    return lowest_largest_number_results

# Step 5

In [5]:
def bruteforce(a):
    n = a.shape[0]
    smallest_matrix = a.copy()
    smallest_matrix_string = "".join(map(str, a.flatten()))
    for i in range(n):
        for j in range(n): 
            a = np.roll(a, 1, axis=1)
            a_string = "".join(map(str, a.flatten()))
            if a_string < smallest_matrix_string:
                smallest_matrix = a.copy()
                smallest_matrix_string = a_string
        a = np.roll(a, 1, axis=0)
    return smallest_matrix

# Combined

In [6]:
def preprocess(a):
    row_shift = 0
    col_shift = 0

    row, col = count(a)
    largest_number_of_zeros_row = find_largest_number_of_zeros(row)
    lowest_largest_number_results_row = take_lowest_largest_number_of_zeros(a, largest_number_of_zeros_row, "row")
    if len(lowest_largest_number_results_row) > 1:
        neighbour_results_row = test_neighbours(a, lowest_largest_number_results_row, "row")
        if len(neighbour_results_row) > 1:
            return bruteforce(a)
        else:
            row_shift = -neighbour_results_row[0][1]
    else:
        row_shift = -lowest_largest_number_results_row[0][1]

    largest_number_of_zeros_col = find_largest_number_of_zeros(col)
    lowest_largest_number_results_col = take_lowest_largest_number_of_zeros(a, largest_number_of_zeros_col, "col")
    if len(lowest_largest_number_results_col) > 1:
        neighbour_results_col = test_neighbours(a, lowest_largest_number_results_col, "col")
        if len(neighbour_results_col) > 1:
            return bruteforce(a)
        else:
            col_shift = -neighbour_results_col[0][1]
    else:
        col_shift = -lowest_largest_number_results_col[0][1]

    return np.roll(np.roll(a, col_shift, axis=0), row_shift, axis=1)


# Testing

In [13]:
n = 4

def random_row_col(a):
    shift_row, shift_col = np.random.randint(0, a.shape[0]), np.random.randint(0, a.shape[1])
    return np.roll(np.roll(a, shift_row, axis=0), shift_col, axis=1)

done = 0
for o in range(1000):
    a = np.random.choice([0, 1], (n, n))
    min_a = preprocess(a)
    for p in range(1000):
        b = random_row_col(a)
        assert np.all(min_a == preprocess(b)), "\n" + str(min_a) +"\n" + str(b) + "\n" + str(preprocess(b))
        done += 1
    print(f"Progress: {done/1000**2*100:0.2f}%\r", end="")


Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skipped lol
Skip

KeyboardInterrupt: 

In [None]:
start = time.time()

for _ in range(10000):
    a = np.random.choice([0, 1], (10, 10))
    preprocess(a)

print(f"{(time.time() - start)*1000/10000:0.3f} ms per run")

start = time.time()

for _ in range(10000):
    a = np.random.choice([0, 1], (10, 10))
    bruteforce(a)

print(f"{(time.time() - start)*1000/10000:0.3f} ms per run")

0.516 ms per run
3.899 ms per run
