In [1]:
import logging
import os
import re
import glob
import nibabel as nib
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import time
from monai.networks.nets import DynUNet
from monai.networks.layers import Norm
from torch.cuda.amp import autocast, GradScaler

In [2]:
# Define function to check for zero-only labels
def is_zero_only(label_file):
    label = nib.load(label_file).get_fdata()
    return np.all(label == 0)

class MedicalDataset(Dataset):
    def __init__(self, data_list, label_list, transform=None):
        # Filter the data_list to include only the main files (those without additional suffixes)
        self.data_list = self.filter_data_files(data_list)
        self.label_list = label_list
        self.transform = transform

        # Ensure the number of filtered data files matches the number of label files
        assert len(self.data_list) == len(self.label_list), "Mismatch between data files and labels"

    def filter_data_files(self, data_list):
        # Filter data files to keep only those without the suffix "_1", "_2", etc.
        filtered_data = [f for f in data_list if re.search(r"_defected\.nii\.gz$", f)]
        return filtered_data

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        data_file = self.data_list[idx]
        label_file = self.label_list[idx]

        if not is_zero_only(label_file):
            data = nib.load(data_file).get_fdata()
            label = nib.load(label_file).get_fdata()

            # Binarize the label (values become 0 or 1)
            #label = np.where(label > 0, 1, 0).astype(np.float32)

            # Convert to tensor (data shape: (128, 128, 64))
            data_tensor = torch.from_numpy(data).float().unsqueeze(0)
            label_tensor = torch.from_numpy(label).float().unsqueeze(0)

            sample = {'data': data_tensor, 'label': label_tensor}

            # Apply any transforms (e.g., resizing)
            if self.transform:
                sample = self.transform(sample)

            return sample

        # Raise error if the label is zero-only (shouldn't happen in practice)
        raise IndexError("No valid samples available")

# Transform to resize the data
class ResizeTransform:
    def __init__(self, target_shape=(256, 256, 128)):
        self.target_shape = target_shape

    def __call__(self, sample):
        data, label = sample['data'], sample['label']
        data = F.interpolate(data.unsqueeze(0), size=self.target_shape, mode='trilinear', align_corners=False).squeeze(0)
        label = F.interpolate(label.unsqueeze(0), size=self.target_shape, mode='trilinear', align_corners=False).squeeze(0)
        return {'data': data, 'label': label}

# DataLoader creation
def create_dataloader(data_list, label_list, transform=None, batch_size=2, shuffle=True, num_workers=32):
    dataset = MedicalDataset(data_list, label_list, transform=transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, drop_last=True, pin_memory=True)
    return dataloader
    
# Define directories
train_data_dir = '/workspace/RibCage/train-ribfrac-defected'
train_label_dir = '/workspace/RibCage/train-segmented_ribfrac'
val_data_dir = '/workspace/RibCage/val-ribfrac-defected'
val_label_dir = '/workspace/RibCage/val-segmented_ribfrac'

# Get list of files
train_data_list = sorted(glob.glob(os.path.join(train_data_dir, '*.nii')) + glob.glob(os.path.join(train_data_dir, '*.nii.gz')))
train_label_list = sorted(glob.glob(os.path.join(train_label_dir, '*.nii')) + glob.glob(os.path.join(train_label_dir, '*.nii.gz')))
val_data_list = sorted(glob.glob(os.path.join(val_data_dir, '*.nii')) + glob.glob(os.path.join(val_data_dir, '*.nii.gz')))
val_label_list = sorted(glob.glob(os.path.join(val_label_dir, '*.nii')) + glob.glob(os.path.join(val_label_dir, '*.nii.gz')))

# Ensure data and labels are paired properly
#assert len(train_data_list) == len(train_label_list), "Training data and labels are not of the same length"
#assert len(val_data_list) == len(val_label_list), "Validation data and labels are not of the same length"

# Define the transform
resize_transform = ResizeTransform(target_shape=(256, 256, 128))

# Create DataLoader for training and validation
train_loader = create_dataloader(train_data_list, train_label_list, transform=resize_transform, batch_size=2, shuffle=True)
val_loader = create_dataloader(val_data_list, val_label_list, transform=resize_transform, batch_size=2, shuffle=False)

In [3]:
# Iterate through the DataLoader
print("\nTesting DataLoader for training data...")
for i, batch in enumerate(train_loader):
    data_tensor = batch['data']
    label_tensor = batch['label']
    print(f'Batch {i + 1}: Data shape: {data_tensor.shape}, Label shape: {label_tensor.shape}')
    if i == 2:  # Load a few batches for demonstration
        break

# Iterate through the DataLoader
print("\nTesting DataLoader for testing data...")
for i, batch in enumerate(val_loader):
    data_tensor = batch['data']
    label_tensor = batch['label']
    print(f'Batch {i + 1}: Data shape: {data_tensor.shape}, Label shape: {label_tensor.shape}')
    if i == 2:  # Load a few batches for demonstration
        break


Testing DataLoader for training data...
Batch 1: Data shape: torch.Size([2, 1, 256, 256, 128]), Label shape: torch.Size([2, 1, 256, 256, 128])
Batch 2: Data shape: torch.Size([2, 1, 256, 256, 128]), Label shape: torch.Size([2, 1, 256, 256, 128])
Batch 3: Data shape: torch.Size([2, 1, 256, 256, 128]), Label shape: torch.Size([2, 1, 256, 256, 128])

Testing DataLoader for testing data...
Batch 1: Data shape: torch.Size([2, 1, 256, 256, 128]), Label shape: torch.Size([2, 1, 256, 256, 128])
Batch 2: Data shape: torch.Size([2, 1, 256, 256, 128]), Label shape: torch.Size([2, 1, 256, 256, 128])
Batch 3: Data shape: torch.Size([2, 1, 256, 256, 128]), Label shape: torch.Size([2, 1, 256, 256, 128])


In [11]:
print("Started Importing Necessary Libraries...")
import os
import glob
import nibabel as nib
import numpy as np
import random
from scipy.ndimage import zoom
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

print("Necessary Libraries imported")

print("\nData Loading Started...")

# Function to resize data
def resizing(data, target_shape=(128, 128, 64)):
    """Resize the data to the target shape."""
    a, b, c = data.shape
    return zoom(data, (target_shape[0] / a, target_shape[1] / b, target_shape[2] / c), order=2, mode='constant')


# Custom Dataset Class
class MedicalDataset(Dataset):
    def __init__(self, data_list, label_list, transform=None):
        """
        Args:
            data_list (list): List of paths to the data files.
            label_list (list): List of paths to the label files.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        assert len(data_list) == len(label_list), "Data and label lists must be the same length"
        self.data_list = data_list
        self.label_list = label_list
        self.transform = transform

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        data_file = self.data_list[idx]
        label_file = self.label_list[idx]

        data = nib.load(data_file).get_fdata()
        label = nib.load(label_file).get_fdata()

        data_resized = resizing(data)
        label_resized = resizing(label)

        #data_resized = np.expand_dims(data_resized, axis=(0, 1))
        #label_resized = np.expand_dims(label_resized, axis=(0, 1))

        data_tensor = torch.from_numpy(data_resized).float()
        label_tensor = torch.from_numpy(label_resized).float()

        sample = {'data': data_tensor, 'label': label_tensor}

        if self.transform:
            sample = self.transform(sample)

        return sample

# Create DataLoader
def create_dataloader(data_list, label_list, batch_size=4, shuffle=True, num_workers=2):
    """Create DataLoader for the dataset."""
    dataset = MedicalDataset(data_list, label_list)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
    return dataloader


# Define directories
train_data_dir = '/workspace/RibCage/train-ribfrac-defected'
train_label_dir = '/workspace/RibCage/train-segmented_ribfrac'

# Get list of files
train_data_list = sorted(glob.glob(os.path.join(train_data_dir, '*.nii')) + glob.glob(os.path.join(train_data_dir, '*.nii.gz')))
train_label_list = sorted(glob.glob(os.path.join(train_label_dir, '*.nii')) + glob.glob(os.path.join(train_label_dir, '*.nii.gz')))
#est_data_list = sorted(glob.glob(os.path.join(test_data_dir, '*.nii')) + glob.glob(os.path.join(test_data_dir, '*.nii.gz')))

# Ensure data and labels are paired properly
assert len(train_data_list) == len(train_label_list), "Training data and labels are not of the same length"

# Create DataLoader for training
train_loader = create_dataloader(train_data_list, train_label_list, batch_size=2, shuffle=True)

print("Train Loader has been created...")

# Iterate through the DataLoader
print("\nTesting DataLoader for training data...")
for i, batch in enumerate(train_loader):
    data_tensor = batch['data']
    label_tensor = batch['label']
    print(f'Batch {i + 1}: Data shape: {data_tensor.shape}, Label shape: {label_tensor.shape}')
    if i == 2:  # Load a few batches for demonstration
        break

Started Importing Necessary Libraries...
Necessary Libraries imported

Data Loading Started...


AssertionError: Training data and labels are not of the same length

In [2]:
import nibabel as nib
import os

def save_sample_from_loader(dataloader, output_dir, sample_index=0):
    """ Save a sample from the DataLoader to a file after resizing. """
    os.makedirs(output_dir, exist_ok=True)

    # Iterate through the DataLoader to get a batch
    for batch in dataloader:
        # Extract the sample (assuming sample_index < batch size)
        data_img = batch['data'][sample_index].cpu().numpy()
        label_img = batch['label'][sample_index].cpu().numpy()

        # Create NIfTI images
        data_img_nifti = nib.Nifti1Image(data_img.squeeze(), np.eye(4))
        label_img_nifti = nib.Nifti1Image(label_img.squeeze(), np.eye(4))

        # Define file paths
        data_file_path = os.path.join(output_dir, 'resized_data_sample2.nii.gz')
        label_file_path = os.path.join(output_dir, 'resized_label_sample2.nii.gz')

        # Save the images
        nib.save(data_img_nifti, data_file_path)
        nib.save(label_img_nifti, label_file_path)
        
        print(f"Saved resized data sample to {data_file_path}")
        print(f"Saved resized label sample to {label_file_path}")
        
        break  # Only save one sample

# Define output directory
output_dir = '/workspace/RibCage/resized_samples'

# Save a sample from the training DataLoader
save_sample_from_loader(train_loader, output_dir, sample_index=0)

Saved resized data sample to /workspace/RibCage/resized_samples/resized_data_sample2.nii.gz
Saved resized label sample to /workspace/RibCage/resized_samples/resized_label_sample2.nii.gz


In [10]:
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact

# Load the NIfTI file
nifti_file = '/workspace/RibCage/val-ribfrac-implants/RibFrac421_implant.nii.gz'
nifti_image = nib.load(nifti_file)
ct_volume = nifti_image.get_fdata()  # Get the image data as a NumPy array

# Function to display a single axial slice
def show_axial_slice(slice_index):
    plt.imshow(ct_volume[:, :, slice_index], cmap='gray')
    plt.title(f'Axial Slice {slice_index}')
    plt.axis('off')
    plt.show()

# Interactive slider to browse through axial slices
interact(show_axial_slice, slice_index=(0, ct_volume.shape[2] - 1));

interactive(children=(IntSlider(value=162, description='slice_index', max=324), Output()), _dom_classes=('widg…