# import libraries

In [None]:
import os
import sys
import time
import csv
import torch
import shutil
import pandas
import pickle
import numpy as np
from torch import nn
from tqdm import tqdm
from thop import profile
from datetime import datetime
from btlbo_unet.loss import *
import matplotlib.pyplot as plt
from torchsummary import summary
from torch.utils.data import DataLoader
from util.create_dir import create_dirs
from torchvision.utils import save_image
from btlbo_unet.btlbo import BTLBOUNet22
from util.get_optimizer import get_optimizer
from loss.FocalLoss import FocalLossForSigmoid
from metrics.average_meter import AverageMeter
from dataset.util.get_datasets import get_datasets
from torch.nn.utils.clip_grad import clip_grad_norm_
from metrics.calculate_metrics import calculate_metrics

seed = 7546
np.random.seed(seed)

optimizer_name = 'Adam' #'Lookahead(Adam)'
learning_rate = 0.001
l2_weight_decay = 0
epochs = 80
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
print(device)

## Dataset Preparing

In [None]:
dataset_name = 'CHASEDB'
                                                                                                                            
train_set_root = os.path.join(os.path.abspath('.'), 'dataset', 'trainset', dataset_name)
valid_set_root = os.path.join(os.path.abspath('.'), 'dataset', 'testset', dataset_name)
batch_size = 2

train_set, num_return = get_datasets(dataset_name, train_set_root, True)
valid_set, _ = get_datasets(dataset_name, valid_set_root, False)
train_loader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True, num_workers=3)
valid_loader = DataLoader(dataset=valid_set, batch_size=1, shuffle=False, num_workers=1)
print(len(train_set), len(valid_set))

metrics_name = ['flops', 'param', 'accuracy', 'recall', 'specificity', 'precision', 'f1_score', 'dice', 'auroc', 'iou']        
metrics1 = {'whole_best_f1_score': 0, 'train_best_f1_score': 0}

torch.cuda.empty_cache()
torch.backends.cudnn.benchmark = True

LOSS_Functions = [FocalLossForSigmoid(reduction='mean').to(device), DiceBCELoss()]  #, FocalTverskyLoss(), DiceLoss()
                 
def todec(b):
    #print(b)
    return int(''.join(map(lambda x: str(int(x)), b)), 2)


create_dirs("results")
file_name = 'results/'+dataset_name+'.csv'
print(file_name)

if not os.path.exists(file_name):
    print("writing new")
    with open(file_name,'a') as fp:
        wr = csv.writer(fp, dialect='excel')
        wr.writerow(['generation', 'Index', 'epoch', 'acc', 'recall', 'spe', 'pre', 'f1_score', 'dice', 'auroc', \
                     'mins', 'stime', 'etime', 'loss', 'Particle', 'flops', 'param'])

mpath = 'results/'+dataset_name+'_model_02.pth'   

# Model Evaluation

In [None]:
def runModel(g, index, particle, rerun=False):
    ptcl = ' '.join(map(str, particle))
    print(particle)
    if not rerun:
        f = pandas.read_csv(file_name)
        if ptcl in f['Particle'].values:
            print("already found")
            f1s = float(np.max(f[f['Particle'] == ptcl]['dice']))
            return f1s, None
    model_name = 'BTUNet'
    start = time.time()
    model = BTLBOUNet22(particle)
    model.to(device)
    lrn = particle[140]
    print(lrn)
    loss_func = LOSS_Functions[lrn]
    print(loss_func)
    optimizer = get_optimizer(optimizer_name, filter(lambda p: p.requires_grad, model.parameters()), learning_rate, l2_weight_decay)
    best_dice = 0
    flag = 0
    count = 0
    valid_epoch = 50
    
    model.train()
    
    
    metrics = {}
    for metric_name in metrics_name:
        if metric_name == 'flops' or metric_name == 'param':
            metrics.update({metric_name: 100})
        else:
            metrics.update({metric_name: 0})

    
    ##Training
    for i in range(150):
        print("epoch {}".format(i))
        train_tqdm_batch = tqdm(iterable=train_loader, total=numpy.ceil(len(train_set) / batch_size))
        c = 0
        
        epoch_acc1 = AverageMeter()
        epoch_recall1 = AverageMeter()
        epoch_precision1 = AverageMeter()
        epoch_specificity1 = AverageMeter()
        epoch_f1_score1 = AverageMeter()
        epoch_dice1 = AverageMeter()
        epoch_iou1 = AverageMeter()
        epoch_auroc1 = AverageMeter() 

    
        for images, targets in train_tqdm_batch:
            optimizer.zero_grad()
            images = images.to(device)
            targets = targets.to(device)
            preds = model(images)
            preds1 = preds.clone()
            preds1[preds <= 0.5] = 0
            preds1[preds > 0.5] = 1
            
            loss = loss_func(preds, targets)
            loss.backward()
            clip_grad_norm_(model.parameters(), 0.1)
            optimizer.step()
            
            
            (acc, recall, specificity, precision,
                         f1_score, dice, iou, auroc) = calculate_metrics(preds=preds1, targets=targets, device=device)
            epoch_acc1.update(acc)
            epoch_recall1.update(recall)
            epoch_precision1.update(precision)
            epoch_specificity1.update(specificity)
            epoch_f1_score1.update(f1_score)
            epoch_dice1.update(dice)
            epoch_iou1.update(iou)
            epoch_auroc1.update(auroc)
        
        
        train_tqdm_batch.close()
        
        print('Training  -- acc:{} | recall:{} | spe:{} | pre:{} | f1_score:{} | dice:{} | auroc:{}'
                  .format(epoch_acc1.val,
                          epoch_recall1.val,
                          epoch_specificity1.val,
                          epoch_precision1.val,
                          epoch_f1_score1.val,
                          epoch_dice1.val,
                          epoch_auroc1.val))
    
        if (metrics1['train_best_f1_score'] < epoch_dice1.val):
            metrics1['train_best_f1_score'] = epoch_dice1.val
            flag1 = i

        epoch_acc = AverageMeter()
        epoch_recall = AverageMeter()
        epoch_precision = AverageMeter()
        epoch_specificity = AverageMeter()
        epoch_f1_score = AverageMeter()
        epoch_dice = AverageMeter()
        epoch_iou = AverageMeter()
        epoch_auroc = AverageMeter()

        
        if (i >= valid_epoch):
            print("validation")
            with torch.no_grad():
                model.eval()
                valid_tqdm_batch = tqdm(iterable=valid_loader, total=numpy.ceil(len(valid_set) / 1))
                for images, targets in valid_tqdm_batch:
                    images = images.to(device)
                    targets = targets.to(device)
                    preds = model(images)
                    preds1 = preds.clone()
                    preds1[preds <= 0.5] = 0
                    preds1[preds > 0.5] = 1
#                     save_image(np.squeeze(images), 'results/'+dataset_name+'/'+str(c1)+'_images.png')
#                     save_image(np.squeeze(preds), 'results/'+dataset_name+'/'+str(c1)+'_preds.png')                    
#                     save_image(np.squeeze(preds1), 'results/'+dataset_name+'/'++str(c1)+'_preds1.png')

                    (acc, recall, specificity, precision,
                     f1_score, dice, iou, auroc) = calculate_metrics(preds=preds1, targets=targets, device=device)
                    
                    epoch_acc.update(acc)
                    epoch_recall.update(recall)
                    epoch_precision.update(precision)
                    epoch_specificity.update(specificity)
                    epoch_f1_score.update(f1_score)
                    epoch_dice.update(dice)
                    epoch_iou.update(iou)
                    epoch_auroc.update(auroc)
                    
   
                if i == valid_epoch:
                    try:
                        flops, param = profile(model=model, inputs=(images,), verbose=False)
                        flops = flops / 1e11
                        param = param / 1e6
                    except:
                        flops =0 
                        param = 0

                print('metr- acc:{} | recall:{} | spe:{} | pre:{} | f1_score:{} | dice:{} | auroc:{}'
                      .format(epoch_acc.val,
                              epoch_recall.val,
                              epoch_specificity.val,
                              epoch_precision.val,
                              epoch_f1_score.val,
                              epoch_dice.val,
                              epoch_auroc.val))
                
                
                if epoch_dice.val > best_dice:
                    best_dice = epoch_dice.val
                    
                    if metrics1['whole_best_f1_score'] < epoch_dice.val:
                        metrics1['whole_best_f1_score'] = epoch_dice.val
                        torch.save(model, mpath)                        

                    flag = i
                    count = 0
                    for key in list(metrics):
                        if key == 'flops':
                            metrics[key] = flops
                        elif key == 'param':
                            metrics[key] = param
                        elif key == 'accuracy':
                            metrics[key] = epoch_acc.val
                        elif key == 'recall':
                            metrics[key] = epoch_recall.val
                        elif key == 'specificity':
                            metrics[key] = epoch_specificity.val
                        elif key == 'precision':
                            metrics[key] = epoch_precision.val
                        elif key == 'f1_score':
                            metrics[key] = epoch_f1_score.val
                        elif key == 'dice':
                            metrics[key] = epoch_dice.val
                        elif key == 'auroc':
                            metrics[key] = epoch_auroc.val
                        elif key == 'iou':
                            metrics[key] = epoch_iou.val
                        else:
                            raise NotImplementedError

                else:
                    if i >= valid_epoch:
                        count += 1

                end = None
                
                if i > valid_epoch + 15 and best_dice < 0.50:
                    end = True
                
                if epoch_dice.val < 0.1 and i > 85:
                    print("closing as f1sore is not increasing")
                    valid_tqdm_batch.close()
                    break
                    
                if (count >= 70) or end:
                    valid_tqdm_batch.close()
                    #return metrics, True
    #             print('current best epoch_{} best_f1_score:'.format(flag), best_f1_score)
                valid_tqdm_batch.close()
    print('current best epoch_{} best_f1_score:'.format(flag), best_dice)
                
    # saving
    l=[]
    end = time.time()
    l.extend([g, index, flag, metrics['accuracy'], metrics['recall'], metrics['specificity'], metrics['precision'], \
              metrics['f1_score'], metrics['dice'], metrics['auroc'], int((end-start)/60), \
              datetime.fromtimestamp(start).strftime('%Y-%m-%d %H:%M:%S'),\
              datetime.fromtimestamp(end).strftime('%Y-%m-%d %H:%M:%S'), loss_func, ptcl])
    with open(file_name,'a') as fp:
        wr = csv.writer(fp, dialect='excel')
        wr.writerow(l)

    return best_dice

## BLTBO

In [None]:
n_population = 20
pidx = list(range(n_population))
m = 141  # columns
MAX_GEN = 30

# init population and fitness
print("Initializing new population and fitness")
# population_ = np.random.randint(0,2,(n_population, m))
# fitness_ = np.zeros((n_population,), dtype=np.float32)

population_ = [[1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0], 
[0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0], 
[1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1], 
[0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0], 
[1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0], 
[0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0], 
[1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1], 
[1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0], 
[0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1], 
[1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1], 
[1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0], 
[0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0], 
[1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0], 
[1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1], 
[1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0], 
[0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0], 
[1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1], 
[1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0], 
[1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0], 
[1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0]]
fitness_ = [0.8104, 0.8084, 0.8066, 0.8060, 0.8060, 0.8058, 0.8055, 0.8055, 0.8053, 0.8052, 0.8052, 0.8049, 0.8048, 0.8045, 0.8041, 0.8039, 0.8037, 0.8034, 0.8034, 0.8033] 

population_ = np.array(population_)
fitness_ = np.array(fitness_)

# custom function
import math
def sigmoid(x):
    return 1 / (1 + math.exp(-x))
# define vectorized sigmoid
sigmoid_v = np.vectorize(sigmoid)    


## BTLBO algorithm starts here which runs for MAX_GEN generations
for t in range(MAX_GEN):
    pu = []
    for i in range(n_population):   #Each solution in pop
        print(t, i)

    ## teaching phase
        mean = np.nanmean(population_, axis=0)  # column mean.    
        teacher = np.argmax(fitness_)  # select teacher
        T_F = np.random.randint(1, 3)  # teaching factor is randomly 1 or 2.
        r_i = np.random.rand(m)  # multiplier also random.

        new_solution = population_[i] + (r_i * (population_[teacher] - (T_F * mean)))
        
        
        #bounding
        new_solution = np.where(sigmoid_v(new_solution) >= np.random.uniform(size=len(new_solution)), 1, 0)


        try:
            new_fitness = runModel(t, i, new_solution)
        #except tf.errors.ResourceExhaustedError as e:
        except Exception as e:
            print(e)
            new_fitness = 0

        if new_fitness > fitness_[i]:   #Greedy selection
            population_[i] = new_solution
            fitness_[i] = new_fitness


    ## learning phase
        try:
            p = np.random.choice(list(set(pidx[:i] + pidx[(i+1):])-set(pu)))  # pick another random i!=p
        except:
            p = np.random.choice(pidx[:i] + pidx[(i + 1):], 1)
        pu.append(p)
        r_i = np.random.rand(m)

        if fitness_[i] < fitness_[p]:
            new_solution = population_[i] + r_i * (population_[i] - population_[p]).flatten()
        else:
            new_solution = population_[i] - r_i * (population_[i] - population_[p]).flatten()

        #bounding
        new_solution = np.where(sigmoid_v(new_solution) >= np.random.uniform(size=len(new_solution)), 1, 0)


        #Evaluating fitness of new solution
        try:
            new_fitness = runModel(t, i, new_solution)
        #except tf.errors.ResourceExhaustedError as e:
        except Exception as e:
            print(e)
            new_fitness = 0            

        if new_fitness > fitness_[i]:      #Greedy selection
            population_[i] = new_solution
            fitness_[i] = new_fitness

    bestidx_ = np.argmax(fitness_)  # update details
    print("After ", t, " iteration: ", population_[bestidx_], fitness_[bestidx_]) 