In [1]:
from tqdm import tqdm

In [2]:
filename = "sample.txt"
# filename = "input.txt"
with open(filename, encoding="utf-8") as f:
    data = f.read()

grid = data.strip().split("\n")

https://adventofcode.com/2024/day/4

In [3]:
## Part 1
# How many times does XMAS appear?
#  forwards, backwards, vertically, diagonally all accepted
# directions: N E S W, NE SE SW NW
step_directions = [x + y for x in (1, 0, -1) for y in (1j, 0, -1j) if (x + y) != 0]
step_directions

[(1+1j), 1, (1-1j), 1j, -1j, (-1+1j), -1, (-1-1j)]

In [4]:
def is_xmas(s: str, *, xmas="XMAS") -> bool:
    return s in {xmas, xmas[::-1]}

def line_from(pos: complex, direction: complex, line_length: int = 4) -> list[complex]:
    return [pos + (direction * i) for i in range(line_length)]

def lines_from(pos: complex, line_length: int = 4) -> list[list[complex]]:
    return [line_from(pos, d, line_length) for d in step_directions]

def letter_at(pos: complex, grid):
    return grid[int(pos.imag)][int(pos.real)]

def line_is_xmas(line: list[complex], grid, *, xmas="XMAS") -> bool:
    # Disallow lines with negative indices (will incorrectly wrap around)
    if any((pos.imag < 0) or (pos.real < 0) for pos in line):
        return False

    try:
        letters = [letter_at(pos, grid) for pos in line]
        return is_xmas("".join(letters), xmas=xmas)
    except IndexError:
        # Not in bounds
        return False

In [5]:
def print_xmas_map(grid, keep_map: set[complex], fill_char="."):
    for y in range(len(grid)):
        for x in range(len(grid[0])):
            pos = x + y * 1j
            if pos in keep_map:
                c = letter_at(pos, grid)
            else:
                c = fill_char
            print(c, end="")
        print()

In [6]:
# Find all the Xs
# From each X, check each direction
# (if drawing a map, mark TRUE for each index included)
xmas_score = 0
xmas_map = set() # For fun and visualisation

for y in tqdm(range(len(grid))):
    for x in range(len(grid[0])):
        if grid[y][x] == "X":
            pos = x + y * 1j
            for line in lines_from(pos):
                if line_is_xmas(line, grid):
                    # print(f"{line} is XMAS!")
                    xmas_score += 1
                    xmas_map.update(line)

xmas_score

100%|██████████| 10/10 [00:00<00:00, 14716.86it/s]


18

In [7]:
print_xmas_map(grid, xmas_map)

....XXMAS.
.SAMXMS...
...S..A...
..A.A.MS.X
XMASAMX.MM
X.....XA.A
S.S.S.S.SS
.A.A.A.A.A
..M.M.M.MM
.X.X.XMASX


In [8]:
## Part 2:
# Find MAS in an X instead. Any direction of MAS accepted as long as there are 2 MAS
# M.S
# .A.
# M.S
cross_lines = [(-1-1j, 0, 1+1j), (1-1j, 0, -1+1j)]

def cross_at(pos: complex) -> list[list[complex]]:
    return [[pos + step for step in diagonal] for diagonal in cross_lines]

mas_score = 0
mas_map = set()
for y in tqdm(range(len(grid))):
    for x in range(len(grid[0])):
        if grid[y][x] == "A":
            pos = x + y * 1j
            cross = cross_at(pos)
            if all(line_is_xmas(diagonal, grid, xmas="MAS") for diagonal in cross):
                # print(f"{cross} is X-MAS!")
                mas_score += 1
                mas_map.update(pos for diagonal in cross for pos in diagonal)

mas_score

100%|██████████| 10/10 [00:00<00:00, 30885.89it/s]


9

In [9]:
print_xmas_map(grid, mas_map)

.M.S......
..A..MSMS.
.M.S.MAA..
..A.ASMSM.
.M.S.M....
..........
S.S.S.S.S.
.A.A.A.A..
M.M.M.M.M.
..........
