In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import numpy as np
from utils.log_utils import LogWriter
import torch
import glob
from quick_oct import QuickOct as mclass
import utils.common_utils as common_utils
import inspect
import shutil
from utils.evaluator import evaluate, evaluate2view, evaluate_dice_score, compute_vol_bulk, evaluate3view
from settings import Settings

In [None]:
def evaluate(eval_params, net_params, data_params, common_params, train_params, model_chkpt):
    eval_model_path = eval_params['eval_model_path']
    num_classes = net_params['num_class']
    labels = data_params['labels']
    data_dir = eval_params['data_dir']
    label_dir = eval_params['label_dir']
    volumes_txt_file = eval_params['volumes_txt_file']
    remap_config = eval_params['remap_config']
    device = common_params['device']
    log_dir = common_params['log_dir']
    exp_dir = common_params['exp_dir']
    exp_name = train_params['exp_name']
    save_predictions_dir = eval_params['save_predictions_dir']
    prediction_path = os.path.join(exp_dir, exp_name, save_predictions_dir)
    orientation = eval_params['orientation']
    data_id = eval_params['data_id']
    multi_channel = data_params['use_3channel']
    use_2channel = data_params['use_2channel']
    thick_channel = data_params['thick_channel']
    logWriter = LogWriter(num_classes, log_dir, exp_name, labels=labels)
    print("######################################################################################")
    print(model_chkpt)
    
    model = mclass(net_params)
    cp = torch.load(model_chkpt)
    model.load_state_dict(cp['state_dict'])
    avg_dice_score, class_dist = evaluate_dice_score(model,
                                                        num_classes,
                                                        data_dir,
                                                        label_dir,
                                                        volumes_txt_file,
                                                        remap_config,
                                                        orientation,
                                                        prediction_path,
                                                        data_id,
                                                        device,
                                                        logWriter,
                                                        multi_channel=multi_channel,
                                                        use_2channel=use_2channel,
                                                        thick_ch=thick_channel)
    logWriter.close()
    print("######################################################################################")

    
def evaluate_save_best(eval_params, net_params, data_params, common_params, train_params, model_chkpt):
    eval_model_path = eval_params['eval_model_path']
    num_classes = net_params['num_class']
    labels = data_params['labels']
    data_dir = eval_params['data_dir']
    label_dir = eval_params['label_dir']
    volumes_txt_file = eval_params['volumes_txt_file']
    remap_config = eval_params['remap_config']
    device = common_params['device']
    log_dir = common_params['log_dir']
    exp_dir = common_params['exp_dir']
    exp_name = train_params['exp_name']
    save_predictions_dir = eval_params['save_predictions_dir']
    prediction_path = os.path.join(exp_dir, exp_name, save_predictions_dir)
    orientation = eval_params['orientation']
    data_id = eval_params['data_id']
    multi_channel = data_params['use_3channel']
    use_2channel = data_params['use_2channel']
    thick_channel = data_params['thick_channel']
    logWriter = LogWriter(num_classes, log_dir, exp_name, labels=labels)
    print("######################################################################################")
    print(model_chkpt)
    arch_file_path = inspect.getfile(mclass)
    model = mclass(net_params)
#     print(inspect.getfile(inspect.getfile(model)))
    cp = torch.load(model_chkpt)
    model.load_state_dict(cp['state_dict'])
    avg_dice_score, class_dist = evaluate_dice_score(model,
                                                        num_classes,
                                                        data_dir,
                                                        label_dir,
                                                        volumes_txt_file,
                                                        remap_config,
                                                        orientation,
                                                        prediction_path,
                                                        data_id,
                                                        device,
                                                        logWriter,
                                                        multi_channel=multi_channel,
                                                        use_2channel=use_2channel,
                                                        thick_ch=thick_channel)
    logWriter.close()
    exp_dir_path = os.path.join(exp_dir, exp_name)
    save_architectural_files(exp_dir_path, arch_file_path)
    best_model_path = os.path.join(common_params['save_model_dir'], train_params['final_model_file'])
    torch.save(model, best_model_path)
    print("######################################################################################")
    
def save_architectural_files(exp_dir_path, arch_file_path):
    ARCHITECTURE_DIR = 'architecture'
    if arch_file_path is not None:
        destination = os.path.join(exp_dir_path, ARCHITECTURE_DIR)
        common_utils.create_if_not(destination)
        arch_base = "/".join(arch_file_path.split('/')[:-1])
        print(arch_file_path, arch_base, destination+'/model.py')
        shutil.copy(arch_file_path, destination+'/model.py')
        shutil.copy(f'{arch_base}/run.py', f'{destination}/run.py')
        shutil.copy(f'{arch_base}/solver.py', f'{destination}/solver.py')
        shutil.copy(f'{arch_base}/utils/evaluator.py', f'{destination}/utils-evaluator.py')
        shutil.copy(f'{arch_base}/nn_common_modules/losses.py', f'{destination}/nn_common_modules-losses.py')
        shutil.copy(f'{arch_base}/nn_common_modules/modules.py', f'{destination}/nn_common_modules-modules.py')
        shutil.copy(f'{arch_base}/settings_merged_jj.ini', f'{destination}/settings_merged_jj.ini')
    else:
        print('No Architectural file!!!')

In [None]:
model_path = "/mnt/nas/Abhijit/Jyotirmay/abdominal_segmentation/experiments/pp2_axial_kora_do_cw_swa_transform_csse_seed_octave_inn_concat_thick32/checkpoints/*"
chkpts = glob.glob(model_path)
# chkpt = [chkpts[95], chkpts[113], chkpts[170]]
for ch in chkpts[94:]:
#     try: 
        print(ch)
        checkpoint = torch.load(ch)
        print(checkpoint.keys())
        if 'best_ds_mean' in checkpoint.keys():
            print(checkpoint['best_ds_mean_epoch'],checkpoint['best_ds_mean'],  checkpoint['epoch'])
        settings = Settings('/home/abhijit/Jyotirmay/abdominal_segmentation/quickNAT_pytorch/settings_merged_jj.ini')
        common_params, data_params, net_params, train_params, eval_params = settings['COMMON'], settings['DATA'], \
                                                                        settings[
                                                                            'NETWORK'], settings['TRAINING'], \
                                                                        settings['EVAL']
        evaluate(eval_params, net_params, data_params, common_params, train_params, ch)
#     except Exception as e:
#         print(e)

#         continue

In [None]:
# 0.72223794, 0.688439, 114
# 0.7281057, 0.6950686, 120
# 0.73640484, 0.7043711, 123
# 0.7369785, 0.7049849, 175
# 0.74033165, 0.7088444, 190
# 0.745792, 0.71495265, 209
(1568, 192, 176) (1568, 192, 176) (1568, 192, 176) (1568, 9)                        
│···········(224, 192, 176) (224, 192, 176) (224, 192, 176) (224, 9) 

dict_keys(['epoch', 'start_iteration', 'best_ds_mean', 'best_ds_mean_epoch', 'arch', 'state_dict', 'optimizer', 'scheduler'])
114 0.7440068125724792 172

96,
Mean dice score:  0.73158914
Mean dice score without background:  0.69894266
all dice scores:  [[0.9902677  0.87998533 0.7953528  0.8738815  0.88923365 0.20258947
  0.48780474 0.42834452 0.5258064 ]
 [0.9952544  0.94266886 0.8597339  0.8494878  0.8527255  0.623044
  0.6442166  0.6286432  0.6995638 ]]
class wise mean dice scores:  [0.992761   0.9113271  0.8275434  0.8616847  0.87097955 0.41281673
 0.56601065 0.5284939  0.6126851 ]
114,
Mean dice score:  0.72194374
Mean dice score without background:  0.68809366
all dice scores:  [[0.98947644 0.86812973 0.7180221  0.82245517 0.88225234 0.1769398
  0.45190832 0.29078966 0.70641786]
 [0.9960126  0.9550314  0.90285546 0.889374   0.90793043 0.6460673
  0.5649717  0.5924293  0.6339235 ]]
class wise mean dice scores:  [0.99274457 0.91158056 0.81043875 0.8559146  0.8950914  0.41150355
 0.50844    0.44160947 0.67017066]

171
Mean dice score:  0.7212828
Mean dice score without background:  0.6873605
all dice scores:  [[0.9893841  0.8671018  0.69115525 0.85846704 0.88969386 0.22851744
  0.31768942 0.2562432  0.57367826]
 [0.9959385  0.9534169  0.8895712  0.9120816  0.9075551  0.74326736
  0.6867255  0.6203689  0.60223496]]
class wise mean dice scores:  [0.9926613  0.91025937 0.7903632  0.8852743  0.8986245  0.48589242
 0.50220746 0.43830603 0.5879566 ]
# 0.74397445, 0.7126837, 129
# 0.7492841, 0.71869576, 139
# 0.7515822, 0.72124577, 173
# 0.7521182, 0.72184324, 209
# 0.7530587, 0.7228937, 232
# 0.755678, 0.7258475, 240
# 0.76108134, 0.7319298, 249

# 0.7500159, 0.71961, 156
# 0.7508582, 0.7204968, 166
# 0.7552803, 0.72552896, 182
# 0.75755465, 0.72799754, 231
# 0.75888705, 0.7295543, 243
# 0.7596982, 0.7304461, 259
# 0.76043683, 0.73126954, 285


# 0.75182295, 0.7218148, 58
# 0.7562244, 0.7266662, 125
# 0.7578031, 0.7284702, 187      0.75785315, 0.72836435, 208
# 0.7635112, 0.73478234, 214


# 0.7439982, 0.71291935, 56
# 0.7509596, 0.7207745, 58
# 0.7524023, 0.72232246, 67
# 0.7582579, 0.72889125, 69
# 0.7626364, 0.73383045, 81
# 0.76475966, 0.7362199, 87
# 0.76554954, 0.7370368, 201
# 0.7671693, 0.7387961, 246
# 0.77319837, 0.7456038, 254
# 0.77607036, 0.7488848, 258


# 73.636, 53
# 76.68003, 73.856914, 135
# 77.701074, 74.99354, 147
# 77.76387,  75.05958, 294

In [None]:
model_path = "/mnt/nas/Abhijit/Jyotirmay/abdominal_segmentation/experiments/pp2_axial_kora_do_cw_swa_transform_csse_seed_octave_inn_concat_thick32/checkpoints/*"
chkpts = glob.glob(model_path)
ch = chkpts[110:120]
try: 
    settings = Settings('/home/abhijit/Jyotirmay/abdominal_segmentation/quickNAT_pytorch/settings_merged_jj.ini')
    common_params, data_params, net_params, train_params, eval_params = settings['COMMON'], settings['DATA'], \
                                                                    settings[
                                                                        'NETWORK'], settings['TRAINING'], \
                                                                    settings['EVAL']
    evaluate_save_best(eval_params, net_params, data_params, common_params, train_params, ch)
except Exception as e:
    print(e)

In [None]:
74.88857, 31
77.898, 75.224966, 36
78.504807, 75.90687, 60
78.71778,  76.13684, 257

In [None]:
71, 26
71.24, 41
71.96, 43
72.17, 44
73.4,  52
74.1,  201  76.9058
74.05, 299  76.86589