In [None]:
import numpy as np
import jupyter_black
from tqdm import trange


jupyter_black.load(lab=False)

In [None]:
with open("day15.txt", "r") as f:
    data = f.read()

In [None]:
class Range:
    def __init__(self, start, end):
        assert end >= start
        self.start = start
        self.end = end

    def _does_overlap(self, other):
        return (self.start <= other.end) and (other.start <= self.end)

    def __sub__(self, other):
        if (self.start >= other.start) and (self.end <= other.end):
            return []
        elif (other.start >= self.end) or (other.end <= self.start):
            return [self]
        elif (self.start < other.start) and (self.end > other.end):
            return [Range(self.start, other.start), Range(other.end, self.end)]
        elif self.start < other.start:
            return [Range(self.start, other.start)]
        elif self.end > other.end:
            return [Range(other.end, self.end)]

    def __repr__(self):
        return f"Range({self.start}, {self.end})"

    def __eq__(self, other):
        return (self.start == other.start) and (self.end == other.end)


assert (Range(1, 10) - Range(1, 8)) == [Range(8, 10)]
assert (Range(1, 10) - Range(11, 20)) == [Range(1, 10)]
assert (Range(1, 10) - Range(-7, 1)) == [Range(1, 10)]
assert (Range(1, 10) - Range(3, 8)) == [Range(1, 3), Range(8, 10)]
assert (Range(1, 10) - Range(1, 8)) == [Range(8, 10)]
assert (Range(1, 10) - Range(3, 8)) == [Range(1, 3), Range(8, 10)]
assert (Range(1, 10) - Range(1, 10)) == []
assert (Range(1, 10) - Range(2, 10)) == [Range(1, 2)]

In [None]:
line = data.split("\n")[0]


def parse_line(line):
    splitted = line.split(" ")
    x = int(splitted[2].split("=")[-1].rstrip(", "))
    y = int(splitted[3].split("=")[-1].strip(":"))
    b_x = int(splitted[-2].split("=")[-1].rstrip(", "))
    b_y = int(splitted[-1].split("=")[-1].strip(":"))
    return np.array([x, y]), np.array([b_x, b_y])


def manhattan(a, b):
    return np.sum(np.abs(sensors - beacons), axis=1)


def get_points_within_manh_dist_at_mod_y(dist, mod_y, x):
    mod_x_range = dist - abs(mod_y)
    return set([mod_x + x for mod_x in range(-mod_x_range, mod_x_range + 1)])


def get_range_within_manh_dist_at_mod_y(dist, mod_y, x):
    mod_x_range = dist - abs(mod_y)
    if mod_x_range > 0:
        return [Range(x - mod_x_range, x + mod_x_range + 1)]
    else:
        return []


def remove_range(target_ranges, rng):
    result = []
    for target_range in target_ranges:
        result += target_range - rng
    return result


def remove_entries(sensors, beacons, dist, target_ranges, target_row_idx):
    ranges = []
    for sensor, d in zip(sensors, dist):
        if sensor[1] == target_row_idx:
            ranges += [Range(sensor[0], sensor[0] + 1)]
        rng_list = get_range_within_manh_dist_at_mod_y(
            d, target_row_idx - sensor[1], sensor[0]
        )
        ranges += rng_list

    for beacon in beacons:
        if beacon[1] == target_row_idx:
            ranges += [Range(beacon[0], beacon[0] + 1)]

    for rng in ranges:
        target_ranges = remove_range(target_ranges, rng)

    return target_ranges


def pre_filter_sensors(sensors, dist, search_space_x, search_space_y):
    fltr = (
        (sensors[:, 0] - dist < search_space_x[1])
        & (sensors[:, 0] + dist > search_space_x[0])
        & (sensors[:, 1] - dist < search_space_y[1])
        & (sensors[:, 1] + dist > search_space_y[0])
    )

    return sensors[fltr], dist[fltr]

In [None]:
locations = [parse_line(line) for line in data.split("\n")]
sensors = np.array([sensor for sensor, _ in locations])
beacons = np.array([beacon for _, beacon in locations])
dist = manhattan(sensors, beacons)

In [None]:
target_row_idx = 2000000
target_row_entries = dict()

for sensor, d in zip(sensors, dist):
    if sensor[1] == target_row_idx:
        target_row_entries[sensor[0]] = "S"
    for x in get_points_within_manh_dist_at_mod_y(
        d, target_row_idx - sensor[1], sensor[0]
    ):
        target_row_entries[x] = "#"

for beacon in beacons:
    if beacon[1] == target_row_idx:
        target_row_entries[beacon[0]] = "B"

len([item for _, item in target_row_entries.items() if item == "#"])

In [None]:
search_size = 4_000_000

search_space_x = np.array([0, search_size])
search_space_y = np.array([0, search_size])

sensors, dist = pre_filter_sensors(sensors, dist, search_space_x, search_space_y)


for target_row_idx in trange(search_space_y[0], search_space_y[1]):
    target_row_entries = [Range(search_space_x[0], search_space_x[1])]

    final_range = remove_entries(
        sensors, beacons, dist, target_row_entries, target_row_idx
    )

    if len(final_range) > 0:
        print(f"result: {(final_range[0].start) * 4000000 + target_row_idx}")
        break