diff --git a/search.py b/search.py index 04d5b6c51..dffd5c6a1 100644 --- a/search.py +++ b/search.py @@ -570,30 +570,37 @@ def LRTA_cost(self, s, a, s1, H): # Genetic Algorithm -def genetic_search(problem, fitness_fn, ngen=1000, pmut=0.1, n=20): +def genetic_search(problem, fitness_fn,gene_bound,ngen=1000, optimal_value=10000000,pmut=0.1, n=20,initial_population=None): """ Call genetic_algorithm on the appropriate parts of a problem. This requires the problem to have states that can mate and mutate, - plus a value method that scores states.""" - s = problem.initial_state - states = [problem.result(s, a) for a in problem.actions(s)] - random.shuffle(states) - return genetic_algorithm(states[:n], problem.value, ngen, pmut) - - -def genetic_algorithm(population, fitness_fn, ngen=1000, pmut=0.1): + plus a value method that scores states.These states are passed as initial + population to the search""" + if(initial_population == None) : + raise Exception("Initial population not given in genetic search") + else : + random.shuffle(initial_population) + newfitness_fn = lambda inidividual : fitness_fn(inidividual.genes) + population = [GAState(initial_population[i]) for i in range(len(initial_population))] + best_individual = genetic_algorithm(population[:n],newfitness_fn,gene_bound,optimal_value, ngen, pmut) + return best_individual.genes + +def genetic_algorithm(population, fitness_fn, gene_bound,optimal_value=10000000, ngen=1000, pmut=0.1): "[Figure 4.8]" - for i in range(ngen): + for i in range(int(ngen)): new_population = [] - for i in range(len(population)): + for j in range(len(population)): fitnesses = map(fitness_fn, population) p1, p2 = weighted_sample_with_replacement(population, fitnesses, 2) child = p1.mate(p2) if random.uniform(0, 1) < pmut: - child.mutate() + child.mutate(gene_bound) new_population.append(child) population = new_population - return argmax(population, key=fitness_fn) + current_bestindividual = argmax(population, key=fitness_fn) + if(fitness_fn(current_bestindividual) >= optimal_value) : + return current_bestindividual + return current_bestindividual class GAState: @@ -608,9 +615,10 @@ def mate(self, other): c = random.randrange(len(self.genes)) return self.__class__(self.genes[:c] + other.genes[c:]) - def mutate(self): - "Change a few of my genes." - raise NotImplementedError + def mutate(self,gene_bound) : + "Change one of my genes." + index = random.choice(range(len(self.genes))) + self.genes[index] = random.choice(range(gene_bound[0],gene_bound[1])) # _____________________________________________________________________________ # The remainder of this file implements examples for the search algorithms. @@ -884,6 +892,16 @@ def goal_test(self, state): return not any(self.conflicted(state, state[col], col) for col in range(len(state))) + def value(self,state): + """Returns value corresponding to a state where value is defined as + the number of pairs of non-attacking queens""" + attacking_sum = 0 + for c1 in range(len(state)): + if not state[c1] == None : + for c2 in range(c1+1,len(state)): + if not state[c2] == None : + attacking_sum += self.conflict(state[c1],c1,state[c2],c2) + return (self.N*(self.N - 1))/2 - attacking_sum # ______________________________________________________________________________ # Inverse Boggle: Search for a high-scoring Boggle board. A good domain for # iterative-repair and related search techniques, as suggested by Justin Boyan. diff --git a/tests/test_search.py b/tests/test_search.py index 11d522e94..00c3d5c5b 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -86,6 +86,18 @@ def test_LRTAStarAgent(): my_agent = LRTAStarAgent(LRTA_problem) assert my_agent('State_5') is None +def test_genetic_search(): + N = 5 + nqueens_problem = NQueensProblem(N) + initial_population = [] + gene_bound = (0,N) + for i in range(N * 20) : + population = [random.choice(range(N)) for gene in range(N)] + initial_population.append(population) + result = genetic_search(nqueens_problem,nqueens_problem.value,gene_bound,(N*(N-1))/2,1000,0.1,N * 20,initial_population) + for col1 in range(len(result)) : + for col2 in range(col1+1,len(result)) : + assert nqueens_problem.conflict(result[col1],col1,result[col2],col2) == False # TODO: for .ipynb: """