In [74]:
import copy
import random
from collections import namedtuple


OFFSPRING_SIZE = 3

NUM_GENERATIONS = 100

SEED = 42
Individual = namedtuple("Individual", ["genome", "fitness"])

In [75]:
def problem(N, seed=None):
    random.seed(seed)
    return [
        list(set(random.randint(0, N - 1) for n in range(random.randint(N // 5, N // 2))))
        for n in range(random.randint(N, N * 5))
    ]


In [76]:
def goal_test(state):
    return set(sum((e for e in state), start=())) == GOAL

In [77]:
def w(genome):
    if goal_test(genome) is False:
        return -1000000000
    else:
        return -sum([len(i) for i in genome])


def tournament(population, tournament_size=5):
    return max(random.choices(population, k=tournament_size), key=lambda i: i.fitness)


def cross_over(g1, g2, all_lists):
    # cut = random.randint(0, PROBLEM_SIZE)
    cut = random.randint(0, min([len(g1), len(g2)]))

    for _ in range(100):
        new = set(g1[:cut] + g2[cut:])
        if len(new) < N:
            pick_one = random.choice(list(set(all_lists) - new))
            new = g1[:cut] + [pick_one] + g2[cut + 1:]
            if w(new) > -100000000:
                return new

    return g1


def mutation(g, all_lists):
    point = random.randint(0, len(g) - 1)

    for _ in range(100):
        pick_one = random.choice(list(set(all_lists) - set(g)))
        new = g[:point] + [pick_one] + g[point + 1:]
        if w(new) > w(g):
            return new

    return g

In [78]:
def generate_population(N):
    problem_out = problem(N, SEED)
    all_lists = list(set(tuple(sorted(_)) for _ in problem_out))
    # MAX_LENGTH = sum([sum(i) for i in all_lists])
    # print(MAX_LENGTH)

    population = list()
    for _ in range(N * 20):
        genome = random.sample(all_lists, random.randint(0, N-1))

        population.append(Individual(genome, w(genome)))
    return population, all_lists


In [79]:
def let_it_grow(N, INITIAL_POPULATION, all_lists):
    # problem_out = problem(5, SEED)
    # all_lists = list(set(tuple(sorted(_)) for _ in problem_out))
    # print(all_lists)

    POPULATION = copy.deepcopy(INITIAL_POPULATION)

    optim_flag = False
    for g in range(NUM_GENERATIONS):

        optim = [i for i in POPULATION if i.fitness == -N]
        if len(optim) != 0:
            optim_flag = True
            print(f"OPTIMUM solution: {optim[0].genome}")
            print(f"Total cost: {-optim[0].fitness}")
            break

        offspring = list()
        for i in range(OFFSPRING_SIZE):
            if random.random() < 0.3:
                p = tournament(POPULATION)
                o = mutation(p.genome, all_lists)
            else:
                p1 = tournament(POPULATION)
                p2 = tournament(POPULATION)
                o = cross_over(p1.genome, p2.genome, all_lists)
            f = w(o)
            offspring.append(Individual(o, f))
        POPULATION += offspring
        POPULATION = sorted(POPULATION, key=lambda i: i.fitness, reverse=True)[:2 * N]
    if optim_flag is not True:
        print(f"solution: {POPULATION[0].genome}")
        print(f"total cost: {-POPULATION[0].fitness}")


In [80]:
for N in [5, 10, 20, 100]:
    GOAL = set(range(N))
    INITIAL_POPULATION, all_lists = generate_population(N)

    print(f"Searching solution for N = {N}")
    let_it_grow(N, INITIAL_POPULATION, all_lists)
    print("------------------------------------------------------------")

Searching solution for N = 5
OPTIMUM solution: [(1, 3), (2, 4), (0,)]
Total cost: 5
------------------------------------------------------------
Searching solution for N = 10
solution: [(1, 6), (2, 7, 8), (3, 9), (0, 4), (5, 6)]
total cost: 11
------------------------------------------------------------
Searching solution for N = 20
solution: [(4, 7, 11, 12, 15, 16, 18), (0, 1, 2, 3, 6, 13, 17, 18), (5, 7, 8, 13, 14), (6, 9, 16, 19), (1, 7, 10, 17)]
total cost: 28
------------------------------------------------------------
Searching solution for N = 100
solution: [(16, 23, 26, 30, 33, 34, 35, 43, 46, 47, 51, 53, 54, 57, 60, 63, 70, 73, 78, 87, 89, 96), (2, 3, 5, 6, 8, 10, 13, 14, 16, 19, 20, 21, 25, 27, 28, 30, 31, 32, 39, 44, 47, 53, 59, 60, 69, 72, 77, 79, 83, 85, 89, 91, 95, 96, 98, 99), (2, 5, 9, 16, 29, 40, 45, 48, 50, 51, 53, 54, 55, 58, 64, 68, 71, 73, 81, 85, 90, 94), (0, 3, 5, 15, 16, 23, 24, 26, 29, 37, 42, 44, 52, 53, 59, 63, 66, 71, 72, 78, 80, 81, 90, 93, 94), (1, 8, 10, 