In [None]:
from tabulate import tabulate

EXAMPLE = "../example.txt"
INPUT = "../input.txt"

In [None]:
class Matrix:
    def __init__(self, contents: list[list[str]]):
        self.contents = contents
        self.height = len(self.contents)
        self.width = len(self.contents[0])
        self.letter_positions = {'X': set(), 'A': set()}

    @classmethod
    def build(cls, input_file_name):
        matrix = []
        with open(input_file_name, 'r') as f:
            for line in f:
                matrix.append([c for c in line.strip().replace("\n", "")])
        return cls(matrix)
    
    def find_letter(self, letter: str):
        if letter not in self.letter_positions:
            return
        for i in range(self.height):
            for j in range(self.width):
                if self.contents[i][j] == letter:
                    self.letter_positions[letter].add((i, j))

    def check_for_xmas(self, start_i, start_j) -> int:
        if self.contents[start_i][start_j] != 'X':
            return 0
        candidates = []
        candidates.append("".join([self.contents[start_i][start_j+k] for k in range(0, 4) if 0 <= start_j+k < self.width]))
        candidates.append("".join([self.contents[start_i][start_j-k] for k in range(0, 4) if 0 <= start_j-k < self.width]))
        candidates.append("".join([self.contents[start_i+k][start_j] for k in range(0, 4) if 0 <= start_i+k < self.height]))
        candidates.append("".join([self.contents[start_i-k][start_j] for k in range(0, 4) if 0 <= start_i-k < self.height]))
        candidates.append("".join([self.contents[start_i+k][start_j+k] for k in range(0, 4) if 0 <= start_i+k < self.height and 0 <= start_j+k < self.width]))
        candidates.append("".join([self.contents[start_i+k][start_j-k] for k in range(0, 4) if 0 <= start_i+k < self.height and 0 <= start_j-k < self.width]))
        candidates.append("".join([self.contents[start_i-k][start_j+k] for k in range(0, 4) if 0 <= start_i-k < self.height and 0 <= start_j+k < self.width]))
        candidates.append("".join([self.contents[start_i-k][start_j-k] for k in range(0, 4) if 0 <= start_i-k < self.height and 0 <= start_j-k < self.width]))
        return candidates.count("XMAS")
    
    def check_for_mas_x(self, start_i, start_j) -> int:
        if self.contents[start_i][start_j] != 'A':
            return 0
        diag_0 = "".join([self.contents[start_i+k][start_j+k] for k in range(-1, 2) if 0 <= start_i+k < self.height and 0 <= start_j+k < self.width])
        diag_1 = "".join([self.contents[start_i-k][start_j+k] for k in range(-1, 2) if 0 <= start_i-k < self.height and 0 <= start_j+k < self.width])
        return 1 if diag_0 in ["MAS", "SAM"] and diag_1 in ["MAS", "SAM"] else 0
    
    def count_xmases(self) -> int:
        self.find_letter('X')
        result = 0
        for i, j in self.letter_positions['X']:
            result += self.check_for_xmas(i, j)
        return result
    
    def count_mas_xes(self) -> int:
        self.find_letter('A')
        result = 0
        for i, j in self.letter_positions['A']:
            result += self.check_for_mas_x(i, j)
        return result


In [None]:
matrix = Matrix.build(EXAMPLE)
print(tabulate(matrix.contents))

In [None]:
matrix.find_letter('X')
print(matrix.letter_positions['X'])

In [None]:
for i, j in matrix.letter_positions['X']:
    print(matrix.check_for_xmas(i, j))

In [None]:
print(matrix.count_xmases())

In [None]:
def part_1(input_file_name):
    matrix = Matrix.build(input_file_name)
    print(matrix.count_xmases())

In [None]:
part_1(EXAMPLE)

In [None]:
part_1(INPUT)

In [None]:
def part_2(input_file_name):
    matrix = Matrix.build(input_file_name)
    print(matrix.count_mas_xes())

In [None]:
part_2(EXAMPLE)

In [None]:
part_2(INPUT)