In [1]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity
%matplotlib inline

In [2]:
# 读取数据
edgeLogo = cv2.imread('../Database/edgeLogo.png')
def ImageShow(windowsName, img):
    cv2.imshow(windowsName, img)
    cv2.waitKey(0)
    cv2.destroyWindow(windowsName)
print(edgeLogo.shape)
# ImageShow("sample", edgeLogo)
# plt.imshow(edgeLogo)
# plt.axis('off')
# plt.show()

(800, 800, 3)


In [3]:
# 构建 chromosome 结构
class Chromosome:
    '''geneLength: 染色体长度'''
    def __init__(self, geneLength, geneSize):
        self.geneLength = geneLength
        self.geneSize = geneSize
        self.gene = np.random.randint(2, size = (self.geneSize, self.geneLength))

# 创建 个体 结构
class Creation:
    def __init__(self, cellCount, geneLength, geneSize, coorRange, canvasSize, backgroundColor):
        self.creationSet = []
        self.translated = []
        self.cellCount = cellCount
        self.geneLength = geneLength
        self.geneSize = geneSize
        self.coorRange = coorRange
        self.canvasSize = canvasSize
        self.backgroundColor = backgroundColor

        for i in range(cellCount):
            self.creationSet.append(Chromosome(geneLength, geneSize))
    

    '''把染色体解码出来 ( 翻译成性状 )'''
    def Decoder(self):
        # canvasR, canvasG, canvasB = [np.ones(self.canvasSize) * self.backgroundColor] * 3
        # canvas = np.ones(self.canvasSize) * self.backgroundColor
        canvas = np.ones((800, 800, 3), np.uint8) * 255
        for gene in self.creationSet:
            intCoor = gene.gene[:-1].dot(2 ** np.arange(self.geneLength)) / (1 << self.geneLength) * (self.coorRange[1] - self.coorRange[0])
            
            # 切割成 8bit
            color = []
            for i in range(0, self.geneLength, 8):
                color.append(gene.gene[-1, i : i + 8].dot(2 ** np.arange(8)))
            color = tuple(int(c) for c in color)
            # 转化坐标
            coor = []
            for i in range(0, self.geneSize - 1, 2):
                coor.append((int(intCoor[i]), int(intCoor[i + 1])))

            canvas = cv2.fillConvexPoly(canvas, points=np.array([coor]), color=color)
            
        self.canvas = canvas
        return canvas 

    def GetFitness(self, realImg):
        return structural_similarity(self.Decoder(), realImg, channel_axis=2)
# test = Creation(5, 24, 7, (0, 800), (800, 800, 3), 255)
# out = test.Decoder()
# out = np.ones((800, 800, 3), np.uint8) * 255
# out = cv2.fillConvexPoly(out, points=np.array([[1, 2], [100, 485], [55, 76]]), color = (56, 122, 9))
# ImageShow("test", out)

In [4]:
# 建模

# 初始化种群
SIZE = 200
CELLCOUNT = 5                                           # 利用 5 个三角形 拟合 edge
GENELENGTH = 24
GENESIZE = 7
COORSIZE = (0, 800)
CANVASSIZE = (800, 800, 3)
BACKGROUNDCOLOR = 255

EPOCHS = 50                                            # 训练次数
MAXSIM = 0.99                                         # 允许最小误差

VariationRate = 0.03
CrossRate = 0.8

population = []

for i in range(SIZE):
    population.append(Creation(CELLCOUNT, GENELENGTH, GENESIZE, COORSIZE, CANVASSIZE, BACKGROUNDCOLOR))

In [5]:
'''返回索引'''
def ChoosePopulation(populationSize, fitnessAll, chooseCount):
    idx = np.random.choice(np.arange(populationSize), p = fitnessAll / sum(fitnessAll), size = chooseCount, replace = False)
    return idx

'''传入个体, 进行变异'''
def Variation(creation):
    switchGene = np.random.randint(CELLCOUNT)
    left = np.random.randint(GENESIZE / 2)
    right = np.random.randint(GENESIZE / 2, GENESIZE)
    # crossPoint = np.random.randint(0, GENESIZE, size = 2)
    # left, right = min(crossPoint), max(crossPoint)
    line = np.random.randint(0, GENESIZE)
    # c, l = np.random.randint(0, GENESIZE), np.random.randint(0, GENESIZE)
    # creation.creationSet[switchGene].gene[c][l] = creation.creationSet[switchGene].gene[c][l] ^ 1
    for i in range(left, right):
        creation.creationSet[switchGene].gene[line][i] = creation.creationSet[switchGene].gene[line][i] ^ 1
    return creation

'''传入两个个体, 用 b 替换 a'''
def Cross(a, b):
    line = np.random.randint(0, GENESIZE)
    left = np.random.randint(GENESIZE / 2)
    right = np.random.randint(GENESIZE / 2, GENESIZE)
    switchGene = np.random.randint(CELLCOUNT)
    a.creationSet[switchGene].gene[line][left : right] = b.creationSet[switchGene].gene[line][left : right]
    
    return a 

def GetFitnessAll(population):
    fitness = []
    for single in population:
        fitness.append(single.GetFitness(edgeLogo))
    return fitness

In [6]:
trainTime = 0
error = -1
resCan = []
resSIM = []
while trainTime < EPOCHS and error < MAXSIM:
    newPopulation = []
    fitness = GetFitnessAll(population)

    maxSSIM = max(fitness)
    idx = fitness.index(maxSSIM)
    resSIM.append(maxSSIM)
    resCan.append(population[idx].canvas)
    trainTime += 1
    error = maxSSIM

    for single in population:
        if np.random.rand() < CrossRate:
            parent = population[np.random.randint(SIZE)]
            for i in range(3):
                single = Cross(single, parent)
        if np.random.rand() < VariationRate:
            for i in range(3):
                single = Variation(single)
        newPopulation.append(single)
    
    nextGenaration = []
    
    newFitness = GetFitnessAll(newPopulation)
    for i in range(int(SIZE / 2)):
        maxSIM = max(fitness)
        index = fitness.index(maxSIM)
        nextGenaration.append(population[index])
        population.pop(index)
        fitness.pop(index)
        
        maxSIM = max(newFitness)
        index = newFitness.index(maxSIM)
        nextGenaration.append(newPopulation[index])
        newPopulation.pop(index)   
        newFitness.pop(index)
        
    population = nextGenaration
    # population = population + newPopulation
    # fitness = fitness + GetFitnessAll(newPopulation)
    # print(len(population))
    # # population = 
    
    # newPopulation = []
    # chosenIndex = ChoosePopulation(len(population), fitness, SIZE)
    # for idx in chosenIndex:
    #     newPopulation.append(population[idx])

    # population = newPopulation
    # print(len(population))
    # # if (trainTime + 1) % 10 == 0:
    