In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

from utils.log_utils import LogWriter
import torch
import glob
from quick_oct import QuickOct as mclass
import utils.common_utils as common_utils
import inspect
import shutil
from utils.evaluator import evaluate, evaluate2view, evaluate_dice_score, compute_vol_bulk, evaluate3view
from settings import Settings

In [2]:
def evaluate(eval_params, net_params, data_params, common_params, train_params, model_chkpt):
    eval_model_path = eval_params['eval_model_path']
    num_classes = net_params['num_class']
    labels = data_params['labels']
    data_dir = eval_params['data_dir']
    label_dir = eval_params['label_dir']
    volumes_txt_file = eval_params['volumes_txt_file']
    remap_config = eval_params['remap_config']
    device = common_params['device']
    log_dir = common_params['log_dir']
    exp_dir = common_params['exp_dir']
    exp_name = train_params['exp_name']
    save_predictions_dir = eval_params['save_predictions_dir']
    prediction_path = os.path.join(exp_dir, exp_name, save_predictions_dir)
    orientation = eval_params['orientation']
    data_id = eval_params['data_id']
    multi_channel = data_params['use_3channel']
    use_2channel = data_params['use_2channel']
    logWriter = LogWriter(num_classes, log_dir, exp_name, labels=labels)
    print("######################################################################################")
    print(model_chkpt)
    
    model = mclass(net_params)
    cp = torch.load(model_chkpt)
    model.load_state_dict(cp['state_dict'])
    avg_dice_score, class_dist = evaluate_dice_score(model,
                                                        num_classes,
                                                        data_dir,
                                                        label_dir,
                                                        volumes_txt_file,
                                                        remap_config,
                                                        orientation,
                                                        prediction_path,
                                                        data_id,
                                                        device,
                                                        logWriter,
                                                        multi_channel=multi_channel,
                                                        use_2channel=use_2channel)
    logWriter.close()
    print("######################################################################################")

    
def evaluate_save_best(eval_params, net_params, data_params, common_params, train_params, model_chkpt):
    eval_model_path = eval_params['eval_model_path']
    num_classes = net_params['num_class']
    labels = data_params['labels']
    data_dir = eval_params['data_dir']
    label_dir = eval_params['label_dir']
    volumes_txt_file = eval_params['volumes_txt_file']
    remap_config = eval_params['remap_config']
    device = common_params['device']
    log_dir = common_params['log_dir']
    exp_dir = common_params['exp_dir']
    exp_name = train_params['exp_name']
    save_predictions_dir = eval_params['save_predictions_dir']
    prediction_path = os.path.join(exp_dir, exp_name, save_predictions_dir)
    orientation = eval_params['orientation']
    data_id = eval_params['data_id']
    multi_channel = data_params['use_3channel']
    use_2channel = data_params['use_2channel']
    logWriter = LogWriter(num_classes, log_dir, exp_name, labels=labels)
    print("######################################################################################")
    print(model_chkpt)
    arch_file_path = inspect.getfile(mclass)
    model = mclass(net_params)
#     print(inspect.getfile(inspect.getfile(model)))
    cp = torch.load(model_chkpt)
    model.load_state_dict(cp['state_dict'])
    avg_dice_score, class_dist = evaluate_dice_score(model,
                                                        num_classes,
                                                        data_dir,
                                                        label_dir,
                                                        volumes_txt_file,
                                                        remap_config,
                                                        orientation,
                                                        prediction_path,
                                                        data_id,
                                                        device,
                                                        logWriter,
                                                        multi_channel=multi_channel,
                                                        use_2channel=use_2channel)
    logWriter.close()
    exp_dir_path = os.path.join(exp_dir, exp_name)
    save_architectural_files(exp_dir_path, arch_file_path)
    best_model_path = os.path.join(common_params['save_model_dir'], train_params['final_model_file'])
    torch.save(model, best_model_path)
    print("######################################################################################")
    
def save_architectural_files(exp_dir_path, arch_file_path):
    ARCHITECTURE_DIR = 'architecture'
    if arch_file_path is not None:
        destination = os.path.join(exp_dir_path, ARCHITECTURE_DIR)
        common_utils.create_if_not(destination)
        arch_base = "/".join(arch_file_path.split('/')[:-1])
        print(arch_file_path, arch_base, destination+'/model.py')
        shutil.copy(arch_file_path, destination+'/model.py')
        shutil.copy(f'{arch_base}/run.py', f'{destination}/run.py')
        shutil.copy(f'{arch_base}/solver.py', f'{destination}/solver.py')
        shutil.copy(f'{arch_base}/utils/evaluator.py', f'{destination}/utils-evaluator.py')
        shutil.copy(f'{arch_base}/nn_common_modules/losses.py', f'{destination}/nn_common_modules-losses.py')
        shutil.copy(f'{arch_base}/nn_common_modules/modules.py', f'{destination}/nn_common_modules-modules.py')
        shutil.copy(f'{arch_base}/settings_merged_jj.ini', f'{destination}/settings_merged_jj.ini')
    else:
        print('No Architectural file!!!')

In [None]:
model_path = "/mnt/nas/Abhijit/Jyotirmay/abdominal_segmentation/experiments/pp2_axial_kora_do_cw_swa_CSSE_seed_transform_octave_inn_concat2/checkpoints/*"
chkpts = glob.glob(model_path)
for ch in chkpts:
    try: 
        settings = Settings('/home/abhijit/Jyotirmay/abdominal_segmentation/quickNAT_pytorch/settings_merged_jj.ini')
        common_params, data_params, net_params, train_params, eval_params = settings['COMMON'], settings['DATA'], \
                                                                        settings[
                                                                            'NETWORK'], settings['TRAINING'], \
                                                                        settings['EVAL']
        evaluate(eval_params, net_params, data_params, common_params, train_params, ch)
    except Exception as e:
        print(e)
        continue

######################################################################################
/mnt/nas/Abhijit/Jyotirmay/abdominal_segmentation/experiments/pp2_axial_kora_do_cw_swa_CSSE_seed_transform_octave_inn_concat2/checkpoints/checkpoint_epoch_1.pth.tar
NUMBER OF CHANNEL 1
CSSE
False
CSSE
False
CSSE
False
CSSE
True
CSSE
True
data id  KORA
['/mnt/nas/Abhijit/Jyotirmay/abdominal_segmentation/dataset/KORA/test/volume/KORA2453666.nii.gz', '/mnt/nas/Abhijit/Jyotirmay/abdominal_segmentation/dataset/KORA/test/label9/KORA2453666.nii.gz']
evaluator here
['/mnt/nas/Abhijit/Jyotirmay/abdominal_segmentation/dataset/KORA/test/volume/KORA2460408.nii.gz', '/mnt/nas/Abhijit/Jyotirmay/abdominal_segmentation/dataset/KORA/test/label9/KORA2460408.nii.gz']
evaluator here
Mean dice score:  0.0871981
Mean dice score without background:  0.010007762
all dice scores:  [[0.7830899  0.10249846 0.         0.         0.         0.
  0.         0.         0.        ]
 [0.6263517  0.05762573 0.         0.         0.  

evaluator here
Mean dice score:  0.205053
Mean dice score without background:  0.10734004
all dice scores:  [[0.9829969  0.8037588  0.         0.         0.         0.
  0.         0.         0.        ]
 [0.99051654 0.9136818  0.         0.         0.         0.
  0.         0.         0.        ]]
class wise mean dice scores:  [0.9867567 0.8587203 0.        0.        0.        0.        0.
 0.        0.       ]
######################################################################################
######################################################################################
/mnt/nas/Abhijit/Jyotirmay/abdominal_segmentation/experiments/pp2_axial_kora_do_cw_swa_CSSE_seed_transform_octave_inn_concat2/checkpoints/checkpoint_epoch_8.pth.tar
NUMBER OF CHANNEL 1
CSSE
False
CSSE
False
CSSE
False
CSSE
True
CSSE
True
data id  KORA
['/mnt/nas/Abhijit/Jyotirmay/abdominal_segmentation/dataset/KORA/test/volume/KORA2453666.nii.gz', '/mnt/nas/Abhijit/Jyotirmay/abdominal_segmentation/dataset/

In [None]:
model_path = "/mnt/nas/Abhijit/Jyotirmay/abdominal_segmentation/experiments/pp2_axial_kora_do_cw_swa_CSSE_seed_transform_octave_inn_concat2/checkpoints/*"
chkpts = glob.glob(model_path)
ch = chkpts[200]
try: 
    settings = Settings('/home/abhijit/Jyotirmay/abdominal_segmentation/quickNAT_pytorch/settings_merged_jj.ini')
    common_params, data_params, net_params, train_params, eval_params = settings['COMMON'], settings['DATA'], \
                                                                    settings[
                                                                        'NETWORK'], settings['TRAINING'], \
                                                                    settings['EVAL']
    evaluate_save_best(eval_params, net_params, data_params, common_params, train_params, ch)
except Exception as e:
    print(e)

In [None]:
74.88857, 31
77.898, 75.224966, 36
78.504807, 75.90687, 60
78.71778,  76.13684, 257

In [None]:
71, 26
71.24, 41
71.96, 43
72.17, 44
73.4,  52
74.1,  201  76.9058
74.05, 299  76.86589