In [None]:
import random
import numpy as np
import sys
import os
import multiprocessing
from collections import Counter
from deap import base, creator, tools
import torch

# Set multiprocessing start method to 'spawn' for CUDA compatibility
if __name__ == "__main__":
    try:
        multiprocessing.set_start_method('spawn')
    except RuntimeError:
        pass  # already set

# --- Device setup (CUDA is guaranteed) ---
DEVICE = 'cuda'
print(f"Using device: {DEVICE}")
print(f"GPU: {torch.cuda.get_device_name(0)}")

# --- Grid constants ---
IND_ROWS = 8
IND_COLS = 14
IND_SIZE = IND_ROWS * IND_COLS
INT_MIN, INT_MAX = 0, 9
FILENAME = 'best_output_double.txt'

creator.create("FitnessMax", base.Fitness, weights=(1.0, 1.0))
creator.create("Individual", list, fitness=creator.FitnessMax)
toolbox = base.Toolbox()

# --- PyTorch helpers ---
DELTAS = torch.tensor([
    [-1, -1], [-1, 0], [-1, 1],
    [0, -1],          [0, 1],
    [1, -1], [1, 0], [1, 1]
], dtype=torch.long, device=DEVICE)

def positions_per_digit(grid_t):
    pos = {}
    for d in range(10):
        rs, cs = torch.nonzero(grid_t == d, as_tuple=True)
        if len(rs) > 0:
            pos[d] = torch.stack([rs, cs], dim=1)
    return pos

def batch_has_path_torch(grid_np, numbers, precomputed_pos_dict=None):
    if not numbers:
        return np.array([], dtype=int)

    grid_t = torch.from_numpy(grid_np.astype(np.int32)).to(DEVICE)
    pos_dict = precomputed_pos_dict if precomputed_pos_dict is not None else positions_per_digit(grid_t)

    results = torch.zeros(len(numbers), dtype=torch.int32, device=DEVICE)

    for i, n_val in enumerate(numbers):
        if n_val <= 0:
            results[i] = 1 if (n_val == 0 and 0 in pos_dict) else 0
            continue

        str_n = str(n_val)
        digits = torch.tensor([int(d) for d in str_n], device=DEVICE)
        len_path = len(digits)

        first_d = digits[0].item()
        if first_d not in pos_dict or len(pos_dict[first_d]) == 0:
            continue

        if len_path == 1:
            results[i] = 1
            continue

        visited = torch.zeros_like(grid_t, dtype=torch.bool)
        current_front = pos_dict[first_d].clone()
        visited[current_front[:, 0], current_front[:, 1]] = True

        found = False
        for depth in range(1, len_path):
            next_d = digits[depth].item()
            if next_d not in pos_dict:
                break

            r = current_front[:, 0][:, None] + DELTAS[:, 0][None, :]
            c = current_front[:, 1][:, None] + DELTAS[:, 1][None, :]
            expanded_r = r.flatten()
            expanded_c = c.flatten()

            valid = ((expanded_r >= 0) & (expanded_r < IND_ROWS) &
                     (expanded_c >= 0) & (expanded_c < IND_COLS))
            nr = expanded_r[valid]
            nc = expanded_c[valid]

            if len(nr) == 0:
                break

            is_next = (grid_t[nr, nc] == next_d)
            not_visited = ~visited[nr, nc]
            new_front_mask = is_next & not_visited

            if not new_front_mask.any():
                break

            new_r = nr[new_front_mask]
            new_c = nc[new_front_mask]
            current_front = torch.stack([new_r, new_c], dim=1)

            visited[new_r, new_c] = True

            if depth == len_path - 1:
                found = True
                break

        if found:
            results[i] = 1

    return results.cpu().numpy()

# --- Global precomputation of representatives (1~10000, reverse duplicates removed) ---
REPRESENTATIVES = []
PAIR_DICT = {}  # rep -> list of covered numbers (forward + valid reverse)

seen = set()
for n in range(1, 10001):
    if n in seen:
        continue

    rev_str = str(n)[::-1]
    rev_n = int(rev_str) if not rev_str.startswith('0') else None

    covered = [n]
    if rev_n is not None and rev_n != n:
        covered.append(rev_n)
        seen.add(rev_n)

    seen.add(n)
    REPRESENTATIVES.append(n)
    PAIR_DICT[n] = covered

print(f"Precomputed {len(REPRESENTATIVES)} representatives for 1~10000 (reverse duplicates eliminated)")

# --- GPU-accelerated evaluation function ---
def eval_814_heuristic(individual):
    grid_np = np.array(individual).reshape(IND_ROWS, IND_COLS)
    MAX_N = 50000

    grid_t = torch.from_numpy(grid_np.astype(np.int32)).to(DEVICE)
    pos_dict = positions_per_digit(grid_t)  # precompute once per individual

    # One big batch for all 1~10000 representatives
    has_array = batch_has_path_torch(grid_np, REPRESENTATIVES, precomputed_pos_dict=pos_dict)

    found_set = set()
    for rep, h in zip(REPRESENTATIVES, has_array):
        if h == 1:
            found_set.update(PAIR_DICT[rep])

    # Compute current_score up to 10000
    current_score = 0
    for k in range(1, 10001):
        if k in found_set:
            current_score = k
        else:
            current_score = k - 1
            break
    else:
        # All 1~10000 found → continue to higher numbers
        n = 10001
        batch_size = 1000
        while n < MAX_N:
            end = min(n + batch_size, MAX_N)
            batch_nums = []
            for cand in range(n, end):
                rev_str = str(cand)[::-1]
                rev_cand = int(rev_str) if not rev_str.startswith('0') else None
                if cand in found_set or (rev_cand is not None and rev_cand in found_set):
                    found_set.add(cand)
                    if rev_cand is not None:
                        found_set.add(rev_cand)
                    current_score = cand
                    continue
                batch_nums.append(cand)

            if batch_nums:
                has = batch_has_path_torch(grid_np, batch_nums, precomputed_pos_dict=pos_dict)
                i = 0
                broke = False
                for cand in range(n, end):
                    if cand in found_set:
                        continue
                    h = has[i]
                    i += 1
                    if h == 1:
                        rev_str = str(cand)[::-1]
                        rev_cand = int(rev_str) if not rev_str.startswith('0') else None
                        found_set.add(cand)
                        if rev_cand is not None:
                            found_set.add(rev_cand)
                        current_score = cand
                    else:
                        current_score = cand - 1
                        broke = True
                        break
                if broke:
                    break
            else:
                current_score = end - 1

            n = end
        else:
            current_score = MAX_N - 1

    # Formable count: exact count of 1000~9999 in found_set (covered by representatives batch)
    formable_count = sum(1 for num in range(1000, 10000) if num in found_set)

    return float(current_score), float(formable_count)

# Register the accelerated evaluate
toolbox.register("evaluate", eval_814_heuristic)

# ────────────────────────────────────────────────
# The rest of your original code (unchanged)
# ────────────────────────────────────────────────
def custom_mate(ind1, ind2):
    grid1 = np.array(ind1).reshape(IND_ROWS, IND_COLS)
    grid2 = np.array(ind2).reshape(IND_ROWS, IND_COLS)
    s_y = random.randint(0, IND_ROWS - 1)
    e_y = random.randint(s_y + 1, IND_ROWS)
    s_x = random.randint(0, IND_COLS - 1)
    e_x = random.randint(s_x + 1, IND_COLS)
    temp = grid1[s_y:e_y, s_x:e_x].copy()
    grid1[s_y:e_y, s_x:e_x] = grid2[s_y:e_y, s_x:e_x]
    grid2[s_y:e_y, s_x:e_x] = temp
    ind1[:] = grid1.ravel().tolist()
    ind2[:] = grid2.ravel().tolist()
    return ind1, ind2

def custom_mutate(individual, indpb=0.05):
    if random.random() < 0.5:
        tools.mutUniformInt(individual, low=0, up=9, indpb=indpb)
    else:
        digits = list(range(10))
        random.shuffle(digits)
        a, b = digits[0], digits[1]
        for i in range(len(individual)):
            if individual[i] == a:
                individual[i] = b
            elif individual[i] == b:
                individual[i] = a
    return individual,

toolbox.register("mate", custom_mate)
toolbox.register("mutate", custom_mutate, indpb=0.05)

def _assign_crowding_dist_fallback(front):
    if not front:
        return
    if len(front) <= 2:
        for ind in front:
            ind.crowding_dist = float("inf")
        return
    for ind in front:
        ind.crowding_dist = 0.0
    nobj = len(front[0].fitness.values)
    for m in range(nobj):
        front.sort(key=lambda ind: ind.fitness.values[m])
        front[0].crowding_dist = float("inf")
        front[-1].crowding_dist = float("inf")
        fmin = front[0].fitness.values[m]
        fmax = front[-1].fitness.values[m]
        denom = fmax - fmin
        if denom == 0:
            continue
        for i in range(1, len(front) - 1):
            prev_f = front[i - 1].fitness.values[m]
            next_f = front[i + 1].fitness.values[m]
            front[i].crowding_dist += (next_f - prev_f) / denom

def update_crowding(population):
    fronts = tools.sortNondominated(population, k=len(population), first_front_only=False)
    assign = None
    if hasattr(tools, "emo") and hasattr(tools.emo, "assignCrowdingDist"):
        assign = tools.emo.assignCrowdingDist
    for front in fronts:
        if assign is not None:
            assign(front)
        else:
            _assign_crowding_dist_fallback(front)

toolbox.register("select", tools.selTournamentDCD)
toolbox.register("attr_int", random.randint, INT_MIN, INT_MAX)
toolbox.register("individual", tools.initRepeat, creator.Individual, toolbox.attr_int, IND_SIZE)
toolbox.register("population", tools.initRepeat, list, toolbox.individual)

def load_previous_best():
    loaded = []
    protected = None
    if os.path.exists(FILENAME):
        with open(FILENAME, 'r') as f:
            valid_lines = [line.strip() for line in f if len(line.strip()) == 14 and line.strip().isdigit()]
        for block_start in range(0, len(valid_lines), 8):
            if block_start + 7 >= len(valid_lines): break
            block = valid_lines[block_start:block_start + 8]
            ind_list = []
            for row in block:
                ind_list.extend(int(d) for d in row)
            if len(ind_list) == IND_SIZE:
                ind = creator.Individual(ind_list)
                if block_start == 0:
                    protected = ind
                else:
                    loaded.append(ind)
    return protected, loaded

def main():
    protected, loaded = load_previous_best()
    pop = []
    if protected: pop.append(protected)
    pop.extend(loaded)
    TARGET_POP = 300
    if len(pop) < TARGET_POP:
        pop.extend(toolbox.population(n=TARGET_POP - len(pop)))
    pop = pop[:TARGET_POP]

    # Serial evaluation
    fitnesses = list(map(toolbox.evaluate, pop))
    for ind, fit in zip(pop, fitnesses):
        ind.fitness.values = fit
    update_crowding(pop)

    CXPB, MUTPB = 0.5, 0.2
    NGEN = 100
    best_current_score = 0.0
    best_ind = None
    for g in range(1, NGEN + 1):
        offspring = toolbox.select(pop, len(pop))
        offspring = list(map(toolbox.clone, offspring))
        for child1, child2 in zip(offspring[::2], offspring[1::2]):
            if random.random() < CXPB:
                toolbox.mate(child1, child2)
                del child1.fitness.values, child2.fitness.values
        for mutant in offspring:
            if random.random() < MUTPB:
                toolbox.mutate(mutant)
                del mutant.fitness.values
        invalid_ind = [ind for ind in offspring if not ind.fitness.valid]
        fitnesses = list(map(toolbox.evaluate, invalid_ind))
        for ind, fit in zip(invalid_ind, fitnesses):
            ind.fitness.values = fit
        pop[:] = offspring
        update_crowding(pop)
        pop_max_v = [ind.fitness.values[0] for ind in pop]
        current_max = max(pop_max_v)
        if current_max > best_current_score:
            best_current_score = current_max
            best_ind = max(pop, key=lambda ind: ind.fitness.values[0])
        if g % 1 == 0:
            formable = [ind.fitness.values[1] for ind in pop]
            sys.stdout.write(f"--Generation {g}--\n")
            sys.stdout.write(
                f"Current Score (1+)\tMax: {current_max:>7.0f}\tAvg: {sum(pop_max_v) / len(pop):>7.1f}\n")
            sys.stdout.write(
                f"Formable Count\tMax: {max(formable):>7.0f}\tAvg: {sum(formable) / len(pop):>7.1f}\n")
    print("\n" + "=" * 60)
    print(f"Evolution finished after {NGEN} generations")

    if best_ind is not None:
        print("\n" + "=" * 80)
        print("GRID WITH THE HIGHEST CURRENT SCORE (consecutive from 1)")
        print("=" * 80)
        print(f"Current Score : {best_current_score:.0f}")
        print(f"Formable Count: {best_ind.fitness.values[1]:.0f}")
        grid = np.array(best_ind).reshape(IND_ROWS, IND_COLS)
        for row in grid:
            print(''.join(map(str, row)))
        flat = list(best_ind)
        counts = Counter(flat)
        print("\nDigit distribution:", ' '.join(f"{d}:{counts[d]}" for d in range(10)))
        with open(FILENAME, 'a') as f:
            for row in grid:
                f.write(''.join(map(str, row)) + '\n')
            f.write('\n')
        print("Highest current_score grid saved to file.")
    else:
        print("No valid individuals found to save.")

    top3 = tools.selBest(pop, k=3)
    for rank, ind in enumerate(top3, 1):
        curr, form = ind.fitness.values
        print(f"\nTop #{rank} Individual:")
        print(f" Current consecutive score : {curr:>7.0f}")
        print(f" Formable count (1000-9999): {form:>7.0f}")
        grid = np.array(ind).reshape(IND_ROWS, IND_COLS)
        print(" Grid:")
        for row in grid:
            print(' ' + ''.join(map(str, row)))
        flat = list(ind)
        counts = Counter(flat)
        print(" Digit distribution:", ' '.join(f"{d}:{counts[d]}" for d in range(10)))
        print("-" * 50)

    with open(FILENAME, 'a') as f:
        grid = np.array(top3[0]).reshape(IND_ROWS, IND_COLS)
        for row in grid:
            f.write(''.join(map(str, row)) + '\n')
        f.write('\n')

if __name__ == "__main__":
    main()