In [14]:
import torch
import torch.nn.functional as F
import torch.utils.data as Data
from thop import profile

import os
import matplotlib.pyplot as plt
import random
import copy
import numpy as np
import math

from MadeData import MadeData
from DNA import DNA
from StructMutation import StructMutation

# import global_var
from global_var import DNA_cnt


In [15]:
class Model(torch.nn.Module):
    def __init__(self, DNA, parent_model=None):
        super(Model, self).__init__()
        self.dna = DNA
        self.layer_vertex = torch.nn.ModuleList()
        # print('init vertex', end='')
        for i, vertex in enumerate(DNA.vertices):
            # print('v{} '.format(i), end='')
            # 默认第一层和最后一层 vertex 非 hidden 层
            if vertex.type == 'bn_relu':
                self.layer_vertex.append(
                    torch.nn.Sequential(torch.nn.BatchNorm2d(vertex.input_channel),
                                        torch.nn.ReLU(inplace=True)))
            elif vertex.type == 'Global Pooling':
                self.layer_vertex.append(
                    torch.nn.Sequential(
                        # torch.nn.AdaptiveAvgPool2d((1, 1)),
                        torch.nn.Linear(vertex.input_channel, DNA.output_size_channel)))
            else:
                self.layer_vertex.append(None)

        self.layer_edge = torch.nn.ModuleList()
        # print('\ninit edges', end='')
        for i, edge in enumerate(DNA.edges):
            # TODO: 默认padding补全
            # print('e{}:'.format(i), end='')
            if edge.type == 'conv':
                # print('{},{},{} |'.format(edge.filter_half_height, edge.filter_half_width,edge.stride_scale),end=' ')
                temp = torch.nn.Conv2d(edge.input_channel,
                                       edge.output_channel,
                                       kernel_size=(edge.filter_half_height * 2 + 1,
                                                    edge.filter_half_width * 2 + 1),
                                       stride=pow(2, edge.stride_scale),
                                       padding=(edge.filter_half_height, edge.filter_half_width))
                if edge.model_id != -1 or parent_model == None:
                    temp.weight = parent_model.layer_edge[i].weight
                self.layer_edge.append(temp)
            else:
                # print(end=' |')
                self.layer_edge.append(None)
        # print('')
        self.batch_size = Evolution_pop.BATCH_SIZE

    def forward(self, input):
        '''
        配置每层的 输入、输出、激活函数
        '''
        block_h = input.shape[0]
        x = {
            0: input,
        }
        for index, layer_vert in enumerate(self.layer_vertex[1:], start=1):
            length = len(x)

            a = torch.empty(block_h, 0, 0, 0)
            for j, edg in enumerate(self.dna.vertices[index].edges_in):
                ind_edg = self.dna.edges.index(edg)
                ind_x = self.dna.vertices.index(edg.from_vertex)
                t = x[ind_x]
                if edg.type == 'conv':
                    t = self.layer_edge[ind_edg](x[ind_x])
                if j == 0:
                    a = torch.empty(block_h, 0, t.shape[2], t.shape[3])
                a = torch.cat((a, t), dim=1)

            if self.dna.vertices[index].type == 'linear':
                x[index] = a
            elif self.dna.vertices[index].type == 'bn_relu':
                x[index] = layer_vert(a)
            elif self.dna.vertices[index].type == 'Global Pooling':
                temp = torch.nn.AdaptiveAvgPool2d((1, 1))
                a = temp(a)
                a = torch.squeeze(a, 3)
                a = torch.squeeze(a, 2)
                x[index] = layer_vert(a)

        return x[len(x) - 1]


class Evolution_pop:
    _population_size_setpoint = 10
    _evolve_time = 100
    fitness_pool = []

    EPOCH = 3  # 训练整批数据多少次
    BATCH_SIZE = 50
    N_CLASSES = 10

    # LR = 0.001          # 学习率

    def __init__(self, data, pop_max=10, evolve_time=100):
        '''
        初始化DNA: 一层hidden(节点数不同); 都为linear
        接收传入的训练数据 data
        初始化 Mutation 类
        '''
        self.population = []
        self.model_stack = {}

        for i in range(self._population_size_setpoint):
            dna_iter = DNA()
            self.population.append(dna_iter)
            dna_iter.calculate_flow()
            self.model_stack[dna_iter.dna_cnt] = Model(dna_iter)

            global DNA_cnt
            DNA_cnt = self._population_size_setpoint

        self.data = data
        self.struct_mutation = StructMutation()

        self._population_size_setpoint = pop_max
        self._evolve_time = evolve_time

        self.fitness_dir = {}

    def decode(self):
        '''
         对当前population队列中的每个未训练过的个体进行训练 
         https://www.cnblogs.com/denny402/p/7520063.html
        '''
        for dna in self.population:
            if dna.fitness != -1.0:
                continue
            # TODO: 新训练的个体将fitness加入fitness_pool
            # dna.calculate_flow()
            # net = Model(dna)

            net = self.model_stack[dna.dna_cnt]
            print("[decode].[", dna.dna_cnt, "]", net)

            optimizer = torch.optim.Adam(net.parameters(), lr=dna.learning_rate)
            # the target label is not one-hotted
            loss_func = torch.nn.CrossEntropyLoss()

            train_loader, testloader = self.data.getData()
            # print("[Evolution_pop].[decode]->test_x: ", test_x.shape)
            accuracy = 0
            # training and testing
            for epoch in range(self.EPOCH):
                step = 0
                # TODO: 用movan的enumerate会报错，why?
                max_tep = int(60000 / train_loader.batch_size)

                train_acc = .0
                len_y = 0
                for step, (b_x, b_y) in enumerate(train_loader):
                    # print("[b_x, b_y].shape: ", b_x.shape, b_y.shape)
                    # 分配 batch data, normalize x when iterate train_loader
                    output = net(b_x)  # cnn output
                    idy = b_y.view(-1, 1)
                    # b_y = torch.zeros(self.BATCH_SIZE, 10).scatter_(1, idy, 1).long()

                    loss = loss_func(output, b_y)  # cross entropy loss
                    # clear gradients for this training step
                    optimizer.zero_grad()
                    loss.backward()  # backpropagation, compute gradients
                    optimizer.step()  # apply gradients

                    if step % 50 == 0:
                        pred = net(b_x)

                        print("\r" + 'Epoch: ' + str(epoch) + ' step: ' + str(step) + '[' +
                              ">>>" * int(step / 50) + ']',
                              end=' ')
                        # print('loss: %.4f' % loss.data.numpy(), '| accuracy: %.4f' % accuracy, end=' ')
                        print('loss: %.4f' % loss.data.numpy(), end=' ')
                print('')

            self.model_stack[dna.dna_cnt] = net
            # evaluation--------------------------------
            accuracy = self.Accuracy(net, testloader)
            input = torch.randn(self.BATCH_SIZE, dna.input_size_channel, dna.input_size_height,
                                dna.input_size_width)
            flops, params = profile(net, inputs=(input, ))
            print('----- Accuracy: {:.6f} Flops: {:.6f}-----'.format(accuracy, flops))
            # dna.fitness = eval_acc / len_y
            dna.fitness = accuracy
            self.fitness_dir[dna.dna_cnt] = accuracy
            print('')

    def Accuracy(self, net, testloader):
        ''' https://blog.csdn.net/Arctic_Beacon/article/details/85068188 '''
        classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

        class_correct = list(0. for i in range(self.N_CLASSES))
        class_total = list(0. for i in range(self.N_CLASSES))
        with torch.no_grad():
            for data in testloader:
                images, labels = data
                outputs = net(images)
                _, predicted = torch.max(outputs, 1)
                c = (predicted == labels).squeeze()
                for i in range(self.BATCH_SIZE):
                    label = labels[i]
                    class_correct[label] += c[i].item()
                    class_total[label] += 1

        # for i in range(self.N_CLASSES):
        #     print('Accuracy of %5s : %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))
        return sum(class_correct) / sum(class_total)

    def choose_varition_dna(self):
        '''
        每次挑选两个体,取fitness,判断要kill还是reproduce
        '''
        while self._evolve_time > 0:
            self._evolve_time -= 1
            self.decode()
            # 每次挑两个个体并提取出训练成绩fitness
            individual_pair = random.sample(list(enumerate(self.population)), 2)
            # TODO: 话说他这样取出来如果删掉的话真的能保证吗
            individual_pair.sort(key=lambda i: i[1].fitness, reverse=True)
            # better_individual = individual_pair[0]
            # worse_individual = individual_pair[1]
            # print("Choice: ",self._evolve_time,end=' ')
            # print("better: ",better_individual[0],'->',better_individual[1].fitness, end=' ')
            # print("worse: ", worse_individual[0],'->', worse_individual[1].fitness, end=' ')
            better_individual = individual_pair[0][0]
            worse_individual = individual_pair[1][0]
            individual_pair = []
            # (population过大->kill不好的)，反之(population过小->reproduce好的)
            if len(self.population) >= self._population_size_setpoint:
                print("--kill worse", worse_individual)
                self._kill_individual(worse_individual)
            elif len(self.population) < self._population_size_setpoint:
                print("--reproduce better", better_individual)
                self._reproduce_and_train_individual(better_individual)
        self.population.sort(key=lambda i: i.fitness, reverse=True)
        print(self.population[0].fitness)
        self.population[0].calculate_flow()
        # self.pop_show()

    def _kill_individual(self, index):
        ''' kill by the index of population '''
        # self._print_population()
        if self.population[index].dna_cnt in self.model_stack:
            del self.model_stack[self.population[index].dna_cnt]
        del self.population[index]

        # debug
        # self._print_population()

    def _reproduce_and_train_individual(self, index):
        ''' 
        inherit the parent, mutate, join the population 
        为了节省时间实际上有 Weight Inheritance
        '''
        # self._print_population()

        # inherit the parent (attention the dna_cnt)
        son = self.inherit_DNA(self.population[index])

        self.struct_mutation.mutate(son)
        son.calculate_flow()
        net = Model(son, self.model_stack[self.population[index].dna_cnt])

        self.model_stack[son.dna_cnt] = net
        self.population.append(son)
        # debug
        # self._print_population()

    def inherit_DNA(self, dna):
        ''' inderit from parent: reset dna_cnt, fitness '''
        son = copy.deepcopy(dna)
        global DNA_cnt
        son.dna_cnt = DNA_cnt
        DNA_cnt += 1
        son.fitness = -1
        return son

    def _print_population(self):
        print("pop sum: ", len(self.population), '|', end=' ')
        index = 0
        for i in self.population:
            print('(', index, '->', i.dna_cnt, ')', end=' ')
            index += 1
        print('')

    def pop_show(self):
        ''' 画出种群变化分布图 '''
        best_individual = self.population[0].dna_cnt
        live_individual = []
        for i in self.population:
            live_individual.append(i.fitness)

        global DNA_cnt
        show_x = []
        show_y = []
        show_color = []
        for i in range(DNA_cnt + 1):
            if i in self.fitness_dir:
                show_x.append(i)
                show_y.append(self.fitness_dir[i])
                if i in live_individual:
                    if i == self.population[0].dna_cnt:
                        show_color.append('red')
                    else:
                        show_color.append('blue')
                else:
                    show_color.append('gray')
        plt.scatter(show_x, show_y, c=show_color, marker='.')
        plt.show()


In [18]:
if __name__ == "__main__":
    data = MadeData()
    # 数据集选择
    train_loader, testloader = data.CIFR10()

    # test = Evolution_pop(train_loader, test_x, test_y)
    test = Evolution_pop(data, pop_max=10, evolve_time=100)
    test.choose_varition_dna()
    test.pop_show()

    print()

vertex [ 1 ].0 , 0.i_s0,N,N 
vertex [ 2 ].0 , 1.i_s0,N,N 
[calculate_flow] finish
vertex [ 1 ].0 , 0.i_s0,N,N 
vertex [ 2 ].0 , 1.i_s0,N,N 
[calculate_flow] finish
vertex [ 1 ].0 , 0.i_s0,N,N 
vertex [ 2 ].0 , 1.i_s0,N,N 
[calculate_flow] finish
vertex [ 1 ].0 , 0.i_s0,N,N 
vertex [ 2 ].0 , 1.i_s0,N,N 
[calculate_flow] finish
vertex [ 1 ].0 , 0.i_s0,N,N 
vertex [ 2 ].0 , 1.i_s0,N,N 
[calculate_flow] finish
vertex [ 1 ].0 , 0.i_s0,N,N 
vertex [ 2 ].0 , 1.i_s0,N,N 
[calculate_flow] finish
vertex [ 1 ].0 , 0.i_s0,N,N 
vertex [ 2 ].0 , 1.i_s0,N,N 
[calculate_flow] finish
vertex [ 1 ].0 , 0.i_s0,N,N 
vertex [ 2 ].0 , 1.i_s0,N,N 
[calculate_flow] finish
vertex [ 1 ].0 , 0.i_s0,N,N 
vertex [ 2 ].0 , 1.i_s0,N,N 
[calculate_flow] finish
vertex [ 1 ].0 , 0.i_s0,N,N 
vertex [ 2 ].0 , 1.i_s0,N,N 
[calculate_flow] finish
DNA [ 7 ]销毁->fitness 0.26
DNA [ 11 ]销毁->fitness 0.34
DNA [ 13 ]销毁->fitness 0.38
DNA [ 15 ]销毁->fitness 0.392
DNA [ 16 ]销毁->fitness 0.38
DNA [ 17 ]销毁->fitness -1
[decode].[ 14 ] Mode

)
Epoch: 0 step: 950[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>]ss: 2.1162  loss: 2.0176  loss: 2.0431  loss: 2.0549  loss: 2.2135  loss: 2.1740  loss: 2.0007  loss: 2.0868  loss: 1.9538  loss: 1.9207  loss: 2.1022  loss: 2.0712  loss: 2.1217  loss: 2.2292  loss: 2.1220  loss: 1.9585 
Epoch: 1 step: 950[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>]ss: 2.0424  loss: 2.0535  loss: 2.0863  loss: 2.1314  loss: 2.0375  loss: 2.0285  loss: 2.0986  loss: 2.0216  loss: 2.0764  loss: 1.9615  loss: 2.0921  loss: 2.0169  loss: 2.1286  loss: 2.0972  loss: 2.1571  loss: 1.9435 
Epoch: 2 step: 950[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>]ss: 2.0535  loss: 1.9760  loss: 2.2977  loss: 2.0447  loss: 2.2558  loss: 1.9403  loss: 2.0696  loss: 2.1295  loss: 2.1873  loss: 2.2811  loss: 2.0478  loss: 2.1239  loss: 2.0360  loss: 2.2074  loss: 2.1438  loss: 1.9728 
----- Accuracy: 0.260000 Flops: 3000.000000-----

[decode].[ 21 ] Model(
  (layer_vertex): ModuleList

Epoch: 0 step: 950[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>]ss: 2.0548  loss: 1.7313  loss: 1.9530  loss: 1.6978  loss: 1.7217  loss: 1.6980  loss: 1.7402  loss: 1.8395  loss: 1.8626  loss: 1.7452  loss: 1.8542  loss: 1.7704  loss: 1.7378  loss: 1.6946  loss: 1.6442  loss: 1.8526 
Epoch: 1 step: 950[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>]ss: 1.6671  loss: 1.9676  loss: 1.5197  loss: 1.7436  loss: 1.6865  loss: 1.8050  loss: 1.6949  loss: 1.6324  loss: 1.8865  loss: 1.7494  loss: 1.6486  loss: 1.7720  loss: 1.7702  loss: 1.6687  loss: 1.8987  loss: 1.6086 
Epoch: 2 step: 950[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>]ss: 1.5841  loss: 1.8754  loss: 1.9675  loss: 1.6141  loss: 1.7802  loss: 1.7503  loss: 1.6384  loss: 1.6249  loss: 1.7904  loss: 1.7774  loss: 1.6263  loss: 1.6203  loss: 2.0349  loss: 1.7367  loss: 1.6570  loss: 1.8937 
----- Accuracy: 0.344000 Flops: 16907000.000000-----

--kill worse 6
DNA [ 21 ]销毁->fitness 0.212
--repr

Epoch: 2 step: 950[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>]ss: 1.6997  loss: 1.6242  loss: 1.5106  loss: 1.5915  loss: 1.8625  loss: 1.7436  loss: 1.7711  loss: 1.7147  loss: 1.7058  loss: 1.8006  loss: 1.8261  loss: 1.7833  loss: 1.6579  loss: 1.5881  loss: 1.5231  loss: 1.7482 
----- Accuracy: 0.376000 Flops: 19469000.000000-----

--kill worse 4
DNA [ 23 ]销毁->fitness 0.264
--reproduce better 5
vertex [ 1 ].0 , 0.i_s0,N,N 
vertex [ 2 ].0 , 1.c_s0,1,1 
vertex [ 3 ].0 , 2.i_s0,N,N 
[calculate_flow] finish
[decode].[ 18 ] Model(
  (layer_vertex): ModuleList(
    (0): None
    (1): None
    (2): None
    (3): Sequential(
      (0): Linear(in_features=10, out_features=10, bias=True)
    )
  )
  (layer_edge): ModuleList(
    (0): None
    (1): None
    (2): Conv2d(3, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
)
Epoch: 0 step: 950[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>]ss: 2.2726  loss: 2.1902  loss: 2.3269  loss: 2.3860  loss: 2.2160  los

Epoch: 0 step: 950[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>]ss: 2.0416  loss: 2.1362  loss: 2.1820  loss: 2.0616  loss: 2.2326  loss: 1.9326  loss: 1.9712  loss: 2.1452  loss: 2.1729  loss: 2.1606  loss: 2.1061  loss: 2.2984  loss: 2.0570  loss: 2.0363  loss: 1.8804  loss: 2.0098 
Epoch: 1 step: 950[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>]ss: 1.9629  loss: 1.9940  loss: 1.9387  loss: 1.9610  loss: 1.8476  loss: 1.9594  loss: 2.1090  loss: 2.0286  loss: 2.0706  loss: 1.9666  loss: 2.0598  loss: 2.0824  loss: 2.1002  loss: 2.1584  loss: 2.1015  loss: 2.2557 
Epoch: 2 step: 950[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>]ss: 2.0002  loss: 2.1382  loss: 1.9054  loss: 1.9210  loss: 2.1437  loss: 2.1774  loss: 1.7973  loss: 1.9094  loss: 1.9881  loss: 2.0600  loss: 2.2481  loss: 2.0109  loss: 1.9522  loss: 1.9311  loss: 1.7405  loss: 1.8199 
----- Accuracy: 0.256000 Flops: 11476800.000000-----

--kill worse 1


KeyError: 18

In [19]:
%whos

Variable         Type             Data/Info
-------------------------------------------
DNA              type             <class 'DNA.DNA'>
DNA_cnt          int              23
Data             module           <module 'torch.utils.data<...>tils\\data\\__init__.py'>
Evolution_pop    type             <class '__main__.Evolution_pop'>
F                module           <module 'torch.nn.functio<...>orch\\nn\\functional.py'>
MadeData         type             <class 'MadeData.MadeData'>
Model            type             <class '__main__.Model'>
StructMutation   type             <class 'StructMutation.StructMutation'>
copy             module           <module 'copy' from 'D:\\<...>anaconda3\\lib\\copy.py'>
data             MadeData         <MadeData.MadeData object at 0x000001F70CB15908>
math             module           <module 'math' (built-in)>
np               module           <module 'numpy' from 'D:\<...>ges\\numpy\\__init__.py'>
os               module           <module 'os' from 'D:\\