In [12]:
import numpy, sys
from torch.utils.data import DataLoader
from tqdm import tqdm


from dataio.loader import get_dataset, get_dataset_path
from dataio.transformation import get_dataset_transformation
from utils.util import json_file_to_pyobj
from utils.visualiser import Visualiser
from utils.error_logger import ErrorLogger

from models import get_model

In [13]:
def test_model():
    json_filename = "configs/config_unet_ct_multi_att_dsv.json"
    json_opts = json_file_to_pyobj(json_filename)
    train_opts = json_opts.training

    # Architecture type
    arch_type = train_opts.arch_type

    # Setup Dataset and Augmentation
    ds_class = get_dataset(arch_type)
    ds_path = get_dataset_path(arch_type, json_opts.data_path)
    ds_transform = get_dataset_transformation(arch_type, opts=json_opts.augmentation)

    # Setup the NN Model
    model = get_model(json_opts.model)

    # Load the pretrained model
    pretrained_model_path = "checkpoints/experiment_unet_ct_multi_att_dsv/095_net_S.pth"
    model.load_network_from_path(model.net, pretrained_model_path, strict=True)
        
    # Setup Data Loader
    test_dataset = ds_class(
        ds_path, split='test', transform=ds_transform['valid'], preload_data=train_opts.preloadData
    )
    test_loader = DataLoader(
        dataset=test_dataset, num_workers=1, batch_size=train_opts.batchSize, shuffle=False
    )

    # Set the model to evaluation mode
    model.net.eval()

    # Initialize the error logger
    error_logger = ErrorLogger()

    # Testing Iterations
    print("Starting testing...")
    for images, labels in tqdm(test_loader, total=len(test_loader), file=sys.stdout):
        # Move data to the appropriate device
        if model.use_cuda and model.gpu_ids:
            images, labels = images.cuda(model.gpu_ids[0]), labels.cuda(model.gpu_ids[0])

        # Perform forward pass
        model.set_input(images, labels)
        model.validate()

        # Log errors and stats
        errors = model.get_current_errors()
        stats = model.get_segmentation_stats()
        error_logger.update({**errors, **stats}, split='test')

    # Summarize results
    test_errors = error_logger.get_errors('test')
    print("--------------------------------------")
    print(f"Test Results: {test_errors}")
    print("--------------------------------------")


In [14]:
if __name__ == '__main__':
    print("Testing the model...")
    test_model()

Testing the model...


############# Augmentation Parameters #############
{'division_factor': (16, 16, 1),
 'inten_val': (1.0, 1.0),
 'name': 'acdc_sax',
 'patch_size': [160, 160, 96],
 'random_flip_prob': 0.5,
 'rotate_val': 15.0,
 'scale_size': [160, 160, 96],
 'scale_val': (0.7, 1.3),
 'shift_val': (0.1, 0.1)}
###################################################



Initialising model unet_ct_multi_att_dsv
Model [FeedForwardSegmentation] is created
Loading the model 095_net_S.pth - epoch 095
Number of test images: 18 NIFTIs
Starting testing...
100%|██████████| 9/9 [00:17<00:00,  1.89s/it]
--------------------------------------
Test Results: {'Seg_Loss': 0.799999992052714, 'Overall_Acc': 0.7266673674406829, 'Mean_IOU': 0.18577655912286054, 'Class_0': 0.8261760671933492, 'Class_1': 0.024303110523356333, 'Class_2': 0.0, 'Class_3': 0.0}
--------------------------------------
