In [1]:
import math
from itertools import product

with open("input.txt", "r") as f:
    input = f.read()

input2 = """............
........0...
.....0......
.......0....
....0.......
......A.....
............
............
........A...
.........A..
............
............"""

input = [list(line) for line in input.split("\n")]

In [2]:
# get all antennas and their locations
def get_antenna_dict(input: list[list[str]]) -> dict:
    d = {}
    for i, row in enumerate(input):
        for j, cell in enumerate(row):
            if cell == ".": continue
            if not d.get(cell):
                d[cell] = [(i,j)]
            else:
                d[cell].append((i,j))
    return d

def get_distance(node1: tuple, node2: tuple) -> float:
    return math.sqrt((node1[0]-node2[0])**2 + (node1[1]-node2[1])**2)

def get_dir(node1: tuple, node2: tuple) -> tuple:
    dist = get_distance(node1, node2)
    return (node2[1]-node1[1])/dist , (node2[0]-node1[0])/dist

def tuples_equals(tup1, tup2):
    return abs(tup1[0] - tup2[0]) < 1e-6 and abs(tup1[1] - tup2[1]) < 1e-6

def is_antinode(antennas: dict, node: tuple, include_distance):
    for freq, locs in antennas.items():
        for loc_pairs in product(locs, repeat=2):
            if node in loc_pairs and not include_distance: 
                return True
            
            if loc_pairs[0] == loc_pairs[1]: continue
            
            distance_condition = True
            if include_distance:
                distances = sorted([get_distance(node, loc_pairs[0]), get_distance(node, loc_pairs[1])])
                if not all(distances): 
                    continue
                distance_condition = (2*distances[0] == distances[1])
            
            directions = [get_dir(node, loc_pairs[0]), get_dir(node, loc_pairs[1])]
            
            if distance_condition and (tuples_equals(directions[0], directions[1]) or tuples_equals(directions[0], (-directions[1][0], -directions[1][1]))): 
                return True
    return False

def all_anitnodes(input, antennas, include_distance=True) -> list:
    antinodes = []
    for i, row in enumerate(input):
        for j, cell in enumerate(row):
            if is_antinode(antennas, (i,j), include_distance=include_distance):
                antinodes.append((i,j))
    return antinodes
            
antennas = get_antenna_dict(input)
antinodes = all_anitnodes(input, antennas)
len(antinodes)

280

### Part 2 

In [3]:
antinodes_ext = all_anitnodes(input, antennas, include_distance=False)
len(antinodes_ext)

958