In [None]:
!pip install nibabel
!pip install SimpleITK

Collecting SimpleITK
  Downloading SimpleITK-2.4.1-cp311-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.9 kB)
Downloading SimpleITK-2.4.1-cp311-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (52.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m52.3/52.3 MB[0m [31m12.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: SimpleITK
Successfully installed SimpleITK-2.4.1


In [None]:
import cv2
import torch
import numpy as np
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
import random
from matplotlib import pyplot as plt
import os
import nibabel as nib
from sklearn.model_selection import train_test_split
import SimpleITK as sitk
from tqdm import tqdm

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [None]:
# Utility functions
def pad_to_shape(this, shp):
    if len(shp) == 4:
        pad = (0, shp[3] - this.shape[3], 0, shp[2] - this.shape[2])
    elif len(shp) == 5:
        pad = (0, shp[4] - this.shape[4], 0, shp[3] - this.shape[3], 0, shp[2] - this.shape[2])
    return F.pad(this, pad)

def calculate_dice_score(pred_mask, gt_mask):
    intersection = torch.sum(pred_mask * gt_mask)
    total_pixels = torch.sum(pred_mask) + torch.sum(gt_mask)
    dice = (2.0 * intersection) / (total_pixels + 1e-8)  # Adding a small epsilon to avoid division by zero
    return dice

def dice_score(y_pred_bin, y_true):
    """
    Args:
        y_pred_bin: shape => (batch_size, 1, h, w, d)
        y_true: shape => (batch_size, 1, h, w, d)

    Returns:
        : shape => (batch_size, dice_score)
    """
    dice_scores = []
    for pred_mask, gt_mask in zip(y_pred_bin, y_true):
        dice = calculate_dice_score(pred_mask, gt_mask)
        dice_scores.append(dice)
    return dice_scores

In [None]:
# Data Loading Functions
def load_nii_data(images_dir, labels_dir, max_samples=None):
    """Load .nii CT scan data and corresponding labels"""
    image_files = sorted([f for f in os.listdir(images_dir) if f.endswith('.nii') or f.endswith('.nii.gz')])
    label_files = sorted([f for f in os.listdir(labels_dir) if f.endswith('.nii') or f.endswith('.nii.gz')])

    if max_samples:
        image_files = image_files[:max_samples]
        label_files = label_files[:max_samples]

    images = []
    labels = []

    print("Loading NII data...")
    for img_file, lbl_file in tqdm(zip(image_files, label_files), total=len(image_files)):
        # Load image and label
        img_path = os.path.join(images_dir, img_file)
        lbl_path = os.path.join(labels_dir, lbl_file)

        img_nii = nib.load(img_path)
        lbl_nii = nib.load(lbl_path)

        # Convert to numpy arrays
        img_data = img_nii.get_fdata()
        lbl_data = lbl_nii.get_fdata()

        # Ensure label is binary (liver segmentation)
        lbl_data = (lbl_data > 0).astype(np.float32)

        # Add channel dimension and convert to torch tensors
        img_tensor = np.expand_dims(img_data, axis=0)
        lbl_tensor = np.expand_dims(lbl_data, axis=0)

        images.append(img_tensor)
        labels.append(lbl_tensor)

    return images, labels

def preprocess_data(images, labels, target_shape=(1, 128, 128, 128)):
    """Preprocess data - resizing, normalizing, etc."""
    processed_images = []
    processed_labels = []

    print("Preprocessing data...")
    for img, lbl in tqdm(zip(images, labels), total=len(images)):
        # Get original shape
        orig_shape = img.shape

        # Resize if necessary (using simple cropping/padding for demonstration)
        # In production, consider proper interpolation methods
        if orig_shape[1:] != target_shape[1:]:
            # Crop or pad
            cropped_img = img[:,
                             :min(orig_shape[1], target_shape[1]),
                             :min(orig_shape[2], target_shape[2]),
                             :min(orig_shape[3], target_shape[3])]

            padded_img = np.zeros(target_shape)
            padded_img[:,
                      :cropped_img.shape[1],
                      :cropped_img.shape[2],
                      :cropped_img.shape[3]] = cropped_img

            cropped_lbl = lbl[:,
                             :min(orig_shape[1], target_shape[1]),
                             :min(orig_shape[2], target_shape[2]),
                             :min(orig_shape[3], target_shape[3])]

            padded_lbl = np.zeros(target_shape)
            padded_lbl[:,
                      :cropped_lbl.shape[1],
                      :cropped_lbl.shape[2],
                      :cropped_lbl.shape[3]] = cropped_lbl

            processed_images.append(padded_img)
            processed_labels.append(padded_lbl)
        else:
            processed_images.append(img)
            processed_labels.append(lbl)

    return np.array(processed_images), np.array(processed_labels)

In [None]:
# Training and Evaluation Functions
def one_epoch(model, loader, criterion, optimizer, scheduler, device, samples_count, phase):
  if phase == 'train':
    model.train()  # Set model to training mode
  else:
    model.eval()

  running_loss = 0.0
  running_dice = 0.0

  # Iterate over data.
  for inputs, labels, indices in loader:
    inputs = inputs.type(torch.FloatTensor).to(device)
    labels = labels.type(torch.LongTensor).to(device)

    # zero the parameter gradients
    optimizer.zero_grad()

    # forward
    with torch.set_grad_enabled(phase == 'train'):
      outputs = model(inputs)
      _, preds = torch.max(outputs, 1, keepdim=True)
      loss = criterion(outputs, labels[:,0])

      if phase == 'train':
        loss.backward()
        optimizer.step()

    # statistics
    running_loss += loss.item()
    running_dice += torch.sum(torch.stack(dice_score(preds, labels)))

    if phase == 'train' and scheduler:
        scheduler.step()

  loss = running_loss / len(loader)
  dice = running_dice / samples_count[phase]

  return loss, dice.cpu()

def train(model, loaders, criterion, optimizer, num_epochs, device, model_path, samples_count, scheduler=None):
  best_valid_loss = float('inf')
  best_valid_dice = 0

  dice_dic, loss_dic = {}, {}
  loss_dic['train'], loss_dic['valid'] = [], []
  dice_dic['train'], dice_dic['valid'] = [], []

  for epoch in range(num_epochs):
      train_loss, train_dice = one_epoch(model, loaders['train'], criterion, optimizer, scheduler, device, samples_count, phase='train')
      val_loss, val_dice = one_epoch(model, loaders['valid'], criterion, optimizer, scheduler, device, samples_count, phase='valid')

      loss_dic['train'].append(train_loss)
      loss_dic['valid'].append(val_loss)
      dice_dic['train'].append(train_dice)
      dice_dic['valid'].append(val_dice)

      if val_dice > best_valid_dice:
        best_valid_dice = val_dice
        best_valid_loss = val_loss
        torch.save(model.state_dict(), model_path)

      print(f'Epoch [{epoch+1}/{num_epochs}] - '
            f'Train Loss: {train_loss:.4f} - '
            f'Train Dice: {train_dice:.4f} - '
            f'Valid Loss: {val_loss:.4f} - '
            f'Valid Dice {val_dice:.4f}')

  return loss_dic, dice_dic

def evaluate(model, loaders, criterion, optimizer, device, samples_count, phase, scheduler=None):
  test_loss, test_dice = one_epoch(model, loaders[phase], criterion, optimizer, scheduler, device, samples_count, phase)
  print(f'Test Loss: {test_loss:.4f} - '
        f'Test Dice {test_dice:.4f}')
  return test_dice

def show_plots(num_epochs, data, metric):
  e = np.arange(num_epochs)
  plt.figure(figsize=(10, 6))
  plt.plot(e, data['train'], label='train '+metric)
  plt.plot(e, data['valid'], label='validation '+metric)
  plt.xlabel('epoch')
  plt.ylabel(metric)
  plt.legend()
  plt.savefig(f'{metric}_plot.png')
  plt.show()


In [None]:
# Model definition
class First3D(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels, dropout=False):
        super(First3D, self).__init__()

        layers = [
            nn.Conv3d(in_channels, middle_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(middle_channels, track_running_stats=False),
            nn.ReLU(inplace=True),
            nn.Conv3d(middle_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels, track_running_stats=False),
            nn.ReLU(inplace=True)
        ]

        if dropout:
            assert 0 <= dropout <= 1, 'dropout must be between 0 and 1'
            layers.append(nn.Dropout3d(p=dropout))

        self.first = nn.Sequential(*layers)

    def forward(self, x):
        return self.first(x)

class Encoder3D(nn.Module):
    def __init__(
            self, in_channels, middle_channels, out_channels,
            dropout=False, downsample_kernel=2
    ):
        super(Encoder3D, self).__init__()

        layers = [
            nn.MaxPool3d(kernel_size=downsample_kernel),
            nn.Conv3d(in_channels, middle_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(middle_channels, track_running_stats=False),
            nn.ReLU(inplace=True),
            nn.Conv3d(middle_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels, track_running_stats=False),
            nn.ReLU(inplace=True)
        ]

        if dropout:
            assert 0 <= dropout <= 1, 'dropout must be between 0 and 1'
            layers.append(nn.Dropout3d(p=dropout))

        self.encoder = nn.Sequential(*layers)

    def forward(self, x):
        return self.encoder(x)

class Center3D(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels, deconv_channels, dropout=False):
        super(Center3D, self).__init__()

        layers = [
            nn.MaxPool3d(kernel_size=2),
            nn.Conv3d(in_channels, middle_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(middle_channels, track_running_stats=False),
            nn.ReLU(inplace=True),
            nn.Conv3d(middle_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels, track_running_stats=False),
            nn.ReLU(inplace=True),
            nn.ConvTranspose3d(out_channels, deconv_channels, kernel_size=2, stride=2)
        ]

        if dropout:
            assert 0 <= dropout <= 1, 'dropout must be between 0 and 1'
            layers.append(nn.Dropout3d(p=dropout))

        self.center = nn.Sequential(*layers)

    def forward(self, x):
        return self.center(x)

class Decoder3D(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels, deconv_channels, dropout=False):
        super(Decoder3D, self).__init__()

        layers = [
            nn.Conv3d(in_channels, middle_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(middle_channels, track_running_stats=False),
            nn.ReLU(inplace=True),
            nn.Conv3d(middle_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels, track_running_stats=False),
            nn.ReLU(inplace=True),
            nn.ConvTranspose3d(out_channels, deconv_channels, kernel_size=2, stride=2)
        ]

        if dropout:
            assert 0 <= dropout <= 1, 'dropout must be between 0 and 1'
            layers.append(nn.Dropout3d(p=dropout))

        self.decoder = nn.Sequential(*layers)

    def forward(self, x):
        return self.decoder(x)

class Last3D(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels, softmax=False):
        super(Last3D, self).__init__()

        layers = [
            nn.Conv3d(in_channels, middle_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(middle_channels, track_running_stats=False),
            nn.ReLU(inplace=True),
            nn.Conv3d(middle_channels, middle_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(middle_channels, track_running_stats=False),
            nn.ReLU(inplace=True),
            nn.Conv3d(middle_channels, out_channels, kernel_size=1),
        ]

        self.first = nn.Sequential(*layers)

    def forward(self, x):
        return self.first(x)

class UNet3D(nn.Module):
    def __init__(self, in_channels, out_channels, conv_depths=(16, 32, 64, 128, 256)):
        assert len(conv_depths) > 2, 'conv_depths must have at least 3 members'

        super(UNet3D, self).__init__()

        # defining encoder layers
        encoder_layers = []
        encoder_layers.append(First3D(in_channels, conv_depths[0], conv_depths[0]))
        encoder_layers.extend([Encoder3D(conv_depths[i], conv_depths[i + 1], conv_depths[i + 1])
                               for i in range(len(conv_depths)-2)])

        # defining decoder layers
        decoder_layers = []
        decoder_layers.extend([Decoder3D(2 * conv_depths[i + 1], 2 * conv_depths[i], 2 * conv_depths[i], conv_depths[i])
                               for i in reversed(range(len(conv_depths)-2))])
        decoder_layers.append(Last3D(conv_depths[1], conv_depths[0], out_channels))

        # encoder, center and decoder layers
        self.encoder_layers = nn.ModuleList(encoder_layers)
        self.center = Center3D(conv_depths[-2], conv_depths[-1], conv_depths[-1], conv_depths[-2])
        self.decoder_layers = nn.ModuleList(decoder_layers)

    def forward(self, x, return_all=False):
        # Store intermediates for skip connections
        skip_connections = []

        # Encode
        x = self.encoder_layers[0](x)
        skip_connections.append(x)

        for i in range(1, len(self.encoder_layers)):
            x = self.encoder_layers[i](x)
            skip_connections.append(x)

        # Bottleneck
        x = self.center(x)

        # Decode with skip connections
        for i in range(len(self.decoder_layers)-1):
            skip = skip_connections[-(i+1)]
            x = torch.cat((x, skip), dim=1)
            x = self.decoder_layers[i](x)

        # Final layer
        skip = skip_connections[0]
        x = torch.cat((x, skip), dim=1)
        x = self.decoder_layers[-1](x)

        return x

In [None]:
class Dataset3D(Dataset):
    def __init__(self, x, y, normalization=True):
        self.normalization = normalization
        self.x = x
        self.y = y

    def __len__(self):
        return len(self.x)  # number of samples

    def __getitem__(self, index):  # sampling method. used by DataLoader.
        x = self.x[index]
        y = self.y[index]
        if self.normalization:
            # Normalize per volume
            x = (x - x.min()) / (x.max() - x.min() + 1e-8)
        return x, y, index  # we return the index as well for future use

In [None]:
def predict(model, image, device):
    """
    Predict segmentation for a single 3D image
    """
    # Add batch dimension if not present
    if len(image.shape) == 4:
        image = np.expand_dims(image, axis=0)

    input_tensor = torch.from_numpy(image).type(torch.float).to(device)
    model.eval()

    with torch.no_grad():
        output = model(input_tensor)
        _, pred = torch.max(output, 1, keepdim=True)

    return pred.cpu().numpy()

def save_prediction_as_nii(prediction, reference_nii_path, output_path):
    """
    Save a prediction array as a .nii file using the reference file's metadata
    """
    # Load reference nii to get affine transformation and header
    ref_nii = nib.load(reference_nii_path)

    # Remove batch dimension if present
    if len(prediction.shape) == 5:
        prediction = prediction[0]

    # Remove channel dimension
    if prediction.shape[0] == 1:
        prediction = prediction[0]

    # Create new nii file
    pred_nii = nib.Nifti1Image(prediction.astype(np.int32), ref_nii.affine, ref_nii.header)
    nib.save(pred_nii, output_path)
    print(f"Saved prediction to {output_path}")

def get_slice_indices(img, k):
    """Get random slice indices that contain some information"""
    slices = img.shape[1]
    indices = []
    i = 0
    while i < k:
        idx = random.randint(0, slices-1)
        rnd_slice = img[:, idx]

        # Check if slice contains information
        if not np.all(rnd_slice == 0):
            indices.append(idx)
            i += 1
    return indices

def visualize_results(image, mask, prediction, num_slices=4):
    """Visualize image, mask and prediction"""
    slices = image.shape[1]
    slice_indices = get_slice_indices(mask, num_slices)

    fig, axs = plt.subplots(3, num_slices, figsize=(4*num_slices, 12))

    for i, s in enumerate(slice_indices):
        axs[0][i].title.set_text(f"Prediction slice {s}")
        axs[0][i].imshow(prediction[0, s, 0], cmap='bone')

        axs[1][i].title.set_text(f"Ground truth slice {s}")
        axs[1][i].imshow(mask[0, s, 0], cmap='bone')

        axs[2][i].title.set_text(f"Original image slice {s}")
        axs[2][i].imshow(image[0, s, 0], cmap='bone')

    fig.tight_layout()
    plt.savefig('visualization_results.png')
    plt.show()

In [None]:
# Main execution
# Set paths
images_dir = "/content/drive/MyDrive/LITS/imagesTr"
labels_dir = "/content/drive/MyDrive/LITS/labelsTr"
model_path = "/content/drive/MyDrive/LITS/Models/model3d.pt"
output_dir = "/content/drive/MyDrive/LITS/predictions"

# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

# Load and preprocess data
raw_images, raw_labels = load_nii_data(images_dir, labels_dir)
processed_images, processed_labels = preprocess_data(raw_images, raw_labels)

# Split data into train, validation and test sets
X_train_val, X_test, y_train_val, y_test = train_test_split(
   processed_images, processed_labels, test_size=0.15, random_state=42
)

train_X, valid_X, train_Y, valid_Y = train_test_split(
   X_train_val, y_train_val, test_size=0.2, random_state=42
)

test_X, test_Y = X_test, y_test

print(f"Training samples: {len(train_X)}")
print(f"Validation samples: {len(valid_X)}")
print(f"Test samples: {len(test_X)}")

# Create data loaders
train_loader3d = DataLoader(
   Dataset3D(train_X, train_Y, normalization=True),
   batch_size=1,
   shuffle=True,
   num_workers=2
)
print('Train Loader Done')

valid_loader3d = DataLoader(
   Dataset3D(valid_X, valid_Y, normalization=True),
   batch_size=1,
   shuffle=False,
   num_workers=2
)
print('Validation Loader Done')

test_loader3d = DataLoader(
   Dataset3D(test_X, test_Y, normalization=True),
   batch_size=1,
   shuffle=False,
   num_workers=2
)
print('Test Loader Done')

Loading NII data...


 17%|█▋        | 21/123 [02:20<11:24,  6.71s/it]


KeyboardInterrupt: 

In [None]:
# Set up sample counts and dataloaders
train_samples_count3d = len(train_loader3d.dataset)
val_samples_count3d = len(valid_loader3d.dataset)
test_samples_count3d = len(test_loader3d.dataset)

samples_count3d = {
   'train': train_samples_count3d,
   'valid': val_samples_count3d,
   'test': test_samples_count3d
}

dataloaders3d = {
   'train': train_loader3d,
   'valid': valid_loader3d,
   'test': test_loader3d
}

# Initialize model, optimizer and loss function
model3d = UNet3D(in_channels=1, out_channels=2)
model3d = model3d.to(device).float()
optimizer = torch.optim.Adam(model3d.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
epochs = 30

# Train model
loss_dic, dice_dic = train(model3d, dataloaders3d, criterion, optimizer, epochs, device, model_path, samples_count3d)

In [None]:
# Plot results
show_plots(epochs, loss_dic, 'loss')
show_plots(epochs, dice_dic, 'dice score')

In [None]:
# Load best model and evaluate
model3d.load_state_dict(torch.load(model_path))
test_dice = evaluate(model3d, dataloaders3d, criterion, optimizer, device, samples_count3d, 'test')

# Make prediction on a test sample and save as .nii
test_idx = random.randint(0, len(test_X)-1)
test_image = test_X[test_idx]
test_mask = test_Y[test_idx]

# Get original file path for reference
test_file = sorted([f for f in os.listdir(images_dir) if f.endswith('.nii') or f.endswith('.nii.gz')])[test_idx]
reference_path = os.path.join(images_dir, test_file)

# Predict and save
prediction = predict(model3d, test_image, device)
output_path = os.path.join(output_dir, f"prediction_{test_file}")
save_prediction_as_nii(prediction, reference_path, output_path)

# Visualize results
visualize_results(test_image, test_mask, prediction)

print(f"Final Test Dice Score: {test_dice:.4f}")
print(f"Prediction saved to {output_path}")