In [1]:
from collections import defaultdict

In [11]:
class SparseImg:
    def __init__(self, inversed: bool):
        self.grid = set()
        self.inversed = inversed
        self.min_row = None
        self.max_row = None
        self.min_col = None
        self.max_col = None
        
    def __setitem__(self, key, value):
        if (self.inversed and value) or (not self.inversed and not value):
            return
        r, c = key
        if self.min_row is None or r < self.min_row:
            self.min_row = r
        if self.max_row is None or r > self.max_row:
            self.max_row = r
        if self.min_col is None or c < self.min_col:
            self.min_col = c
        if self.max_col is None or c > self.max_col:
            self.max_col = c
        self.grid.add(key)
    
    def __getitem__(self, key):
        return key in self.grid if not self.inversed else key not in self.grid
    
    def enhance(self, algo: str, enhanced_is_inversed: bool) -> 'SparseImg':
        new_img = self.__class__(inversed=enhanced_is_inversed)
        for r in range(self.min_row - 1, self.max_row + 2):
            for c in range(self.min_col - 1, self.max_col + 2):
                binary_idx = []
                for rr in range(r - 1 , r + 2):
                    for cc in range(c - 1, c + 2):
#                         print(rr, cc, self[(rr, cc)])
                        binary_idx.append('1' if self[(rr, cc)] else '0')
                decimal_idx = int(''.join(binary_idx), 2)
#                 print(decimal_idx)
                new_img[(r, c)] = algo[decimal_idx] == '#'
        return new_img
    
    @classmethod
    def parse(cls, lines, *args, **kwargs):
        input_img = cls(*args, **kwargs)
        for r, l in enumerate(lines):
            for c, char in enumerate(l):
                input_img[(r,c)] = char == '#'
        return input_img
    
    def __repr__(self):
        return str(self.grid)
    
    def __str__(self):
        iv = 'inversed' if self.inversed else 'normal'
        return iv + '\n' + '\n'.join([
            ''.join(['#' if self[(r, c)] else '.' for c in range(self.min_col-1, self.max_col + 2)]) 
            for r in range(self.min_row-1, self.max_row + 2)
        ])
    
    def __len__(self):
        return len(self.grid) if not self.inversed else float('inf')

# Example

In [4]:

with open('./day20_example_input.txt', 'r') as f:
    lines = [l.strip() for l in f.readlines() if l.strip()]

enhancement_algo, *input_img_str = lines
img = SparseImg.parse(input_img_str)
img_i = SparseImg.parse(input_img_str, inversed=True)
print('img')
print(str(img), end='\n\n')
print('img')
print(str(img_i), end='\n\n')

img
normal
.......
.#..#..
.#.....
.##..#.
...#...
...###.
.......

img
inversed
#######
##..#.#
##....#
###..##
#..#..#
#..####
#######



In [5]:
img.enhance(enhancement_algo, enhanced_is_inversed=False)

{(-1, 4), (-1, 0), (3, 0), (2, 1), (2, 5), (1, -1), (1, 2), (-1, 1), (3, 3), (1, 5), (2, 2), (0, 4), (4, 1), (5, 4), (4, 5), (2, -1), (5, 2), (4, 2), (1, 0), (-1, 3), (2, 0), (0, -1), (3, 4), (0, 2)}

In [6]:
img_i.enhance(enhancement_algo, enhanced_is_inversed=False)

{(3, 0), (2, 1), (0, 3), (1, -1), (1, 2), (-1, 1), (3, 3), (5, 5), (1, 5), (2, 2), (-1, -1), (3, 4), (4, 1), (5, -1), (5, 4), (4, 5), (5, 3), (0, 1), (3, 5), (0, -1), (5, 2), (0, 2)}

In [7]:
# enhancement_algo = '#' + enhancement_algo[1:]
# enhancement_algo

In [30]:
def ex():
    with open('./day20_example_input.txt', 'r') as f:
        lines = [l.strip() for l in f.readlines() if l.strip()]

    enhancement_algo, *input_img_str = lines
    inverse_algo = enhancement_algo.replace('#', '-').replace('.', '#')
    img = SparseImg.parse(input_img_str, inversed=False)
    print('img')
    print(str(img), end='\n\n')
    
    img2 = img.enhance(enhancement_algo, enhanced_is_inversed=False)
    print('img2')
    print(str(img2), end='\n\n')
    
    img3 = img2.enhance(enhancement_algo, enhanced_is_inversed=False)
    print('img3')
    print(str(img3), end='\n\n')
    
ex()

img
normal
.......
.#..#..
.#.....
.##..#.
...#...
...###.
.......

img2
normal
.........
..##.##..
.#..#.#..
.##.#..#.
.####..#.
..#..##..
...##..#.
....#.#..
.........

img3
normal
...........
........#..
..#..#.#...
.#.#...###.
.#...##.#..
.#.....#.#.
..#.#####..
...#.#####.
....##.##..
.....###...
...........



# Part 1

In [27]:
with open('./day20_input.txt', 'r') as f:
    lines = [l.strip() for l in f.readlines() if l.strip()]
    
enhancement_algo, *input_img_str = lines
img = SparseImg.parse(input_img_str, inversed=False)
img = img.enhance(enhancement_algo, enhanced_is_inversed=True)
img = img.enhance(enhancement_algo, enhanced_is_inversed=False)

In [28]:
len(img)

5571

# Part 2

In [25]:
with open('./day20_input.txt', 'r') as f:
    lines = [l.strip() for l in f.readlines() if l.strip()]
    
enhancement_algo, *input_img_str = lines
img = SparseImg.parse(input_img_str, inversed=False)

for i in range(50):
    img = img.enhance(enhancement_algo, enhanced_is_inversed=i % 2 == 0)


In [26]:
len(img)

17965