In [1]:
from sklearn.metrics import accuracy_score
from torchvision import transforms
import os
import glob
import nibabel as nib
import numpy as np
import random
from scipy.ndimage import zoom
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from scipy.spatial.distance import directed_hausdorff
from sklearn.metrics import mean_squared_error
import segmentation_models_pytorch_3d as smp
import re

In [2]:
# Dice coefficient for binary segmentation
def dice_score(pred, target, smooth=1e-6):
    """
    Calculates the Dice coefficient between the prediction and the ground truth.
    Handles cases where the output or ground truth is all zeros.
    """
    # Binarize the predictions (threshold at 0.0)
    pred = (pred > 0.0).float()  # Convert probabilities/logits to binary predictions
    target = (target > 0.0).float()  # Ensure target is binary

    # Flatten the arrays to compare voxel-wise
    pred_flat = pred.view(-1)  # Flatten prediction
    target_flat = target.view(-1)  # Flatten target

    # Calculate intersection and Dice score
    intersection = (pred_flat * target_flat).sum()  # True positive (prediction == 1 and target == 1)

    dice = (2.0 * intersection + smooth) / (pred_flat.sum() + target_flat.sum() + smooth)

    return dice.item()

In [3]:
def hausdorff_distance(pred, target):
    pred_np = pred.cpu().numpy()
    target_np = target.cpu().numpy()
    pred_points = np.argwhere(pred_np > 0)
    target_points = np.argwhere(target_np > 0)
    return max(directed_hausdorff(pred_points, target_points)[0], directed_hausdorff(target_points, pred_points)[0])

In [4]:
# NIfTI saving function
def save_nifti(data_tensor, label_tensor, save_dir, batch_idx):
    os.makedirs(save_dir, exist_ok=True)
    data_np = data_tensor.cpu().numpy().astype(np.float32).squeeze(1)
    label_np = label_tensor.cpu().numpy().astype(np.float32).squeeze(1)

    for i in range(data_np.shape[0]):
        data_filename = os.path.join(save_dir, f'test_data_batch{batch_idx}_instance_{i}.nii.gz')
        label_filename = os.path.join(save_dir, f'test_label_batch{batch_idx}_instance_{i}.nii.gz')

        nib.save(nib.Nifti1Image(data_np[i], np.eye(4)), data_filename)
        nib.save(nib.Nifti1Image(label_np[i], np.eye(4)), label_filename)


In [5]:
def extract_ribfrac_number(filename):
    base = os.path.basename(filename)
    match = re.search(r'RibFrac(\d+)', base)
    if match:
        return int(match.group(1))
    return None

def find_matching_files(data_files, label_files):
    data_dict = {extract_ribfrac_number(f): f for f in data_files}
    label_dict = {extract_ribfrac_number(f): f for f in label_files}
    
    matched_pairs = []
    for num in data_dict.keys():
        if num in label_dict:
            matched_pairs.append((data_dict[num], label_dict[num]))
        elif num - 1 in label_dict:  # Check for off-by-one match
            matched_pairs.append((data_dict[num], label_dict[num - 1]))
    
    return matched_pairs

def is_valid_pair(data_file, label_file):
    #logger.debug(f"Checking pair: {os.path.basename(data_file)} - {os.path.basename(label_file)}")
    
    try:
        label = nib.load(label_file).get_fdata()
        if np.all(label == 0) or np.isnan(label).any() or np.isinf(label).any():
            #logger.debug(f"Label file contains invalid data (all zeros, NaNs, or Infs): {label_file}")
            return False
        #logger.debug(f"Valid pair: {os.path.basename(data_file)} - {os.path.basename(label_file)}")
        return True
    except Exception as e:
        #logger.error(f"Error loading {label_file}: {e}")
        return False

def check_for_invalid_values(tensor, tensor_name="tensor"):
    if torch.isnan(tensor).any():
        print(f"NaN detected in {tensor_name}")
    if torch.isinf(tensor).any():
        print(f"Inf detected in {tensor_name}")

In [11]:
# Custom Dataset Class with Dynamic Filtering
class MedicalDataset(Dataset):

    def __init__(self, data_list, label_list, transform=None):
        #logger.info(f"Initializing dataset with {len(data_list)} data files and {len(label_list)} label files")
        
        self.matched_pairs = find_matching_files(data_list, label_list)
        self.valid_pairs = [pair for pair in self.matched_pairs if is_valid_pair(*pair)]
        
        #logger.info(f"Total pairs: {len(self.matched_pairs)}, Valid pairs: {len(self.valid_pairs)}")
        
        self.transform = transform

        # Log all valid pairs
        #for data, label in self.valid_pairs:
         #   logger.debug(f"Valid pair: {os.path.basename(data)} - {os.path.basename(label)}")

        #self.mean = mean
        #self.std = std

    #def normalize(self, tensor, mean, std):
    #    return (tensor - mean) / std
    
    def __len__(self):
        return len(self.valid_pairs)

    def __getitem__(self, idx):
        if idx >= len(self.valid_pairs):
            raise IndexError(f"Index {idx} out of range for valid_pairs.")
        
        data_file, label_file = self.valid_pairs[idx]

        # Load the data and label
        data = nib.load(data_file).get_fdata()
        label = nib.load(label_file).get_fdata()

        #logger.debug(f"File: {os.path.basename(data_file)} - Raw data shape: {data.shape}")
        #logger.debug(f"File: {os.path.basename(label_file)} - Raw label shape: {label.shape}")

        # Convert data and label to tensors
        data_tensor = torch.from_numpy(data).float().unsqueeze(0)
        label_tensor = torch.from_numpy(label).float().unsqueeze(0)

        # Check for invalid values in data and label
        check_for_invalid_values(data_tensor, "data_tensor")
        check_for_invalid_values(label_tensor, "label_tensor")

        # Normalize tensors
        #data_tensor = self.normalize(data_tensor, self.mean, self.std)
        #label_tensor = self.normalize(label_tensor, self.mean, self.std)

        # Log stats to check ranges
        #logger.debug(f"Data tensor min: {data_tensor.min()}, max: {data_tensor.max()}, mean: {data_tensor.mean()}")
        #logger.debug(f"Label tensor min: {label_tensor.min()}, max: {label_tensor.max()}, mean: {label_tensor.mean()}")

        sample = {'data': data_tensor, 'label': label_tensor, 'data_file': data_file, 'label_file': label_file}

        # Apply any transforms (e.g., resizing)
        if self.transform:
            sample = self.transform(sample)
        
        return sample

# Transform to resize the data
class ResizeTransform:
    def __init__(self, target_shape=(256, 256, 128)):
        self.target_shape = target_shape

    def __call__(self, sample):
        data, label = sample['data'], sample['label']
        data = F.interpolate(data.unsqueeze(0), size=self.target_shape, mode='trilinear', align_corners=False).squeeze(0)
        #label = F.interpolate(label.unsqueeze(0), size=self.target_shape, mode='trilinear', align_corners=False).squeeze(0)
        label = F.interpolate(label.unsqueeze(0), size=self.target_shape, mode='nearest').squeeze(0)
        #logger.debug(f"Transform data shape: {data.shape}")
        #logger.debug(f"Transform label shape: {label.shape}")
        return {'data': data, 'label': label, 'data_file': sample['data_file'], 'label_file': sample['label_file']}

In [12]:
# DataLoader creation function
def create_dataloader(data_list, label_list, transform=None, batch_size=2, shuffle=True, num_workers=8):
    dataset = MedicalDataset(data_list, label_list, transform=transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, drop_last=True, pin_memory=True)
    return dataloader

test_data_dir = '/workspace/RibCage/test-ribfrac-defected'  
test_label_dir = '/workspace/RibCage/test-ribfrac-implants' 
val_data_dir = '/workspace/RibCage/val-ribfrac-defected'
val_label_dir = '/workspace/RibCage/val-ribfrac-implants'

test_data_list = sorted(glob.glob(os.path.join(test_data_dir, '*.nii')) + glob.glob(os.path.join(test_data_dir, '*.nii.gz')))
test_label_list = sorted(glob.glob(os.path.join(test_label_dir, '*.nii')) + glob.glob(os.path.join(test_label_dir, '*.nii.gz')))
val_data_list = sorted(glob.glob(os.path.join(val_data_dir, '*.nii')) + glob.glob(os.path.join(val_data_dir, '*.nii.gz')))
val_label_list = sorted(glob.glob(os.path.join(val_label_dir, '*.nii')) + glob.glob(os.path.join(val_label_dir, '*.nii.gz')))

# Transform (resize if necessary)
resize_transform = ResizeTransform(target_shape=(256, 256, 128))

temp_test_dataset = MedicalDataset(test_data_list, test_label_list, transform=resize_transform) 
#test_mean, test_std = compute_mean_std(temp_test_dataset)

temp_val_dataset = MedicalDataset(val_data_list, val_label_list, transform=resize_transform) 
#val_mean, val_std = compute_mean_std(temp_val_dataset)

# Create DataLoader for test set
test_loader = create_dataloader(test_data_list, test_label_list,   transform=resize_transform, batch_size=2, shuffle=False)
val_loader = create_dataloader(val_data_list, val_label_list, transform=resize_transform, batch_size=2, shuffle=False)

In [13]:
# Inference and metric calculation function
def inference_and_evaluate_dice_model(model, test_loader, device, save_dir):
    model.eval()  # Set model to evaluation mode
    dice_scores = []
    hausdorff_distances = []
    mse_scores = []

    os.makedirs(save_dir, exist_ok=True)

    with torch.no_grad():
        for batch_idx, batch in enumerate(test_loader):
            inputs = batch['data'].to(device)
            labels = batch['label'].to(device)

            #print(inputs.shape)

            inputs = inputs.permute(1, 0, 2, 3, 4)
            labels = labels.permute(1, 0, 2, 3, 4)

            # Forward pass
            #features = model.encoder(inputs)
            #decoder_output = model.decoder(*features)
            #outputs = model.segmentation_head(decoder_output)

            outputs = model(inputs)

            #print(f"Val output shape: {outputs.shape}")
            #print(f"Val labels shape: {labels.shape}")
            
            # Apply sigmoid to get probabilities
            #outputs = torch.sigmoid(outputs)
            
            # Save output predictions
            #save_nifti(outputs, labels, save_dir, batch_idx)
            
            # Compute Dice Score
            dice = dice_score(outputs, labels)
            dice_scores.append(dice)

            # Compute Hausdorff Distance
            hausdorff = hausdorff_distance(outputs, labels)
            hausdorff_distances.append(hausdorff)

            # Compute MSE
            mse = mean_squared_error(labels.cpu().numpy().ravel(), outputs.cpu().numpy().ravel())
            mse_scores.append(mse)

            if dice > 0.30:
                dice_scores.append(dice)
                hausdorff_distances.append(hausdorff)
                mse_scores.append(mse)

            print(f'Batch {batch_idx} - Dice: {dice:.4f}, Hausdorff: {hausdorff:.4f}, MSE: {mse:.4f}')
    
    # Save overall metrics
    #np.save(os.path.join(save_dir, 'dice_scores.npy'), np.array(dice_scores))
    #np.save(os.path.join(save_dir, 'hausdorff_distances.npy'), np.array(hausdorff_distances))
    #np.save(os.path.join(save_dir, 'mse_scores.npy'), np.array(mse_scores))

    print(f'Average Dice Score: {np.mean(dice_scores):.4f}, len: {len(dice_scores)}')
    print(f'Average Hausdorff Distance: {np.mean(hausdorff_distances):.4f}')
    print(f'Average MSE: {np.mean(mse_scores):.4f}')

In [14]:
# Define the convolutional layer with initialization
class Conv3dLayer(nn.Module):
    def __init__(self, input_chn, output_chn, kernel_size, stride, bias=False):
        super(Conv3dLayer, self).__init__()
        padding = (kernel_size - 1) // 2  # Calculate padding
        self.conv = nn.Conv3d(input_chn, output_chn, kernel_size, stride, padding=padding, bias=use_bias)
        nn.init.trunc_normal_(self.conv.weight, std=0.01)
        if use_bias:
            nn.init.zeros_(self.conv.bias)

    def forward(self, x):
        return self.conv(x)

# Define the block with convolution, batch normalization, and ReLU
class ConvBnReLU(nn.Module):
    def __init__(self, input_chn, output_chn, kernel_size, stride):
        super(ConvBnReLU, self).__init__()
        padding = (kernel_size - 1) // 2
        self.conv = nn.Conv3d(input_chn, output_chn, kernel_size, stride, padding=padding)
        self.bn = nn.BatchNorm3d(output_chn, momentum=0.9)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

# Define the deconvolutional layer with initialization
class Deconv3dLayer(nn.Module):
    def __init__(self, input_chn, output_chn):
        super(Deconv3dLayer, self).__init__()
        self.deconv = nn.ConvTranspose3d(input_chn, output_chn, kernel_size=4, stride=2, padding=1)
        nn.init.normal_(self.deconv.weight, std=0.01)
        nn.init.zeros_(self.deconv.bias)

    def forward(self, x):
        return self.deconv(x)

# Define the block with deconvolution, batch normalization, and ReLU
class DeconvBnReLU(nn.Module):
    def __init__(self, input_chn, output_chn):
        super(DeconvBnReLU, self).__init__()
        self.deconv = Deconv3dLayer(input_chn, output_chn)
        self.bn = nn.BatchNorm3d(output_chn, momentum=0.9)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.deconv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

# Define the block with repeated convolution, batch normalization, and ReLU
class ConvBnReLUX3(nn.Module):
    def __init__(self, input_chn, output_chn, kernel_size, stride, use_bias):
        super(ConvBnReLUX3, self).__init__()
        self.conv1 = ConvBnReLU(input_chn, output_chn, kernel_size, stride, use_bias)
        self.conv2 = ConvBnReLU(output_chn, output_chn, kernel_size, stride, use_bias)
        self.conv3 = ConvBnReLU(output_chn, output_chn, kernel_size, stride, use_bias)

    def forward(self, x):
        z = self.conv1(x)
        z_out = self.conv2(z)
        z_out = self.conv3(z_out)
        return z + z_out        


class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder, self).__init__()
        self.conv1 = ConvBnReLU(2, 64, kernel_size=5, stride=2)
        self.conv2 = ConvBnReLU(64, 128, kernel_size=5, stride=2)
        self.conv3 = ConvBnReLU(128, 256, kernel_size=5, stride=2)
        self.conv4 = ConvBnReLU(256, 512, kernel_size=5, stride=2)
        self.conv5 = ConvBnReLU(512, 512, kernel_size=5, stride=1)
        self.deconv1 = DeconvBnReLU(512, 256)
        self.deconv2 = DeconvBnReLU(256, 128)
        self.deconv3 = DeconvBnReLU(128, 64)
        self.deconv4 = DeconvBnReLU(64, 32)
        self.pred_prob1 = ConvBnReLU(32, 2, kernel_size=5, stride=1)
        self.pred_prob2 = nn.Conv3d(2, 2, kernel_size=5, stride=1, padding='same', bias=True)
        self.pred_prob3 = nn.Conv3d(2, 2, kernel_size=5, stride=1, padding='same', bias=True)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        #print("input x:",x.shape)
        conv1 = self.conv1(x)
        #print("input conv1:",conv1.shape)
        conv2 = self.conv2(conv1)
        #print("input conv2:",conv2.shape)
        conv3 = self.conv3(conv2)
        #print("input conv3:",conv3.shape)
        conv4 = self.conv4(conv3)
        #print("input conv4:",conv4.shape)
        conv5 = self.conv5(conv4)
        #print("input conv5:",conv5.shape)
        deconv1 = self.deconv1(conv5)
        #print("input deconv1:",deconv1.shape)
        deconv2 = self.deconv2(deconv1)
        #print("input deconv2:",deconv2.shape)
        deconv3 = self.deconv3(deconv2)
        #print("input deconv3:",deconv3.shape)
        deconv4 = self.deconv4(deconv3)
        #print("input deconv4:",deconv4.shape)
        pred_prob1 = self.pred_prob1(deconv4)
        pred_prob2 = self.pred_prob2(pred_prob1)
        pred_prob3 = self.pred_prob3(pred_prob2)
        #soft_prob = self.softmax(pred_prob3)
        return pred_prob3

In [15]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [16]:
print(device)

cuda


In [17]:
model = AutoEncoder().to(device)

In [18]:
# Load the checkpoint
checkpoint = torch.load('/workspace/RibCage/RibCageImplant/src/checkpoint_epoch_100.pth.tar')
model.load_state_dict(checkpoint['state_dict'])
model = model.to(device)

In [20]:
# Directory to save inference results
inference_save_dir = '/workspace/RibCage/saved_test_results_dice1/'

# Perform inference on the test dataset and save results
inference_and_evaluate_dice_model(model, val_loader, device, inference_save_dir)

  return F.conv3d(


Batch 0 - Dice: 0.0006, Hausdorff: 198.0227, MSE: 365856.2188
Batch 1 - Dice: 0.0004, Hausdorff: 264.0417, MSE: 395059.1875
Batch 2 - Dice: 0.0015, Hausdorff: 266.5914, MSE: 371341.9375
Batch 3 - Dice: 0.0000, Hausdorff: 261.0843, MSE: 343632.7812
Batch 4 - Dice: 0.0006, Hausdorff: 211.6743, MSE: 373413.0000
Batch 5 - Dice: 0.0006, Hausdorff: 180.0278, MSE: 402193.8125
Batch 6 - Dice: 0.0000, Hausdorff: 244.4974, MSE: 379551.5938
Batch 7 - Dice: 0.0000, Hausdorff: 251.8432, MSE: 328940.1562
Batch 8 - Dice: 0.0011, Hausdorff: 226.8127, MSE: 362976.2812
Batch 9 - Dice: 0.0002, Hausdorff: 197.0330, MSE: 381851.5000
Batch 10 - Dice: 0.0001, Hausdorff: 235.9428, MSE: 289615.6250
Batch 11 - Dice: 0.0004, Hausdorff: 279.1863, MSE: 445965.8125
Batch 12 - Dice: 0.0006, Hausdorff: 199.2963, MSE: 390492.1250
Batch 13 - Dice: 0.0006, Hausdorff: 181.3422, MSE: 333328.9375
Batch 14 - Dice: 0.0002, Hausdorff: 237.0253, MSE: 380598.9688
Batch 15 - Dice: 0.0003, Hausdorff: 197.9217, MSE: 318811.1875
Ba