In [None]:
import random
import numpy as np
import torch
import time
from collections import Counter

# ────────────────────────────────────────────────
# Device & constants
# ────────────────────────────────────────────────
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using PyTorch device: {DEVICE}")
if DEVICE == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")

IND_ROWS = 8
IND_COLS = 14

# 8-directional deltas (as tensor on device)
DELTAS = torch.tensor([
    [-1, -1], [-1, 0], [-1, 1],
    [0, -1],          [0, 1],
    [1, -1], [1, 0], [1, 1]
], dtype=torch.long, device=DEVICE)

# ────────────────────────────────────────────────
# Helper: precompute positions per digit on GPU
# ────────────────────────────────────────────────
def positions_per_digit(grid_t):
    pos = {}
    for d in range(10):
        mask = (grid_t == d)
        rs, cs = torch.nonzero(mask, as_tuple=True)
        if len(rs) > 0:
            pos[d] = torch.stack([rs, cs], dim=1)  # [N_positions, 2]
    return pos

# ────────────────────────────────────────────────
# Batched path check: GPU-accelerated multi-source BFS with visited cells
# ────────────────────────────────────────────────
def batch_has_path_torch(grid_np, numbers, precomputed_pos_dict=None):
    """
    grid_np: np.ndarray (8,14) int32/uint8
    numbers: list[int] or np.array
    precomputed_pos_dict: optional dict[int → torch.Tensor] for reuse (faster in GA)
    Returns: np.ndarray[int] (len(numbers),)  1 if path exists, 0 otherwise
    """
    if DEVICE == 'cpu':
        print("Warning: Falling back to CPU (slow)")
        return np.zeros(len(numbers), dtype=int)  # placeholder; replace with CPU DFS if needed

    grid_t = torch.from_numpy(grid_np.astype(np.int32)).to(DEVICE)

    # Use precomputed if provided (e.g., in GA per individual)
    if precomputed_pos_dict is not None:
        pos_dict = precomputed_pos_dict
    else:
        pos_dict = positions_per_digit(grid_t)

    results = torch.zeros(len(numbers), dtype=torch.int32, device=DEVICE)

    for i, n in enumerate(numbers):
        if n <= 0:
            results[i] = 1 if (n == 0 and 0 in pos_dict) else 0
            continue

        str_n = str(n)
        digits = torch.tensor([int(d) for d in str_n], device=DEVICE)
        len_path = len(digits)

        first_d = digits[0].item()
        if first_d not in pos_dict or len(pos_dict[first_d]) == 0:
            continue

        # Single digit: yes if exists
        if len_path == 1:
            results[i] = 1
            continue

        # Multi-digit: BFS level-by-level (per digit)
        visited = torch.zeros_like(grid_t, dtype=torch.bool)
        current_front = pos_dict[first_d].clone()  # All starting positions [N_starts, 2]
        visited[current_front[:, 0], current_front[:, 1]] = True

        found = False
        for depth in range(1, len_path):
            next_d = digits[depth].item()
            if next_d not in pos_dict:
                break

            # Vectorized expansion: broadcast current_front to 8 directions
            r_offsets = current_front[:, 0][:, None] + DELTAS[:, 0][None, :]
            c_offsets = current_front[:, 1][:, None] + DELTAS[:, 1][None, :]
            expanded_r = r_offsets.flatten()
            expanded_c = c_offsets.flatten()

            # Filter valid positions
            valid_mask = (
                (expanded_r >= 0) & (expanded_r < IND_ROWS) &
                (expanded_c >= 0) & (expanded_c < IND_COLS)
            )
            nr = expanded_r[valid_mask]
            nc = expanded_c[valid_mask]

            if len(nr) == 0:
                break

            # Check if neighbor matches next digit AND not visited
            is_next = (grid_t[nr, nc] == next_d)
            not_visited = ~visited[nr, nc]
            new_front_mask = is_next & not_visited

            if not new_front_mask.any():
                break  # No way to proceed

            # Update front to new positions
            new_r = nr[new_front_mask]
            new_c = nc[new_front_mask]
            current_front = torch.stack([new_r, new_c], dim=1)

            # Mark as visited (prevent reuse)
            visited[new_r, new_c] = True

            # If we reached the last digit level and have positions, path exists
            if depth == len_path - 1:
                found = True
                break

        if found:
            results[i] = 1

    return results.cpu().numpy()

# ────────────────────────────────────────────────
# Test function (adapted: precompute pos_dict once, test reverse numbers too)
# ────────────────────────────────────────────────
def run_test():
    np.random.seed(42)
    grid_np = np.random.randint(0, 10, size=(IND_ROWS, IND_COLS), dtype=np.int32)

    print("\n" + "=" * 70)
    print("TEST GRID (8x14):")
    print("=" * 70)
    for row in grid_np:
        print(' '.join(f"{x:1d}" for x in row))

    # Precompute positions once for speed (useful in GA too)
    grid_t = torch.from_numpy(grid_np.astype(np.int32)).to(DEVICE)
    pos_dict = positions_per_digit(grid_t)

    test_numbers = [1, 12, 123, 1234, 5678, 9999, 42, 100, 314159]
    # Add some reverses for testing reverse optimization
    test_numbers += [21, 321, 8765, 9999]  # e.g., reverse of 12, 123, 5678

    print("\nTesting individual numbers (including reverses):")
    results = batch_has_path_torch(grid_np, test_numbers, precomputed_pos_dict=pos_dict)
    for n, res in zip(test_numbers, results):
        print(f"{n:7d} → {'YES' if res == 1 else 'no'}")

    # Performance test: larger batch
    print("\nPerformance test:")
    consecutive_nums = list(range(1, 1000))  # larger for meaningful timing
    random_4digit = [random.randint(1000, 9999) for _ in range(1000)]
    all_nums = consecutive_nums + random_4digit

    t0 = time.perf_counter()
    _ = batch_has_path_torch(grid_np, all_nums, precomputed_pos_dict=pos_dict)
    t_gpu = time.perf_counter() - t0
    print(f"Time for {len(all_nums)} numbers (GPU): {t_gpu:.4f} seconds")
    print(f"Approx. checks/sec: {len(all_nums) / t_gpu:.0f}" if t_gpu > 0 else "N/A")

if __name__ == "__main__":
    print("Starting full test...")
    run_test()
    print("\nTest complete. Ready for GA integration (use precomputed_pos_dict per individual).")