In [1]:
!pip install deap
!pip install tensorboardX
!pip install thop

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting deap
  Downloading deap-1.3.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (139 kB)
[K     |████████████████████████████████| 139 kB 14.3 MB/s 
Installing collected packages: deap
Successfully installed deap-1.3.3
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting tensorboardX
  Downloading tensorboardX-2.5.1-py2.py3-none-any.whl (125 kB)
[K     |████████████████████████████████| 125 kB 7.8 MB/s 
Installing collected packages: tensorboardX
Successfully installed tensorboardX-2.5.1
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting thop
  Downloading thop-0.1.1.post2209072238-py3-none-any.whl (15 kB)
Installing collected packages: thop
Successfully installed thop-0.1.1.post2209072238


# Import modules

In [2]:
from google.colab import drive
drive.mount('/content/gdrive/', force_remount=True)

Mounted at /content/gdrive/


In [3]:
import sys    
path_to_module = '/content/gdrive/MyDrive/IT402'
sys.path.append(path_to_module)

In [4]:
output_path = "/content/gdrive/MyDrive/IT402/output/"
data_path = "/content/gdrive/MyDrive/IT402/"

In [5]:
from scipy.special import comb
from deap import base, creator, tools
from tensorboardX import SummaryWriter
from thop import profile
from os import path
from PIL import Image, ImageFilter
from tqdm import tqdm
from torch import nn
from torch.optim.optimizer import Optimizer
from torch.nn.utils.clip_grad import clip_grad_norm_
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.data import Subset
from torchvision.transforms import functional as TF
from torchvision.datasets.utils import list_files

from metrics.average_meter import AverageMeter
from metrics.calculate_metrics import calculate_metrics
from genetic_model import UnetBlock, check_active, count_param, Net

import torch.multiprocessing
import os
import sys
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import pickle
import numpy as np
import torch.multiprocessing
import matplotlib.pyplot as plt
import multiprocessing.pool
import itertools as it
import math
import shutil
import multiprocessing as mp

# Dataset

In [6]:
class DRIVE_dataset(Dataset):

    def __init__(self, data_root, train=True, transforms=None):
        super(DRIVE_dataset, self).__init__()
        self.data_root = data_root
        self.transforms = transforms
        self.num_return = 2
        self.dataset = DRIVEPILDataset(self.data_root)
        self.train = train

    def __getitem__(self, index):
        image, annot = self.dataset[index]

        if self.transforms is None:
            image, annot = self._default_trans(image, annot, self.train)
        else:
            image, annot = self.transforms(image, annot)

        return image, annot

    def __len__(self):
        return len(self.dataset)

    @staticmethod
    def _default_trans(image, annot, train):

        annot = TF.to_grayscale(annot, num_output_channels=1)
        if train:
            if random.random() < 0.5:
                image = TF.hflip(image)
                annot = TF.hflip(annot)
            #
            if random.random() < 0.5:
                image = TF.vflip(image)
                annot = TF.vflip(annot)
            if random.random() < 0.6:
                angle = random.random() * 360
                image = TF.rotate(img=image, angle=angle)
                annot = TF.rotate(img=annot, angle=angle)

        image = TF.to_tensor(image)
        image = TF.normalize(image, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))

        annot = TF.to_tensor(annot)
        annot[annot > 0.5] = 1
        annot[annot < 0.5] = 0
        return image, annot

In [7]:
class DRIVEPILDataset(Dataset):
    def __init__(self, data_root):
        self.data_root = path.expanduser(data_root)
        self._image_dir = path.join(self.data_root, 'images')
        self._annot_dir = path.join(self.data_root, 'labels')

        self._image_paths = sorted(list_files(self._image_dir, suffix=('.tif', '.TIF'), prefix=True))
        self._annot_paths = sorted(list_files(self._annot_dir, suffix=('.gif', '.GIF'), prefix=True))

    def __getitem__(self, index):
        image = Image.open(self._image_paths[index], mode='r').convert('RGB')
        annot = Image.open(self._annot_paths[index], mode='r').convert('1')
        return image, annot

    def __len__(self):
        return len(self._image_paths)

def get_datasets(dataset_name, data_root, train, transforms=None):
    if dataset_name == 'DRIVE':
        dataset = DRIVE_dataset(data_root=data_root, train=train, transforms=transforms)
        num_return = dataset.num_return
    else:
        raise NotImplementedError

    return dataset, num_return

# Model

In [8]:
active, pre_index, out_index = check_active(node_num=5, connect_gene=list(np.random.randint(0, 2, size=[10])))
model = UnetBlock(base_ch=36, active=active, pre_index=pre_index, out_index=out_index, node_func_type='conv_in_mish_3').cuda(0)
# model = get_func('conv_in_mish_3', channel=16)
x = torch.rand(1, 36, 64, 64).cuda(0)
y = model(x)
param = count_param(model)
print('Total parameters: ',param)

Total parameters:  81900


# Training

In [20]:
class NoDaemonProcess(mp.Process):

    def _get_daemon(self):
        return False

    def _set_daemon(self, value):
        pass

    daemon = property(_get_daemon, _set_daemon)


class NoDaemonProcessPool(multiprocessing.pool.Pool):
    Process = NoDaemonProcess

In [31]:
def func_try(population, ind_num, device, model_settings):
    i = 0
    seed = 12
    
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    np.random.seed(seed)
    #torch.cuda.set_device(device=device)
    
    while True:
        i += 1
        mem_max_cached = torch.cuda.max_memory_cached(device=device) / 1000 ** 3
        mem_used_cached = torch.cuda.memory_cached(device=device) / 1000 ** 3
        torch.cuda.reset_max_memory_cached(device=device)
        torch.cuda.reset_max_memory_allocated(device=device)
        if i > 5:
            break
        if mem_max_cached > 9 and mem_used_cached > 1:
            curr_device = torch.cuda.current_device()
            torch.cuda.empty_cache()
            time.sleep(3)
        else:
            break
    temp = population[ind_num][0:150]
    population[ind_num] = tools.mutFlipBit(population[ind_num], indpb=0.3)[0]
    population[ind_num][0:150] = temp
    model = Net(gene=population[ind_num][:], model_settings=model_settings)
    print('Have changed the channel number!')

    return model, device


def help_func(optimizer_name, learning_rate, l2_weight_decay, gen_num, ind_num, model, batch_size, epochs, device,
              train_set_name, valid_set_name,
              train_set_root, valid_set_root, exp_name,
              population, model_settings):
    metrics, flag = train_one_model(optimizer_name, learning_rate, l2_weight_decay, gen_num, ind_num, model, batch_size,
                                    epochs, device, train_set_name,
                                    valid_set_name,
                                    train_set_root, valid_set_root, exp_name)
    if flag == False:
        while True:
            model, device = func_try(population, ind_num, device, model_settings)
            metrics, flag = train_one_model(optimizer_name, learning_rate, l2_weight_decay, gen_num, ind_num, model,
                                            batch_size, epochs, device,
                                            train_set_name, valid_set_name,
                                            train_set_root, valid_set_root, exp_name)
            if flag == True:
                break

    return metrics


def util_function(i):
    return help_func(i[0], i[1], i[2], i[3], i[4], i[5], i[6], i[7], i[8], i[9], i[10], i[11], i[12], i[13], i[14], i[15])

Optimizer

In [22]:
class Lookahead(Optimizer):
    def __init__(self, base_optimizer,alpha=0.5, k=6):
        if not 0.0 <= alpha <= 1.0:
            raise ValueError(f'Invalid slow update rate: {alpha}')
        if not 1 <= k:
            raise ValueError(f'Invalid lookahead steps: {k}')
        self.optimizer = base_optimizer
        self.param_groups = self.optimizer.param_groups
        self.alpha = alpha
        self.k = k
        for group in self.param_groups:
            group["step_counter"] = 0
        self.slow_weights = [[p.clone().detach() for p in group['params']]
                                for group in self.param_groups]

        for w in it.chain(*self.slow_weights):
            w.requires_grad = False

    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()
        loss = self.optimizer.step()
        for group,slow_weights in zip(self.param_groups,self.slow_weights):
            group['step_counter'] += 1
            if group['step_counter'] % self.k != 0:
                continue
            for p,q in zip(group['params'],slow_weights):
                if p.grad is None:
                    continue
                q.data.add_(self.alpha,p.data - q.data)
                p.data.copy_(q.data)
        return loss

def get_optimizer(optimizer_name, params, learning_rate, l2_weight_decay):
    if optimizer_name == 'SGD':
        from torch.optim import SGD
        optimizer = SGD(params=params, lr=learning_rate, weight_decay=l2_weight_decay)

    elif optimizer_name == 'Adam':
        from torch.optim import Adam
        optimizer = Adam(params=params, lr=learning_rate, weight_decay=l2_weight_decay)

    elif optimizer_name == 'RMS':
        from torch.optim.rmsprop import RMSprop
        optimizer = RMSprop(params=params, lr=learning_rate, weight_decay=l2_weight_decay)

    elif optimizer_name == 'Lookahead(Adam)':
        from torch.optim import Adam
        base_optimizer = Adam(params=params, lr=learning_rate, weight_decay=l2_weight_decay)
        optimizer = Lookahead(base_optimizer=base_optimizer)

    else:
        raise NotImplementedError

    return optimizer

Loss function

In [23]:
class FocalLossForSigmoid(nn.Module):
    def __init__(self, gamma=2, alpha=0.55, reduction='mean'):
        super(FocalLossForSigmoid, self).__init__()
        self.gamma = gamma
        assert 0 <= alpha <= 1, 'The value of alpha must in [0,1]'
        self.alpha = alpha
        self.reduction = reduction
        self.bce = nn.BCELoss(reduce=False)

    def forward(self, input_, target):
        input_ = torch.clamp(input_, min=1e-7, max=(1 - 1e-7))

        if self.alpha != None:
            loss = (self.alpha * target + (1 - target) * (1 - self.alpha)) * (
                torch.pow(torch.abs(target - input_), self.gamma)) * self.bce(input_, target)
        else:
            loss = torch.pow(torch.abs(target - input_), self.gamma) * self.bce(input_, target)
        if self.reduction == 'mean':
            loss = torch.mean(loss)
        elif self.reduction == 'sum':
            loss = torch.sum(loss)
        else:
            pass

        return loss

Train model

In [35]:
def train_one_model(optimizer_name, learning_rate, l2_weight_decay, gen_num, ind_num, model, batch_size, epochs, device,
                    train_set_name, valid_set_name,
                    train_set_root, valid_set_root, exp_name,
                    mode='train'):

    seed = 12
    
    torch.cuda.empty_cache()
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.benchmark = True

    #model.to(device)
    model.train()

    loss_func = FocalLossForSigmoid(reduction='mean').to(device)
    optimizer = get_optimizer(optimizer_name, filter(lambda p: p.requires_grad, model.parameters()), learning_rate, l2_weight_decay)

    train_set, num_return = get_datasets(train_set_name, train_set_root, True)
    valid_set, _ = get_datasets(valid_set_name, valid_set_root, False)
    train_loader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True, num_workers=2)
    valid_loader = DataLoader(dataset=valid_set, batch_size=1, shuffle=False, num_workers=1)

    best_f1_score = 0
    flag = 0
    count = 0

    valid_epoch = 80
    metrics_name = ['flops', 'param', 'accuracy', 'recall', 'specificity', 'precision', 'f1_score', 'auroc', 'iou']
    metrics = {}
    for metric_name in metrics_name:
        if metric_name == 'flops' or metric_name == 'param':
            metrics.update({metric_name: 100})
        else:
            metrics.update({metric_name: 0})

    try:
        for i in range(epochs):
            train_tqdm_batch = tqdm(iterable=train_loader, total=np.ceil(len(train_set) / batch_size))

            for images, targets in train_tqdm_batch:
                images, targets = images.to(device), targets.to(device)
                optimizer.zero_grad()
                preds = model(images)
                loss = loss_func(preds, targets)
                loss.backward()
                clip_grad_norm_(model.parameters(), 0.1)
                optimizer.step()
            train_tqdm_batch.close()

            print('gens_{} individual_{}_epoch_{} train end'.format(gen_num, ind_num, i))

            epoch_acc = AverageMeter()
            epoch_recall = AverageMeter()
            epoch_precision = AverageMeter()
            epoch_specificity = AverageMeter()
            epoch_f1_score = AverageMeter()
            epoch_iou = AverageMeter()
            epoch_auroc = AverageMeter()

            if (i >= valid_epoch):
                with torch.no_grad():
                    model.eval()
                    valid_tqdm_batch = tqdm(iterable=valid_loader, total=np.ceil(len(valid_set) / 1))
                    
                    for images, targets in valid_tqdm_batch:
                        images = images.to(device)
                        targets = targets.to(device)
                        preds = model(images)

                        (acc, recall, specificity, precision,
                         f1_score, iou, auroc) = calculate_metrics(preds=preds, targets=targets, device=device)
                        epoch_acc.update(acc)
                        epoch_recall.update(recall)
                        epoch_precision.update(precision)
                        epoch_specificity.update(specificity)
                        epoch_f1_score.update(f1_score)
                        epoch_iou.update(iou)
                        epoch_auroc.update(auroc)

                    if i == valid_epoch:
                        flops, param = profile(model=model, inputs=(images,), verbose=False)
                        flops = flops / 1e11
                        param = param / 1e6
                  
                    print('gens_{} individual_{}_epoch_{} validate end'.format(gen_num, ind_num, i))
                    print('acc:{} | recall:{} | spe:{} | pre:{} | f1_score:{} | auroc:{}'
                          .format(epoch_acc.val,
                                  epoch_recall.val,
                                  epoch_specificity.val,
                                  epoch_precision.val,
                                  epoch_f1_score.val,
                                  epoch_auroc.val))
                    if epoch_f1_score.val > best_f1_score:
                        best_f1_score = epoch_f1_score.val

                        flag = i
                        count = 0
                        for key in list(metrics):
                            if key == 'flops':
                                metrics[key] = flops
                            elif key == 'param':
                                metrics[key] = param
                            elif key == 'accuracy':
                                metrics[key] = epoch_acc.val
                            elif key == 'recall':
                                metrics[key] = epoch_recall.val
                            elif key == 'specificity':
                                metrics[key] = epoch_specificity.val
                            elif key == 'precision':
                                metrics[key] = epoch_precision.val
                            elif key == 'f1_score':
                                metrics[key] = epoch_f1_score.val
                            elif key == 'auroc':
                                metrics[key] = epoch_auroc.val
                            elif key == 'iou':
                                metrics[key] = epoch_iou.val
                            else:
                                raise NotImplementedError

                        import pandas as pd
                        from os.path import join
                        performance_df = pd.DataFrame(
                            data=[[gen_num, ind_num, epoch_acc.val, epoch_recall.val, epoch_specificity.val,
                                   epoch_precision.val,
                                   epoch_f1_score.val, epoch_iou.val, epoch_auroc.val]],
                            columns=['epoch', 'individual', 'acc', 'recall',
                                     'specificity', 'precision', 'f1_score', 'iou',
                                     'auroc', ]

                        )
                        performance_csv_path = join(os.path.abspath(output_path), 'exps/{}/csv'.format(exp_name),
                                                    'gens_{} individual_{} performance.csv'.format(gen_num, ind_num))
                        performance_df.to_csv(performance_csv_path)
                    else:
                        if i >= valid_epoch:
                            count += 1

                    end = None
                    if i > valid_epoch + 15 and best_f1_score < 0.50:
                        end = True
                    if (count >= 70) or end:
                        print('current best epoch_{} best_f1_score:'.format(flag), best_f1_score)
                        print('gens_{} individual_{} train early stop'.format(gen_num, ind_num))
                        print('=======================================================================')
                        valid_tqdm_batch.close()
                        return metrics, True
                    print('current best epoch_{} best_f1_score:'.format(flag), best_f1_score)
                    valid_tqdm_batch.close()
        print('current best epoch_{} best_f1_score:'.format(flag), best_f1_score)
        print('=======================================================================')
    except RuntimeError as exception:
        images.detach_()
        del images
        del model
        del targets
        return metrics, False
    return metrics, True

Train parameters

In [24]:
def train_population_parr(train_list, gen_num, population, batch_size, devices, epochs, exp_name, train_set_name,
                          valid_set_name, train_set_root, valid_set_root,
                          optimizer_name, learning_rate, l2_weight_decay, model_settings):
    seed = 12
    
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    np.random.seed(seed)
    
    model_list = []
    metrics_ = []
    pickle_file = open(os.path.join(os.path.abspath(output_path), 'exps/{}/pickle/gens_{}individuals_code.pkl'.format(exp_name, gen_num)),'wb')

    assert len(train_list) == len(population)
    for individual, inds in zip(population, train_list):
        list_ = {'gens_{}_individual_{}'.format(gen_num, inds): individual[:]}
        pickle.dump(list_, pickle_file)
        model_list.append(Net(gene=individual[:], model_settings=model_settings))
    pickle_file.close()
    gpu_num = len(devices)
    
    for i in np.arange(0, len(population), gpu_num):
        process_num = np.min((i + gpu_num, len(population))) - i
        pool = NoDaemonProcessPool(process_num)
        args = [
            (optimizer_name, learning_rate, l2_weight_decay, gen_num, train_list[i + j], model_list[i + j], batch_size,
             epochs, devices[j],
             train_set_name, valid_set_name,
             train_set_root, valid_set_root, exp_name, population, model_settings)
            for j in
            range(process_num)]
        metrics = pool.map(util_function, args)
        pool.terminate()
        metrics_.extend(metrics)
    return metrics_

# Evolve

In [26]:
import matplotlib as mpl
mpl.use('Agg')
torch.multiprocessing.set_sharing_strategy('file_system')

In [27]:
def find_train_inds(population):
    i = 0
    train_list = []
    for ind in population:
        if ind.fitness.valid == False:
            train_list.append(i)
        i += 1
    return train_list


def special_initialization(population, code_list):
    for ind, code in zip(population, code_list):
        ind[:] = code
    return population


def save_evolution_stat_ckpt(evolution_stat_dict, exp_name, g):
    import pickle
    pickle_file1 = open(
        os.path.join(os.path.abspath(output_path), 'exps/{}/pickle/gens{}_evolution_stat_dict.pkl'.format(exp_name, g)),
        'wb')
    pickle.dump(evolution_stat_dict, pickle_file1)
    pickle_file1.close()


def reload_evolution_stat_ckpt(exp_name, g):
    import pickle
    pickle_file = open(
        os.path.join(os.path.abspath(output_path), 'exps/{}/pickle/gens{}_evolution_stat_dict.pkl'.format(exp_name, g)), 'rb')
    pkl2 = pickle.load(pickle_file)
    pickle_file.close()
    evolution_stat_dict = pkl2

    return evolution_stat_dict


def save_population_ckpt(population, exp_name, g):
    import pickle
    pickle_file1 = open(os.path.join(os.path.abspath(output_path), 'exps/{}/pickle/gens{}_ckpt.pkl'.format(exp_name, g)),
                        'wb')
    pickle.dump(population, pickle_file1)
    pickle_file1.close()


def reload_population_ckpt(exp_name, g):
    import pickle
    pickle_file = open(os.path.join(os.path.abspath(output_path), 'exps/{}/pickle/gens{}_ckpt.pkl'.format(exp_name, g)), 'rb')
    pkl2 = pickle.load(pickle_file)
    pickle_file.close()
    population = pkl2

    return population

In [28]:
def check_dir(exp_name):
    exps_path = os.path.abspath(output_path)
    ckpt_path = os.path.join(exps_path, 'exps/{}/ckpt'.format(exp_name))
    runs_path = os.path.join(exps_path, 'exps/{}/runs'.format(exp_name))
    pickle_path = os.path.join(exps_path, 'exps/{}/pickle'.format(exp_name))
    csv_path = os.path.join(exps_path, 'exps/{}/csv'.format(exp_name))

    if not os.path.exists(ckpt_path):
        os.makedirs(ckpt_path, exist_ok=True)
    if not os.path.exists(runs_path):
        os.makedirs(runs_path, exist_ok=True)
    if not os.path.exists(pickle_path):
        os.makedirs(pickle_path, exist_ok=True)
    if not os.path.exists(csv_path):
        os.makedirs(csv_path, exist_ok=True)


def get_gene_len(de_func_type, en_func_type, de_node_num_list, en_node_num_list, only_en=False):
    de_func_type_num = len(de_func_type)
    en_func_type_num = len(en_func_type)

    de_node_func_gene_len = int(np.ceil(np.log2(de_func_type_num)))
    en_node_func_gene_len = int(np.ceil(np.log2(en_func_type_num)))

    de_connect_gene_len_list = [None for _ in range(len(de_node_num_list))]
    en_connect_gene_len_list = [None for _ in range(len(en_node_num_list))]

    for i in range(len(de_node_num_list)):
        de_connect_gene_len_list[i] = int(comb(de_node_num_list[i], 2))
    for i in range(len(en_node_num_list)):
        en_connect_gene_len_list[i] = int(comb(en_node_num_list[i], 2))

    de_gene_len_list = [None for _ in range(len(de_node_num_list))]
    en_gene_len_list = [None for _ in range(len(en_node_num_list))]

    for i in range(len(de_node_num_list)):
        de_gene_len_list[i] = de_node_func_gene_len + de_connect_gene_len_list[i]
    for i in range(len(en_node_num_list)):
        en_gene_len_list[i] = en_node_func_gene_len + en_connect_gene_len_list[i]

    if only_en:
        gene_len = sum(en_gene_len_list)
    else:
        gene_len = sum(de_gene_len_list) + sum(en_gene_len_list)

    return gene_len


def bin(n):
    result = ''
    if n:
        result = bin(n // 2)
        return result + str(n % 2)
    else:
        return result


def cxMultiPoint(ind1, ind2):
    size = min(len(ind1), len(ind2))
    cxpoints = []
    for _ in range(10):
        point = random.randint(0, size)
        while point in cxpoints:
            point = random.randint(0, size)
        cxpoints.append(point)
    cxpoints.sort()
    cxpoint1, cxpoint2, cxpoint3, cxpoint4, cxpoint5, cxpoint6, cxpoint7, cxpoint8, cxpoint9, cxpoint10 = cxpoints
    ind1[cxpoint1:cxpoint2], ind2[cxpoint1:cxpoint2] \
        = ind2[cxpoint1:cxpoint2], ind1[cxpoint1:cxpoint2]
    ind1[cxpoint3:cxpoint4], ind2[cxpoint3:cxpoint4] \
        = ind2[cxpoint3:cxpoint4], ind1[cxpoint3:cxpoint4]
    ind1[cxpoint5:cxpoint6], ind2[cxpoint5:cxpoint6] \
        = ind2[cxpoint5:cxpoint6], ind1[cxpoint5:cxpoint6]
    ind1[cxpoint7:cxpoint8], ind2[cxpoint7:cxpoint8] \
        = ind2[cxpoint7:cxpoint8], ind1[cxpoint7:cxpoint8]
    ind1[cxpoint9:cxpoint10], ind2[cxpoint9:cxpoint10] \
        = ind2[cxpoint9:cxpoint10], ind1[cxpoint9:cxpoint10]

    return ind1, ind2

In [29]:
gpu_num = 4
seed = 12
random.seed(seed)
np.random.seed(seed)
optimization_objects = ['f1_score']
optimization_weights = [1]

channel = 20
en_node_num = 5
de_node_num = 5
sample_num = 3
exp_name = 'test'
crossover_rate = 0.9
mutation_rate = 0.7
flipping_rate = 0.05
gens = 50
epochs = 130
batch_size = 1
parents_num = 20
offsprings_num = 20
    
devices = [torch.device(type='cuda', index=i) for i in range(gpu_num)]
optimizer_name = 'Lookahead(Adam)'
learning_rate = 0.001
l2_weight_decay = 0

print(devices)

[device(type='cuda', index=0), device(type='cuda', index=1), device(type='cuda', index=2), device(type='cuda', index=3)]


In [None]:
resume_train = False
train_set_name = 'DRIVE'
valid_set_name = 'DRIVE'
train_set_root = os.path.join(os.path.abspath(data_path), 'dataset', 'trainset', train_set_name)
valid_set_root = os.path.join(os.path.abspath(data_path), 'dataset', 'validset', valid_set_name)

en_node_num_list = [en_node_num for _ in range(sample_num + 1)]
de_node_num_list = [de_node_num for _ in range((sample_num))]

func_type = ['conv_relu_3', 'conv_mish_3', 'conv_in_relu_3',
             'conv_in_mish_3', 'p_conv_relu_3', 'p_conv_mish_3',
             'p_conv_in_relu_3', 'p_conv_in_mish_3', 'conv_relu_5',
             'conv_mish_5', 'conv_in_relu_5','conv_in_mish_5', 'p_conv_relu_5',
             'p_conv_mish_5','p_conv_in_relu_5', 'p_conv_in_mish_5']

gene_len = get_gene_len(de_func_type=func_type, en_func_type=func_type, de_node_num_list=de_node_num_list,
                       en_node_num_list=en_node_num_list, only_en=False)

model_settings = {'channel': channel, 'en_node_num_list': en_node_num_list, 'de_node_num_list': de_node_num_list,
                  'sample_num': sample_num, 'en_func_type': func_type, 'de_func_type': func_type}

creator.create("FitnessMax", base.Fitness, weights=optimization_weights)
creator.create("Individual", list, fitness=creator.FitnessMax)
toolbox = base.Toolbox()
toolbox.register("attr_bool", random.randint, 0, 1)
toolbox.register("individual", tools.initRepeat, creator.Individual, toolbox.attr_bool, gene_len)
toolbox.register("population", tools.initRepeat, list, toolbox.individual)
toolbox.register("mutateL", tools.mutFlipBit, indpb=flipping_rate)

check_dir(exp_name)
sum_writer = SummaryWriter(log_dir=os.path.join(os.path.abspath(output_path), 'exps/{}/runs'.format(exp_name)))

if resume_train:
    g = 1
    exp_name_load = None
    population = reload_population_ckpt(exp_name_load, g=g)

    for i in range(len(population)):
        if not os.path.exists(
                os.path.join(os.path.abspath(output_path), 'exps/{}/ckpt/individual_{}'.format(exp_name, i))):
            os.mkdir(os.path.join(os.path.abspath(output_path), 'exps/{}/ckpt/individual_{}'.format(exp_name, i)))
    if not os.path.exists(os.path.join(os.path.abspath(output_path), 'exps/{}/pickle/'.format(exp_name))):
        os.mkdir(os.path.join(os.path.abspath(output_path), 'exps/{}/pickle/'.format(exp_name)))
    offspring = None

else:
    population = toolbox.population(n=parents_num)
    print('==========Sucessfully initialize population==========')

    for i in range(len(population)):
        if not os.path.exists(os.path.join(os.path.abspath(output_path), 'exps/{}/ckpt/individual_{}'.format(exp_name, i))):
            os.mkdir(os.path.join(os.path.abspath(output_path), 'exps/{}/ckpt/individual_{}'.format(exp_name, i)))
    if not os.path.exists(os.path.join(os.path.abspath(output_path), 'exps/{}/pickle/'.format(exp_name))):
        os.mkdir(os.path.join(os.path.abspath(output_path), 'exps/{}/pickle/'.format(exp_name)))

    train_list = find_train_inds(population)
    print('gens_{} train individuals is:'.format(0), train_list)

    metrics = train_population_parr(train_list=train_list, gen_num=0, population=population, batch_size=batch_size,
                                        devices=devices, epochs=epochs, exp_name=exp_name,
                                        train_set_name=train_set_name,
                                        valid_set_name=valid_set_name, train_set_root=train_set_root,
                                        valid_set_root=valid_set_root, optimizer_name=optimizer_name,
                                        learning_rate=learning_rate,
                                        model_settings=model_settings, l2_weight_decay=l2_weight_decay)

    for i in range(len(population)):
        fitness = []
        for opt_obj in optimization_objects:
            fitness.append(metrics[i][opt_obj])
        population[i].fitness.values = fitness

    print('evaluate gens_{} successfully'.format(0))
    save_population_ckpt(population=population, exp_name=exp_name, g=0)

    print('>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>')
    g = 0
    sum_writer.add_scalar('best_fitness', tools.selBest(population,k=1)[0].fitness.values[0], g)
    offspring = None

In [None]:
for n in range(g + 1, gens):

    from copy import deepcopy
    parents = deepcopy(population)
    new_parents = list(map(toolbox.clone, parents))
    if offspring != None:
        del offspring
    offspring = toolbox.population(n=offsprings_num)
    if len(new_parents) >= 2:
        for i in range(int(np.ceil(offsprings_num // 2))):
            if random.random() < crossover_rate:
                for _ in range(10):
                    new_parents_list = deepcopy(tools.selTournament(new_parents, 2, tournsize=2))
                    gene_len = len(new_parents_list[0])
                    xor_result = []
                    for p in range(gene_len):
                        xor_result.append(int(new_parents_list[0][p]) ^ int(new_parents_list[1][p]))
                    diff = sum(xor_result) / gene_len
                    if diff > 0.2:
                        break
                off1, off2 = cxMultiPoint(new_parents_list[0], new_parents_list[1])
            else:
                new_parents_list = deepcopy(tools.selTournament(new_parents, 2, tournsize=2))
                off1, off2 = new_parents_list[0], new_parents_list[1]
                    
            offspring[i][:] = off1[:]
            offspring[i + 1][:] = off2[:]
            del off1.fitness.values
            del off2.fitness.values
            del offspring[i].fitness.values
            del offspring[i + 1].fitness.values

            del new_parents_list
        offspring = offspring[:offsprings_num]

        for i in range(offsprings_num):
            pb = mutation_rate
            if random.random() < pb:
                offspring[i][:] = toolbox.mutateL(offspring[i])[0]
                del offspring[i].fitness.values
    else:
        for i in range(len(offspring)):
            new_parents_list = deepcopy(tools.selRandom(new_parents, 1))
            off = toolbox.mutateL(new_parents_list[0])

            offspring[i][:] = off[0]
            del offspring[i].fitness.values

    print('gens_{} crossover and mutation successfully'.format(n))
    print('gens_{} mutation successfully'.format(n))

    invalid_ind = [ind for ind in offspring if not ind.fitness.valid]
    train_list = find_train_inds(invalid_ind)
    print('gens_{} train individuals is:'.format(n), train_list)
    print('train individuals code are:', invalid_ind[:])
    metrics = train_population_parr(train_list=train_list, gen_num=n, population=invalid_ind, batch_size=batch_size,
                                        devices=devices, epochs=epochs, exp_name=exp_name,
                                        train_set_name=train_set_name,
                                        valid_set_name=valid_set_name, train_set_root=train_set_root,
                                        valid_set_root=valid_set_root, optimizer_name=optimizer_name,
                                        learning_rate=learning_rate,
                                        model_settings=model_settings, l2_weight_decay=l2_weight_decay)
    print('fitness of all trained model:', metrics)

    for i in range(len(offspring)):
        fitness = []
        for opt_obj in optimization_objects:
            fitness.append(metrics[i][opt_obj])
        invalid_ind[i].fitness.values = fitness

    cad_pop = population + offspring
    best5_pop = tools.selBest(cad_pop, 5)
    for ind in best5_pop:
        cad_pop.remove(ind)
    other_pop = tools.selTournament(cad_pop, k=parents_num - 5, tournsize=2)
    new_offspring = best5_pop + other_pop
                        
    sum_writer.add_scalar('best_fitness', tools.selBest(new_offspring,k=1)[0].fitness.values[0], g)
    population[:] = new_offspring
    save_population_ckpt(population=population, exp_name=exp_name, g=n)

    print('evaluate gens_{} successfully'.format(n))
    print('>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>')

best_ind = tools.selBest(population, parents_num)
best_inddividuals = deepcopy(best_ind[:])
pickle_file = open(
    os.path.join(os.path.abspath(output_path), 'exps/{}/pickle/gens_{} best_individuals_code.pkl'.format(exp_name, gens)),
    'wb')
pickle.dump(best_inddividuals, pickle_file)
pickle_file.close()