In [None]:
import numpy as np
from matplotlib import pyplot as plt
import random
import math
from tqdm import tqdm
import imageio

COLORS = [
    [0, 0, 0],
    [0, 140, 0],
    [0, 170, 0],
    [0, 190, 0],
    [0, 200, 0],
    [0, 210, 0],
    [0, 230, 0],
    [0, 240, 0],
    [220, 220, 0],
    [230, 230, 0],
    [240, 240, 0],
    [0, 0, 220],
    [0, 0, 200],
    [0, 0, 140],
    [0, 0, 120],
    [0, 0, 100],
]
COLOR_MAP = {i-1: COLORS[i] for i in range(len(COLORS))}
NUM_COLORS = len(COLOR_MAP) - 1

class WFCMap():
    def __init__(self, size: int):
        self.size = size
        
        self.grid = np.full((self.size, self.size), -1, dtype=int)
        self.possible = np.array([[set(range(NUM_COLORS)) for _ in range(self.size)] for _ in range(self.size)])
        

    def compatible_colors(self, color):
        return {c for c in range(NUM_COLORS) if abs(c - color) <= 1}

    def neighbors(self, i, j):
        """4-directional neighbors inside the grid."""
        for di, dj in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
            ni, nj = i + di, j + dj
            if 0 <= ni < self.size and 0 <= nj < self.size:
                yield ni, nj
                
    def propagate(self, i, j):
        self.possible[i, j] = {self.grid[i, j]}
        for ni, nj in list(self.neighbors(i, j)):
            if self.grid[ni, nj] != -1:
                continue
            sub_neighbors = self.neighbors(ni, nj)
            collapsed_colors = [self.grid[x, y] for x, y in sub_neighbors if self.grid[x, y] != -1]
            compatible = set.intersection(*(self.compatible_colors(c) for c in collapsed_colors)) if collapsed_colors else set(range(NUM_COLORS))
            self.possible[ni, nj] = compatible
            
    def collapse_next_cell(self, chunk_size: int, chunk:tuple[int, int]):
        min_entropy = float('inf')
        min_cells = []
        for i in range(chunk[0]*chunk_size, min((chunk[0]+1)*chunk_size, self.size)):
            for j in range(chunk[1]*chunk_size, min((chunk[1]+1)*chunk_size, self.size)):
                if self.grid[i, j] == -1:
                    entropy = len(self.possible[i, j])
                    if entropy < min_entropy:
                        min_entropy = entropy
                        min_cells.append((i, j))
                    elif entropy == min_entropy:
                        min_cells.append((i, j))
        if min_cells is []:
            return -1, -1
        
        i, j = random.choice(min_cells)
        cell_possible_colors = list(self.possible[i, j])
        weights = [0.4, 0.2, 0.4]
        weights = weights[:len(cell_possible_colors)]
        chosen_color = random.choices(cell_possible_colors, weights=weights, k=1)[0]
        self.grid[i, j] = chosen_color
        return i, j
    
    def collapse(self, chunk_size:int=16, time_between_snapshots:int=1):
        chunk_length = math.ceil(self.size / chunk_size)
        prog_bar = tqdm(total=chunk_length**2, desc="Collapsing chunks")
        
        self.snapshots = []
        snapshot_counter = 0
        
        for chunk_i in range(chunk_length):
            for chunk_j in range(chunk_length):
                chunk_snapshot = self.grid[chunk_i*chunk_size : min((chunk_i+1)*chunk_size, self.size),
                                           chunk_j*chunk_size : min((chunk_j+1)*chunk_size, self.size)].copy()
                chunk_possible_snapshot = self.possible[chunk_i*chunk_size : min((chunk_i+1)*chunk_size, self.size),
                                                         chunk_j*chunk_size : min((chunk_j+1)*chunk_size, self.size)].copy()
                prog_bar.set_postfix({'chunk': f'({chunk_i},{chunk_j})'})
                while True:
                    try:
                        for _ in range(chunk_size**2):
                            i, j = self.collapse_next_cell(chunk_size, (chunk_i, chunk_j))
                            self.propagate(i, j)
                            
                            if snapshot_counter >= time_between_snapshots:
                                self.snapshots.append(self.grid.copy())
                                snapshot_counter = 0
                            snapshot_counter += 1
                        break
                    except Exception as e:
                        # print(f"chunk({chunk_i},{chunk_j}) | Non Convergant - Retying chunk...")
                        # Reset chunk
                        self.grid[chunk_i*chunk_size : min((chunk_i+1)*chunk_size, self.size),
                                  chunk_j*chunk_size : min((chunk_j+1)*chunk_size, self.size)] = chunk_snapshot
                        self.possible[chunk_i*chunk_size : min((chunk_i+1)*chunk_size, self.size),
                                      chunk_j*chunk_size : min((chunk_j+1)*chunk_size, self.size)] = chunk_possible_snapshot
                prog_bar.update(1)
                self.snapshots.append(self.grid.copy())
        prog_bar.close()

    def convert_to_rgb(self, img=None):
        if img is None:
            img = self.grid.copy()
        rgb_img = np.zeros((img.shape[0], img.shape[1], 3), dtype=np.uint8)
        for i in range(img.shape[0]):
            for j in range(img.shape[1]):
                rgb_img[i, j] = COLOR_MAP[img[i, j]]
                
        return rgb_img

    def view_image(self):
        plt.imshow(self.convert_to_rgb())
        plt.axis('off')
        plt.show()
        
m = WFCMap(size=16*4)
m.collapse(chunk_size=16, time_between_snapshots=64*6)

imageio.mimwrite('wfc_output.gif', [m.convert_to_rgb(img) for img in m.snapshots], fps=10)
m.view_image()