# 3

https://adventofcode.com/2023/day/3

## Part 1

In [1]:
from typing import List, Tuple
import re
import numpy as np

In [2]:
class Grid:
    def __init__(self, num_rows: int):
        self.rows = [None] * num_rows

    def add_symbol(self, i, j) -> None:
        row_min, row_max = max(0, i - 1), min(len(self) - 1, i + 1)
        col_min, col_max = max(0, j - 1), j + 1
        for r in range(row_min, row_max + 1):
            row = self.rows[r]
            if row is None:
                row = []
                self.rows[r] = row
            row.append((col_min, col_max))

    def is_adjacent(self, row_idx: int, span: Tuple[int, int]) -> bool:
        row_intervals = self.rows[row_idx]
        if row_intervals is None:
            return False
        out = any(
            self.intersects_intervals(row_intervals, i) for i in range(*span))
        return out
            
    def intersects_intervals(self, intervals: np.ndarray, i: int) -> bool:
        start, end = intervals.T
        return np.any((i >= start) & (i <= end))

    def optimize(self):
        for i, row in enumerate(self.rows):
            if row is None:
                continue
            self.rows[i] = np.array(row)

    def __len__(self) -> int:
        return len(self.rows)

In [3]:
test_lines = [
    '467..114..',
    '...*......',
    '..35..633.',
    '......#...',
    '617*......',
    '.....+.58.',
    '..592.....',
    '......755.',
    '...$.*....',
    '.664.598..',
]

In [4]:
reg_symbol = re.compile(r'([^\.\d])')
reg_num = re.compile(r'(\d+)')

def make_grid(lines: List[str]) -> Grid:
    grid = Grid(len(lines))
    for i, line in enumerate(lines):
        symbol_matches = reg_symbol.finditer(line)
        symbol_inds = [m.start() for m in symbol_matches]
        for j in symbol_inds:
            grid.add_symbol(i, j)
    grid.optimize()
    return grid


def sum_adjacent_nums(lines: List[str], grid: Grid) -> int:
    out = 0
    for i, line in enumerate(lines):
        num_matches = list(reg_num.finditer(line))
        nums = [int(m.group()) for m in num_matches]
        num_spans = [m.span() for m in num_matches]
        if len(num_spans) == 0:
            continue
        for num, span in zip(nums, num_spans):
            if grid.is_adjacent(i, span):
                out += num
    return out

In [5]:
def solve_1(lines: List[str]) -> int:
    grid = make_grid(lines)
    out = sum_adjacent_nums(lines, grid)
    return out

In [6]:
test_lines = [
    '467..114..',
    '...*......',
    '..35..633.',
    '......#...',
    '617*......',
    '.....+.58.',
    '..592.....',
    '......755.',
    '...$.*....',
    '.664.598..',
]
solve_1(test_lines)

4361

In [7]:
inputs_path = 'inputs/3.txt'

with open(inputs_path, 'r') as f:
    lines = [line.strip() for line in f.readlines()]

solve_1(lines)

532428

## Part 2

In [8]:
class Grid2:
    def __init__(self, num_rows: int):
        self.rows = [None] * num_rows
        self.row_nums = [None] * num_rows

    def add_num(self, num: int, i: int, span: Tuple[int, int]) -> None:
        row_min, row_max = max(0, i - 1), min(len(self) - 1, i + 1)
        start, end = span
        col_min, col_max = max(0, start - 1), end
        for r in range(row_min, row_max + 1):
            row = self.rows[r]
            if row is None:
                row = []
                self.rows[r] = row
            row.append((col_min, col_max))
        for r in range(row_min, row_max + 1):
            row = self.row_nums[r]
            if row is None:
                row = []
                self.row_nums[r] = row
            row.append(num)

    def get_adjacent_nums(self, i: int, j: int) -> bool:
        row_intervals = self.rows[i]
        if row_intervals is None:
            return False
        inds = self.interval_intersection_inds(row_intervals, j)
        out = self.row_nums[i][inds]
        return out
            
    def interval_intersection_inds(self, intervals: np.ndarray, i: int) -> bool:
        start, end = intervals.T
        mask = (i >= start) & (i <= end)
        inds = np.where(mask)[0]
        return inds

    def optimize(self):
        for i, row in enumerate(self.rows):
            if row is None:
                continue
            self.rows[i] = np.array(row)
        for i, row in enumerate(self.row_nums):
            if row is None:
                continue
            self.row_nums[i] = np.array(row)

    def __len__(self) -> int:
        return len(self.rows)
    

In [9]:
test_lines = [
    '467..114..',
    '...*......',
    '..35..633.',
    '......#...',
    '617*......',
    '.....+.58.',
    '..592.....',
    '......755.',
    '...$.*....',
    '.664.598..',
]

In [10]:
reg_symbol = re.compile(r'(\*)')
reg_num = re.compile(r'(\d+)')

def make_grid2(lines: List[str]) -> Grid2:
    grid = Grid2(len(lines))
    for i, line in enumerate(lines):
        num_matches = list(reg_num.finditer(line))
        nums = [int(m.group()) for m in num_matches]
        num_spans = [m.span() for m in num_matches]
        for num, span in zip(nums, num_spans):
            grid.add_num(num, i, span)
    grid.optimize()
    return grid


def calc_gear_ratio(lines: List[str], grid: Grid2) -> int:
    out = 0
    for i, line in enumerate(lines):
        symbol_matches = reg_symbol.finditer(line)
        symbol_inds = [m.start() for m in symbol_matches]
        if len(symbol_inds) == 0:
            continue
        for j in symbol_inds:
            adj_nums = grid.get_adjacent_nums(i, j)
            if len(adj_nums) == 2:
                out += np.prod(adj_nums)
    return out

In [11]:
def solve_2(lines: List[str]) -> int:
    grid = make_grid2(lines)
    out = calc_gear_ratio(lines, grid)
    return out

In [12]:
test_lines = [
    '467..114..',
    '...*......',
    '..35..633.',
    '......#...',
    '617*......',
    '.....+.58.',
    '..592.....',
    '......755.',
    '...$.*....',
    '.664.598..',
]
solve_2(test_lines)

467835

In [13]:
inputs_path = 'inputs/3.txt'

with open(inputs_path, 'r') as f:
    lines = [line.strip() for line in f.readlines()]

solve_2(lines)

84051670