In [6]:
import os
import nibabel as nib
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from tqdm import tqdm


#VNet
def passthrough(x, **kwargs):
    return x

def ELUCons(elu, nchan):
    if elu:
        return nn.ELU(inplace=True)
    else:
        return nn.PReLU(nchan)

class ContBatchNorm3d(nn.modules.batchnorm._BatchNorm):
    def _check_input_dim(self, input):
        if input.dim() != 5:
            raise ValueError('expected 5D input (got {}D input)'
                             .format(input.dim()))

    def forward(self, input):
        self._check_input_dim(input)
        return F.batch_norm(
            input, self.running_mean, self.running_var, self.weight, self.bias,
            self.training, self.momentum, self.eps)

class LUConv(nn.Module):
    def __init__(self, nchan, elu):
        super(LUConv, self).__init__()
        self.relu1 = ELUCons(elu, nchan)
        self.conv1 = nn.Conv3d(nchan, nchan, kernel_size=5, padding=2)
        self.bn1 = ContBatchNorm3d(nchan)
    def forward(self, x):
        out = self.relu1(self.bn1(self.conv1(x)))
        return out

def _make_nConv(nchan, depth, elu):
    layers = []
    for _ in range(depth):
        layers.append(LUConv(nchan, elu))
    return nn.Sequential(*layers)

class InputTransition(nn.Module):
    def __init__(self, outChans, elu):
        super(InputTransition, self).__init__()
        self.conv1 = nn.Conv3d(1, outChans, kernel_size=5, padding=2)
        self.bn1 = ContBatchNorm3d(outChans)
        self.relu1 = ELUCons(elu, outChans)
    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)
        return out

class DownTransition(nn.Module):
    def __init__(self, inChans, nConvs, elu, dropout=False):
        super(DownTransition, self).__init__()
        outChans = 2 * inChans
        self.down_conv = nn.Conv3d(inChans, outChans, kernel_size=2, stride=2)
        self.bn1 = ContBatchNorm3d(outChans)
        self.do1 = passthrough
        self.relu1 = ELUCons(elu, outChans)
        self.relu2 = ELUCons(elu, outChans)
        if dropout:
            self.do1 = nn.Dropout3d()
        self.ops = _make_nConv(outChans, nConvs, elu)
    def forward(self, x):
        down = self.relu1(self.bn1(self.down_conv(x)))
        out = self.do1(down)
        out = self.ops(out)
        out = self.relu2(out + down)
        return out

class UpTransition(nn.Module):
    def __init__(self, inChans, outChans, nConvs, elu, dropout=False):
        super(UpTransition, self).__init__()
        self.up_conv = nn.ConvTranspose3d(inChans, outChans // 2, kernel_size=2, stride=2)
        self.bn1 = ContBatchNorm3d(outChans // 2)
        self.do1 = passthrough
        self.do2 = nn.Dropout3d()
        self.relu1 = ELUCons(elu, outChans // 2)
        self.relu2 = ELUCons(elu, outChans)
        if dropout:
            self.do1 = nn.Dropout3d()
        self.ops = _make_nConv(outChans, nConvs, elu)
    def forward(self, x, skipx):
        out = self.do1(x)
        skipxdo = self.do2(skipx)
        out = self.relu1(self.bn1(self.up_conv(out)))
        diffZ = skipxdo.size()[2] - out.size()[2]
        diffY = skipxdo.size()[3] - out.size()[3]
        diffX = skipxdo.size()[4] - out.size()[4]
        if diffZ != 0 or diffY != 0 or diffX != 0:
            out = F.pad(out, [diffX // 2, diffX - diffX // 2,
                              diffY // 2, diffY - diffY // 2,
                              diffZ // 2, diffZ - diffZ // 2])
        xcat = torch.cat((out, skipxdo), 1)
        out = self.ops(xcat)
        out = self.relu2(out + xcat)
        return out

class OutputTransition(nn.Module):
    def __init__(self, inChans, elu, nll, num_classes=3):
        super(OutputTransition, self).__init__()
        self.conv1 = nn.Conv3d(inChans, num_classes, kernel_size=5, padding=2)
        self.bn1 = ContBatchNorm3d(num_classes)
        self.relu1 = ELUCons(elu, num_classes)
        if nll:
            self.softmax = F.log_softmax
        else:
            self.softmax = F.softmax
    def forward(self, x):
        out = self.relu1(self.bn1(self.conv1(x)))
        return out

class VNet(nn.Module):
    def __init__(self, elu=True, nll=False, num_classes=3):
        super(VNet, self).__init__()
        self.in_tr = InputTransition(16, elu)
        self.down_tr32 = DownTransition(16, 1, elu)
        self.down_tr64 = DownTransition(32, 2, elu)
        self.down_tr128 = DownTransition(64, 3, elu, dropout=True)
        self.down_tr256 = DownTransition(128, 2, elu, dropout=True)
        self.up_tr256 = UpTransition(256, 256, 2, elu, dropout=True)
        self.up_tr128 = UpTransition(256, 128, 2, elu, dropout=True)
        self.up_tr64 = UpTransition(128, 64, 1, elu)
        self.up_tr32 = UpTransition(64, 32, 1, elu)
        self.out_tr = OutputTransition(32, elu, nll, num_classes)

    def forward(self, x):
        out16 = self.in_tr(x)
        out32 = self.down_tr32(out16)
        out64 = self.down_tr64(out32)
        out128 = self.down_tr128(out64)
        out256 = self.down_tr256(out128)
        out = self.up_tr256(out256, out128)
        out = self.up_tr128(out, out64)
        out = self.up_tr64(out, out32)
        out = self.up_tr32(out, out16)
        out = self.out_tr(out)
        return out

def load_model(model_path, device):
    model = VNet(num_classes=3)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()
    return model

def predict_image(model, image_path, device):
    img = nib.load(image_path).get_fdata().astype(np.float32)
    img = np.expand_dims(img, axis=0)
    img = np.expand_dims(img, axis=0)
    img_tensor = torch.from_numpy(img).to(device)
    with torch.no_grad():
        output = model(img_tensor)
        pred = torch.argmax(output, dim=1).cpu().numpy()
    return pred[0]

def save_prediction(pred, original_image_path, output_dir):
    original_image = nib.load(original_image_path)
    pred_img = nib.Nifti1Image(pred, original_image.affine, original_image.header)
    filename = os.path.basename(original_image_path)
    output_path = os.path.join(output_dir, filename)
    nib.save(pred_img, output_path)


def compute_iou(pred, label, num_classes):
    ious = []
    for cls in range(num_classes):
        intersection = ((pred == cls) & (label == cls)).sum().item()
        union = ((pred == cls) | (label == cls)).sum().item()
        if union == 0:
            ious.append(float('nan'))
        else:
            ious.append(intersection / union)
    return ious


def compute_recall(pred, label, classes):
    recalls = {}
    for cls in classes:
        tp = ((pred == cls) & (label == cls)).sum().item()
        fn = ((pred != cls) & (label == cls)).sum().item()
        if tp + fn == 0:
            recalls[cls] = float('nan')
        else:
            recalls[cls] = tp / (tp + fn)
    return recalls


def main():
    model_path = 'best_model.pth'
    test_img_dir = 'test/img'
    test_label_dir = 'test/label'
    output_dir = ('test/outcome')

    os.makedirs(output_dir, exist_ok=True)
    image_files = sorted(Path(test_img_dir).glob('*.nii.gz'))
    label_files = sorted(Path(test_label_dir).glob('*.nii.gz'))

    if len(image_files) != len(label_files):
        raise ValueError("The number of image files and label files do not match.")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = load_model(model_path, device)

    all_ious = {cls: [] for cls in range(3)}  
    detailed_metrics = []
    recalls_cls1 = []
    recalls_cls2 = []

    for img_path, lbl_path in tqdm(zip(image_files, label_files), total=len(image_files), desc="Predicting and calculating metrics"):
        pred = predict_image(model, str(img_path), device)
        save_prediction(pred, str(img_path), output_dir)
        label = nib.load(lbl_path).get_fdata().astype(np.int64)

        if pred.shape != label.shape:
            raise ValueError(f"Shape mismatch between prediction and label for {img_path}")

        ious = compute_iou(pred, label, num_classes=3)  
        recalls = compute_recall(pred, label, [1, 2])

        detailed_metrics.append((os.path.basename(img_path), ious, recalls))
        for cls, iou in enumerate(ious):
            all_ious[cls].append(iou)

        recalls_cls1.append(recalls.get(1, float('nan')))
        recalls_cls2.append(recalls.get(2, float('nan')))

    print("\nSummary of average metrics:\n")

    avg_iou_cls0 = np.nanmean(all_ious[0])
    print(f'Class 0 Average IoU: {avg_iou_cls0:.4f}')

    avg_iou_cls1 = np.nanmean(all_ious[1])
    avg_recall_cls1 = np.nanmean(recalls_cls1)
    print(f'Class 1 Average IoU: {avg_iou_cls1:.4f}, Average Recall: {avg_recall_cls1:.4f}')

    avg_iou_cls2 = np.nanmean(all_ious[2])
    avg_recall_cls2 = np.nanmean(recalls_cls2)
    print(f'Class 2 Average IoU: {avg_iou_cls2:.4f}, Average Recall: {avg_recall_cls2:.4f}')

    print("\nDetailed metrics for each sample:\n")

    for sample, ious, recalls in detailed_metrics:
        print(f"Sample: {sample}")
        print(f"    Class 0 IoU: {ious[0]:.4f}")
        print(f"    Class 1 IoU: {ious[1]:.4f}, Recall: {recalls.get(1, 'N/A'):.4f}")
        print(f"    Class 2 IoU: {ious[2]:.4f}, Recall: {recalls.get(2, 'N/A'):.4f}\n")

if __name__ == "__main__":
    main()

Predicting and calculating metrics: 100%|██████████| 60/60 [00:00<00:00, 93.14it/s]


Summary of average metrics:

Class 0 Average IoU: 0.9752
Class 1 Average IoU: 0.5559, Average Recall: 0.6747
Class 2 Average IoU: 0.5763, Average Recall: 0.7142

Detailed metrics for each sample:

Sample: hippocampus_001.nii.gz
    Class 0 IoU: 0.9573
    Class 1 IoU: 0.2846, Recall: 0.4456
    Class 2 IoU: 0.2282, Recall: 0.3405

Sample: hippocampus_002.nii.gz
    Class 0 IoU: 0.9760
    Class 1 IoU: 0.5463, Recall: 0.7426
    Class 2 IoU: 0.5371, Recall: 0.6739

Sample: hippocampus_003.nii.gz
    Class 0 IoU: 0.9804
    Class 1 IoU: 0.6207, Recall: 0.7718
    Class 2 IoU: 0.6162, Recall: 0.7529

Sample: hippocampus_004.nii.gz
    Class 0 IoU: 0.9796
    Class 1 IoU: 0.6629, Recall: 0.7360
    Class 2 IoU: 0.7297, Recall: 0.8271

Sample: hippocampus_005.nii.gz
    Class 0 IoU: 0.9715
    Class 1 IoU: 0.3511, Recall: 0.4104
    Class 2 IoU: 0.7203, Recall: 0.8366

Sample: hippocampus_006.nii.gz
    Class 0 IoU: 0.9840
    Class 1 IoU: 0.7303, Recall: 0.8887
    Class 2 IoU: 0.6926, Re


