### Crossover算子
值得一提的是，DEAP中GP默认实现的Crossover算子不考虑根节点。因此，如果要按照GP的原始论文实现，需要稍作修改。

In [1]:
import time

import numpy as np
from deap import base, creator, tools, gp


# 符号回归
def evalSymbReg(individual, pset):
    # 编译GP树为函数
    func = gp.compile(expr=individual, pset=pset)
    
    # 使用numpy创建一个向量
    x = np.linspace(-10, 10, 100) 
    
    # 评估生成的函数并计算MSE
    mse = np.mean((func(x) - x**2)**2)
    
    return (mse,)

# 创建个体和适应度函数
creator.create("FitnessMin", base.Fitness, weights=(-1.0,))
creator.create("Individual", gp.PrimitiveTree, fitness=creator.FitnessMin)

具体来说，需要修改交叉点的取值范围，以包括根节点。

In [2]:
from collections import defaultdict

__type__ = object

def cxOnePoint(ind1, ind2):
    # List all available primitive types in each individual
    types1 = defaultdict(list)
    types2 = defaultdict(list)
    if ind1.root.ret == __type__:
        # Not STGP optimization
        types1[__type__] = list(range(0, len(ind1)))
        types2[__type__] = list(range(0, len(ind2)))
        common_types = [__type__]
    else:
        for idx, node in enumerate(ind1[0:], 1):
            types1[node.ret].append(idx)
        for idx, node in enumerate(ind2[0:], 1):
            types2[node.ret].append(idx)
        common_types = set(types1.keys()).intersection(set(types2.keys()))

    if len(common_types) > 0:
        type_ = random.choice(list(common_types))

        index1 = random.choice(types1[type_])
        index2 = random.choice(types2[type_])

        slice1 = ind1.searchSubtree(index1)
        slice2 = ind2.searchSubtree(index2)
        ind1[slice1], ind2[slice2] = ind2[slice2], ind1[slice1]

    return ind1, ind2

In [3]:
import random

# 定义函数集合和终端集合
pset = gp.PrimitiveSet("MAIN", arity=1)
pset.addPrimitive(np.add, 2)
pset.addPrimitive(np.subtract, 2)
pset.addPrimitive(np.multiply, 2)
pset.addPrimitive(np.negative, 1)
pset.addEphemeralConstant("rand101", lambda: random.randint(-1, 1))
pset.renameArguments(ARG0='x')

# 定义遗传编程操作
toolbox = base.Toolbox()
toolbox.register("expr", gp.genHalfAndHalf, pset=pset, min_=1, max_=2)
toolbox.register("individual", tools.initIterate, creator.Individual, toolbox.expr)
toolbox.register("population", tools.initRepeat, list, toolbox.individual)
toolbox.register("compile", gp.compile, pset=pset)
toolbox.register("evaluate", evalSymbReg, pset=pset)
toolbox.register("select", tools.selTournament, tournsize=3)
toolbox.register("mate", cxOnePoint)
toolbox.register("mutate", gp.mutUniform, expr=toolbox.expr, pset=pset)



In [4]:
import numpy
from deap import algorithms

# 定义统计指标
stats_fit = tools.Statistics(lambda ind: ind.fitness.values)
stats_size = tools.Statistics(len)
mstats = tools.MultiStatistics(fitness=stats_fit, size=stats_size)
mstats.register("avg", numpy.mean)
mstats.register("std", numpy.std)
mstats.register("min", numpy.min)
mstats.register("max", numpy.max)

# 使用默认算法
start=time.time()
population = toolbox.population(n=300)
hof = tools.HallOfFame(1)
pop, log  = algorithms.eaSimple(population=population,
                           toolbox=toolbox, cxpb=0.5, mutpb=0.2, ngen=50, stats=mstats, halloffame=hof, verbose=True)
end=time.time()
print('time:',end-start)
print(str(hof[0]))

   	      	                    fitness                    	                      size                     
   	      	-----------------------------------------------	-----------------------------------------------
gen	nevals	avg    	gen	max   	min	nevals	std    	avg    	gen	max	min	nevals	std    
0  	300   	3301.28	0  	153712	0  	300   	12451.3	4.02667	0  	7  	2  	300   	1.67111
1  	168   	3939.35	1  	159956	0  	168   	17617.1	4.56333	1  	11 	1  	168   	2.21193
2  	178   	159799 	2  	4.69134e+07	0  	178   	2.70387e+06	4.86667	2  	16 	1  	178   	2.7207 
3  	182   	5691.28	3  	159821     	0  	182   	24620.6    	5.23667	3  	15 	1  	182   	2.74845
4  	164   	5353.98	4  	153712     	0  	164   	24489.9    	5.54667	4  	15 	1  	164   	2.99352
5  	170   	9646.87	5  	159956     	0  	170   	35453.2    	5.38667	5  	22 	1  	170   	3.38583
6  	191   	52366.6	6  	1.18774e+07	0  	191   	685076     	4.98667	6  	18 	1  	191   	2.95181
7  	190   	14822.3	7  	608604     	0  	190   	53512.8    	5.32   	7  