In [None]:
!nvidia-smi

# Importing all necessary libraries

In [None]:
import os
import time
import torch
import shutil
import tempfile
import numpy as np
import pandas, csv
from glob import glob
from datetime import datetime
import matplotlib.pyplot as plt
from torchsummary import summary
from monai.losses import DiceLoss
from monai.utils import set_determinism
from monai.handlers.utils import from_engine
from monai.utils.enums import MetricReduction
from monai.transforms import Activations, AsDiscrete, Compose
from monai.data import CacheDataset, DataLoader, decollate_batch
from monai.metrics import DiceMetric, HausdorffDistanceMetric, MeanIoU

from moed.ega import *
from moed.utils import *
from moed.brats2021 import *
from moed.searchspace import *
from moed.moed3d import MOED3D
from moed.task01_braintumour import *

# Dataset Loading and Preprocessing

In [None]:
set_determinism(seed=0)
task_name ="Task01_BrainTumour" 
root_dir = "files/"+ task_name
data_dir = "Datasets/"+ task_name
dataset_name = task_name.split("_")[1]
json_filename = data_dir+"/"+task_name+".json"
file_name_res = f"results/{dataset_name}_single_res.csv"
print(file_name_res)
create_dir(root_dir)
create_dir('results/')
create_res_file(file_name_res)


# Cache rate
cr = 0.1
batchsize = 4
# number of workers
now = batchsize * 2
# Training epochs
max_epochs = 80
# Test epochs
n_epoch = 60
img_size = 96
val_interval = 2
img_shape = (img_size, img_size, img_size)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
scaler = torch.cuda.amp.GradScaler()
torch.backends.cudnn.benchmark = True

(train_images, train_labels), (val_images, val_labels), (test_images, test_labels) = datafold_read(data_dir, json_filename, 0)
train_files = [{"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels)]
val_files = [{"image": image_name, "label": label_name} for image_name, label_name in zip(val_images, val_labels)]
test_files = [{"image": image_name, "label": label_name} for image_name, label_name in zip(test_images, test_labels)]
print(len(train_files), len(val_files), len(test_files))

train_transform = transform_train(img_size)
train_ds = CacheDataset(data=train_files, transform=train_transform, cache_rate=cr, num_workers=now)
train_loader = DataLoader(train_ds, batch_size=batchsize, shuffle=True, num_workers=now)

val_transform = transform_val(img_shape)
val_ds = CacheDataset(data=val_files, transform=val_transform, cache_rate=cr, num_workers=now)
val_loader = DataLoader(val_ds, batch_size=batchsize, shuffle=False, num_workers=now)

## Testing
test_org_transforms = transform_test_org(img_shape)
test_org_ds = CacheDataset(data=test_files, transform=test_org_transforms, cache_rate=cr, num_workers=now)
test_org_loader = DataLoader(test_org_ds, batch_size=batchsize, shuffle=False, num_workers=now)
post_transforms = transform_post(test_org_transforms)

# Visualize the image with corresponding label

In [None]:
with torch.no_grad():
    for tod in test_org_ds:
        plt.figure("image", (24, 6))
        for i in range(4):
            plt.subplot(1, 4, i + 1)
            plt.title(f"image channel {i}")
            plt.imshow(tod["image"][i, :, :, 60].detach().cpu(), cmap="gray")        
        plt.figure("label", (18, 6))
        for i in range(3):
            plt.subplot(1, 3, i + 1)
            plt.title(f"label channel {i}")
            plt.imshow(tod["label"][i, :, :, 55].detach().cpu())
        break

# Model Evaluation

In [None]:

def runModel(model_name, model, optimizer, loss_function, ptcl, file_name_res):     
    exists, dsl, nop, fps = checkCache([file_name_res], ptcl)
    if exists:
        return dsl, nop, fps
    
    model = model.to(device)
    Total_params, flops, fps = calc_pff(model, img_size, 4,4)
    print(f"No.of params - {Total_params}")
   
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs)
    checkpoint_name = model_name+"_"+dataset_name+"_best_metric.pth"
    print(checkpoint_name)
    
    dice_func = DiceMetric(include_background=True, reduction=MetricReduction.MEAN_BATCH, get_not_nans=True)    
    hdm_func = HausdorffDistanceMetric(include_background=True, percentile=95, reduction=MetricReduction.MEAN_BATCH)
  
    run_acc = AverageMeter()
    run_hdm = AverageMeter()

    best_metric = -1
    best_metric_epoch = -1
    best_metrics_epochs_and_time = [[], [], []]
    epoch_loss_values = []

    total_start = time.time()
    
    ## Training starts here
    for epoch in range(max_epochs):
        epoch_start = time.time()
        print("-" * 10)
        print(f"epoch {epoch + 1}/{max_epochs}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = (
                batch_data["image"].to(device),
                batch_data["label"].to(device),
            )
            optimizer.zero_grad()
            with torch.cuda.amp.autocast():
                outputs = model(inputs)
                loss = loss_function(outputs, labels)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            epoch_loss += loss.item()
        lr_scheduler.step()
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
        
        ## Validation
        if epoch >= n_epoch:
            model.eval()
            with torch.no_grad():
                for val_data in val_loader:
                    val_inputs, val_labels = (
                        val_data["image"].to(device),
                        val_data["label"].to(device),
                    )
                    val_outputs = inference(val_inputs, model)
                    val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
                
                    #dice score calculation 
                    dice_func(y_pred=val_outputs, y=val_labels)
                    acc, not_nans = dice_func.aggregate()
                    # print(acc.cpu().numpy())
                    run_acc.update(acc.cpu().numpy())
                    dice_func.reset()
                    
                    #hdm calculation
                    hdm_func(y_pred=val_outputs, y=val_labels)
                    hd = hdm_func.aggregate()
                    # print(hd.cpu().numpy())
                    run_hdm.update(hd.cpu().numpy()) 
                    hdm_func.reset()  
                    
                print(f"current epoch: {epoch + 1}, Dice_Avg: {np.mean(run_acc.val)}, dice_tc: {run_acc.val[0]}, dice_wt: {run_acc.val[1]},\
                        dice_et: {run_acc.val[2]}, HDM_Avg: {np.mean(run_hdm.val)}, hdm_tc: {run_hdm.val[0]}, hdm_wt: {run_hdm.val[1]}, hdm_et: {run_hdm.val[2]}")
                
                if np.mean(run_acc.val) > best_metric:
                    best_metric = np.mean(run_acc.val)
                    best_metric_epoch = epoch + 1
                    best_metrics_epochs_and_time[0].append(best_metric)
                    best_metrics_epochs_and_time[1].append(best_metric_epoch)
                    best_metrics_epochs_and_time[2].append(time.time() - total_start)
                    torch.save(model.state_dict(), os.path.join(root_dir, checkpoint_name),)
                    print(f"saved new best metric model at {best_metric_epoch}")
                              
        
        print(f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}")
        if epoch > 20 and best_metric < 0.15 and np.mean(run_acc.val) < 0.15:
            print(f"Stopping training after {epoch}th epoch as {np.mean(run_acc.val)} dice not increasing")
            break
    total_end = time.time()
    total_time = ((total_end - total_start)/60)/60

    print(f"Train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch},\
                total time: {total_time}.")

    
    ## Testing
    model.load_state_dict(torch.load(os.path.join(root_dir, checkpoint_name)))
    model.eval()

    dice_func.reset()
    hdm_func.reset()
    run_acc.reset()
    run_hdm.reset()

    with torch.no_grad():
        for test_data in test_org_loader:
            test_inputs, test_labels = (
                test_data["image"].to(device),
                test_data["label"].to(device),
            )
            test_outputs = inference(test_inputs, model)
            test_outputs = [post_trans(i) for i in decollate_batch(test_outputs)]
            
            #dice score calculation
            dice_func(y_pred=test_outputs, y=test_labels)
            acc, not_nans = dice_func.aggregate()
            run_acc.update(acc.cpu().numpy())
            dice_func.reset()

            #hdm calculation
            hdm_func(y_pred=test_outputs, y=test_labels)
            hd = hdm_func.aggregate()
            run_hdm.update(hd.cpu().numpy())
            hdm_func.reset()

        print(f"Final validation stats {epoch}/{max_epochs - 1}, Dice_Avg: {np.mean(run_acc.val)}, \
            dice_tc: {run_acc.val[0]}, dice_wt: {run_acc.val[1]}, dice_et: {run_acc.val[2]},  \
            HDM_Avg: {np.mean(run_hdm.val)}, hdm_tc: {run_hdm.val[0]}, hdm_wt: {run_hdm.val[1]}, hdm_et: {run_hdm.val[2]}, \
            time {time.time() - total_end}s ")  
                  
    l=[]
    l.extend([model_name, best_metric, Total_params, flops, fps, total_time, np.mean(run_acc.val), run_acc.val[0], run_acc.val[1], run_acc.val[2], \
              np.mean(run_hdm.val),  run_hdm.val[0], run_hdm.val[1], run_hdm.val[2], ptcl, convert_Time(total_start), convert_Time(total_end), best_metric_epoch, task_name])     
    saveValues(file_name_res,l)

    return 1-np.mean(run_acc.val), Total_params/1e6, flops

# EGA- NAS Evolution

In [None]:
np.random.seed(4345)

# individual length
ch_len = 56    

# Number of generations
max_gen = 20   

# threshold
z = 2

# Population size
pop_size = 20

# crossover rate
r_cross = 0.9

# mutation rate
r_mut = 1.0 / float(ch_len)

# Population evaluation
def evaluation(t, pop):
    for i, ind in enumerate(pop):
        print("----\n\n\n{} {} {}".format(t, i, ind))
        print(ind, type(ind))
        model, optimizer, loss_function, ptcl = encoding_ch(ind, file_name_res)
        # try:
        objective_values[i] = runModel(f"{t}_{i}", model, optimizer, loss_function, ptcl, file_name_res)
        # except Exception as e:
        #     print(e)
        #     objective_values[i] = 99, 9999999999, 99999999999
            
        print(objective_values[i])
    return objective_values

In [None]:
# Population and Objective values Inititalization
pop, objective_values = init_pop(pop_size, ch_len), init_obj(pop_size)

# Initial population evaluation
objective_values = evaluation(0, pop)

# Enumerate generations
for gen in range(1, max_gen):
    print(f"\n\n\n{gen} Generation:")
    
    # E-GA Operations
    selected = np.asarray([parent_selection(pop, objective_values, z) for _ in range(pop_size)])
    print(selected)
    
    child = list()
    for i in range(0, pop_size, 2):
        # get selected parents in pairs
        p1, p2 = selected[i], selected[i+1]
        # crossover
        for c in crossover(p1, p2, r_cross):
            # mutation
            c = mutation(c, r_mut)
            # store for next generation
            child.append(c)
    child = np.asarray(child)
           
    # evaluate all candidates in the child population
    child_objective_values = evaluation(gen, child)
    
    # Choose new pop for next generation
    pop, objective_values = selection(np.concatenate([pop, child]), np.concatenate([objective_values, child_objective_values]), pop_size)

print(pop, objective_values)