## Evolution

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from typing import Any, Callable, List, Union, Tuple
from numbers import Integral, Real
from copy import deepcopy
import torch

import numpy as np
from copy import deepcopy
import time

from tqdm import tqdm

import sys
sys.path.append('/root/surrogate')

# example code
# prescriptor
from Prescriptor import Prescriptor

# crossover
from crossover._crossovers import UniformCrossover, WeightedSumCrossover
from crossover._base import SkipCrossover, FunctionCrossover

# mutation
from mutation._base import BaseMutation, ChainMutation 
from mutation._add import AddNormalMutation, AddUniformMutation
from mutation._multiply import MultiplyNormalMutation, MultiplyUniformMutation
from mutation._special import FlipSignMutation

# selection
from selection._multi import MultiObjectiveSelection, LexsortSelection, ParetoSelection, ParetoLexsortSelection
from selection._single import SingleObjectiveSelection, TournamentSelection, RouletteSelection



In [2]:
class Evolution:
    def __init__(self, prescriptor, selection, crossover, mutation):
        self.prescriptor = prescriptor
        self.selection = selection
        self.crossover = crossover
        self.mutation = mutation
        
        self.chromosome_size = len(self.prescriptor.layers)
        self.num_parents = self.crossover.get_num_parents()

        self.check_model_shape()

    def check_model_shape(self):
        device = next(self.prescriptor.parameters()).device
        self.prescriptor.cpu()
        self.shape_each_layer = []
        self.num_each_layer = []
        for name, param in self.prescriptor.layers[0].named_parameters():
            size = list(param.size())
            self.shape_each_layer.append(size)
            layer_param = 1
            for idx, item in enumerate(size):
                layer_param *= item
            self.num_each_layer.append(layer_param)

        self.prescriptor = self.prescriptor.to(device)


    def update_chromosomes(self, chromosomes, device='cpu'):
        with torch.no_grad():
            for idx, old_chromo in enumerate(self.prescriptor.layers.cpu()):
                new_chromo = chromosomes[idx]
                sd = old_chromo.state_dict()
                split_base = 0
                for idx_sd, param_name in enumerate(sd):
                    split_margin = split_base + self.num_each_layer[idx_sd]
                    param = torch.reshape(new_chromo[split_base:split_margin], shape=self.shape_each_layer[idx_sd])
                    sd[param_name] = param
                    split_base = split_margin
                old_chromo.load_state_dict(sd)

        self.prescriptor.to(device)

    def flatten_chromosomes(self,):
        device = next(self.prescriptor.parameters()).device
        self.prescriptor.cpu()
        with torch.no_grad():
            chromosomes = []
            for ch in self.prescriptor.layers.cpu():
                chromosome = []
                for name, param in ch.named_parameters():
                    chromosome.append(param.flatten())

                chromosomes.append(torch.concat(chromosome).unsqueeze(dim=0))
        return torch.concat(chromosomes), device

    def evolve(self, fitness: torch.Tensor):
        # chromosomes = self.prescriptor.chromosomes.cpu()
        chromosomes, device = self.flatten_chromosomes()
        
        self.selection.select(fitness)
        elite_idx = self.selection.elite_idx()
        elite_chromosomes = chromosomes[elite_idx]

        offspring_size = self.chromosome_size - len(elite_idx)
        select_parents_idx = self.selection.pick_parents(self.num_parents, offspring_size)
        parents = chromosomes[select_parents_idx]
        
        offspring = self.crossover(parents)
        offspring = self.mutation(offspring)

        chromosomes = torch.concat([elite_chromosomes, offspring])
        self.update_chromosomes(chromosomes, device)

In [3]:
elite_size = 200
parent_size = 200
generation = 100
device = 'cpu'

batch_size = 1300
input_dim = 41
pop_size = 1000
hidden_dim = 64
output_dim = 32


prescriptor = Prescriptor(
                basic_block=None,
                input_dim=input_dim, 
                hidden_dim=hidden_dim, 
                output_dim=output_dim, 
                num_blocks=pop_size).to(device).requires_grad_(False)

# single Selector
selection = LexsortSelection(elite_num=elite_size,
                                parents_num=parent_size,
                                )

crossover = UniformCrossover()
mutation = ChainMutation([AddNormalMutation(mut_prob=0.2), MultiplyUniformMutation(mut_prob=0.2), FlipSignMutation(0.07)])
# mutation = FlipSignMutation(mut_prob=0.2)

# Example usages
evolution = Evolution(
                    prescriptor=prescriptor,
                    selection=selection,
                    crossover=crossover,
                    mutation=mutation,
)

# orginal_layer = 

In [4]:
x = torch.randn(size=(batch_size, input_dim)).to(device)

In [5]:
fitness_list = []
for gen in range(generation):
    out = prescriptor(x)
    out = torch.concat(out).reshape(pop_size, batch_size, -1)
    # fitness = torch.sum(torch.sum(torch.abs(out), dim=1), dim=1)
    fitness = torch.sum(torch.abs(out), dim=1)
    
    parent_chrom = evolution.evolve(fitness.cpu())
    fitness_list.append(fitness.cpu())
    # break