In [18]:
import os
import nibabel as nib
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torch
import torch.nn.functional as F


## Dataset Loading

In [45]:
class MedicalImageDataset(Dataset):
    def __init__(self, image_dir, label_dir, target_shape=(512, 512, 100), transform=None):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.transform = transform
        self.target_shape = target_shape

        self.image_files = sorted([f for f in os.listdir(image_dir) if f.endswith('.nii.gz')])
        self.label_files = sorted([f for f in os.listdir(label_dir) if f.endswith('.nii.gz')])

    def __len__(self):
        return len(self.image_files)

    def resize_volume(self,volume, target_shape):
        volume = torch.tensor(volume, dtype=torch.float32).unsqueeze(0).unsqueeze(0)

        volume = F.interpolate(volume, size=target_shape, mode='trilinear', align_corners=False)

        volume = volume.squeeze(0)

        return volume

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.image_files[idx])
        label_path = os.path.join(self.label_dir, self.label_files[idx])

        image = nib.load(image_path).get_fdata()
        label = nib.load(label_path).get_fdata()

        image = self.resize_volume(image,(128, 128, 128))
        label = self.resize_volume(label,(128, 128, 128))

        label = label.squeeze(0).long()  

        if self.transform:
            image = self.transform(image)

        return image, label

In [46]:
image_dir = 'C:\\Projects\\Hackathons\\5C Fellowship\\FLARE22Train\\FLARE22Train\\images'
label_dir = 'C:\\Projects\\Hackathons\\5C Fellowship\\FLARE22Train\\FLARE22Train\\labels'
dataset = MedicalImageDataset(image_dir, label_dir)
data_loader = DataLoader(dataset, batch_size=1, shuffle=True)

for images, labels in data_loader:
    print(images.shape)  
    print(labels.shape) 

torch.Size([1, 1, 128, 128, 128])
torch.Size([1, 128, 128, 128])
torch.Size([1, 1, 128, 128, 128])
torch.Size([1, 128, 128, 128])
torch.Size([1, 1, 128, 128, 128])
torch.Size([1, 128, 128, 128])
torch.Size([1, 1, 128, 128, 128])
torch.Size([1, 128, 128, 128])
torch.Size([1, 1, 128, 128, 128])
torch.Size([1, 128, 128, 128])
torch.Size([1, 1, 128, 128, 128])
torch.Size([1, 128, 128, 128])
torch.Size([1, 1, 128, 128, 128])
torch.Size([1, 128, 128, 128])
torch.Size([1, 1, 128, 128, 128])
torch.Size([1, 128, 128, 128])
torch.Size([1, 1, 128, 128, 128])
torch.Size([1, 128, 128, 128])
torch.Size([1, 1, 128, 128, 128])
torch.Size([1, 128, 128, 128])
torch.Size([1, 1, 128, 128, 128])
torch.Size([1, 128, 128, 128])
torch.Size([1, 1, 128, 128, 128])
torch.Size([1, 128, 128, 128])
torch.Size([1, 1, 128, 128, 128])
torch.Size([1, 128, 128, 128])
torch.Size([1, 1, 128, 128, 128])
torch.Size([1, 128, 128, 128])
torch.Size([1, 1, 128, 128, 128])
torch.Size([1, 128, 128, 128])
torch.Size([1, 1, 128, 12

In [33]:
import nibabel as nib
import numpy as np
import os

def check_unique_labels(label_file):
    label_data = nib.load(label_file).get_fdata()
    
    unique_labels = np.unique(label_data)
    
    return unique_labels


label_files = sorted([os.path.join(label_dir, f) for f in os.listdir(label_dir) if f.endswith('.nii.gz')])

for label_file in label_files:
    unique_labels = check_unique_labels(label_file)
    print(f"Unique labels in {os.path.basename(label_file)}: {unique_labels}")

all_unique_labels = set()
for label_file in label_files:
    unique_labels = check_unique_labels(label_file)
    all_unique_labels.update(unique_labels)

print(f"Total unique labels across all files: {sorted(all_unique_labels)}")

Unique labels in FLARE22_Tr_0001.nii.gz: [ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13.]
Unique labels in FLARE22_Tr_0002.nii.gz: [ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13.]
Unique labels in FLARE22_Tr_0003.nii.gz: [ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13.]
Unique labels in FLARE22_Tr_0004.nii.gz: [ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13.]
Unique labels in FLARE22_Tr_0005.nii.gz: [ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13.]
Unique labels in FLARE22_Tr_0006.nii.gz: [ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13.]
Unique labels in FLARE22_Tr_0007.nii.gz: [ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13.]
Unique labels in FLARE22_Tr_0008.nii.gz: [ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13.]
Unique labels in FLARE22_Tr_0009.nii.gz: [ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13.]
Unique labels in FLARE22_Tr_0010.nii.gz: [ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13.]
Unique lab

In [47]:
train_dataset = MedicalImageDataset(image_dir, label_dir)
train_loader = DataLoader(dataset, batch_size=1, shuffle=True)

val_dataset = MedicalImageDataset(image_dir, label_dir)
val_loader = DataLoader(dataset, batch_size=1, shuffle=False)

In [48]:
def normalize_volume(volume):
    min_val = np.min(volume)
    max_val = np.max(volume)
    volume = (volume - min_val) / (max_val - min_val)
    return volume

def resize_volume(volume, new_shape=(128, 128, 128)):
    resampler = sitk.ResampleImageFilter()
    resampler.SetSize(new_shape)
    resampler.SetOutputSpacing([osz * ospc / nsz for osz, ospc, nsz in zip(volume.GetSize(), volume.GetSpacing(), new_shape)])
    resampler.SetInterpolator(sitk.sitkLinear)
    return sitk.GetArrayFromImage(resampler.Execute(volume))

## Model

In [49]:
import torch.nn as nn
import torch

class VNet(nn.Module):
    def __init__(self, num_classes=4):
        super(VNet, self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv3d(1, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv3d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(kernel_size=2, stride=2)
        )

        self.decoder = nn.Sequential(
            nn.Conv3d(32, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv3d(16, num_classes, kernel_size=3, stride=1, padding=1),
            nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False)
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

model = VNet(num_classes=14)

In [50]:
import torch.optim as optim
from torch.nn import CrossEntropyLoss
from torch.autograd import Variable

def dice_loss(pred, target, smooth=1.0):
    intersection = (pred * target).sum()
    return 1 - ((2. * intersection + smooth) / (pred.sum() + target.sum() + smooth))

criterion = CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [51]:
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.best_score = None
        self.counter = 0
        self.early_stop = False

    def __call__(self, val_loss, model, epoch):
        if self.best_score is None:
            self.best_score = val_loss
            self.save_checkpoint(val_loss, model, epoch)
        elif val_loss > self.best_score + self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = val_loss
            self.save_checkpoint(val_loss, model, epoch)
            self.counter = 0

    def save_checkpoint(self, val_loss, model, epoch):
        """Saves model when validation loss decreases."""
        print(f"Validation loss decreased ({self.best_score:.6f} --> {val_loss:.6f}). Saving model...")
        torch.save(model.state_dict(), f'best_model_epoch_{epoch}.pth')

In [52]:
class ModelCheckpoint:
    def __init__(self, save_path='best_model.pth', monitor='val_loss', mode='min'):
        self.save_path = save_path
        self.monitor = monitor
        self.mode = mode
        self.best_value = None

    def __call__(self, value, model):
        if self.best_value is None:
            self.best_value = value
            self.save_model(model)
        elif (self.mode == 'min' and value < self.best_value) or (self.mode == 'max' and value > self.best_value):
            self.best_value = value
            self.save_model(model)

    def save_model(self, model):
        print(f"Saving best model with {self.monitor}: {self.best_value:.6f}")
        torch.save(model.state_dict(), self.save_path)

## Training Script

In [54]:
early_stopping = EarlyStopping(patience=10, min_delta=0.001)
model_checkpoint = ModelCheckpoint(save_path='best_model.pth', monitor='val_dice_score', mode='max')

num_epochs=5

for epoch in range(num_epochs):
    # Training step
    model.train()
    running_loss = 0.0
    for images, masks in train_loader:
        images = images
        masks = masks

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks) + dice_loss(outputs, masks)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)

    epoch_loss = running_loss / len(train_loader.dataset)
    
    # Validation step
    model.eval()
    val_dice_scores = []
    with torch.no_grad():
        for images, masks in val_loader:
            images = images
            masks = masks
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            dice_score = dice_loss(preds.float(), masks.float()).item()
            val_dice_scores.append(dice_score)
    
    avg_val_dice = np.mean(val_dice_scores)

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Val Dice: {avg_val_dice:.4f}")

    # Callbacks
    early_stopping(avg_val_dice, model, epoch)
    model_checkpoint(avg_val_dice, model)

    if early_stopping.early_stop:
        print("Early stopping")
        break


Epoch 1/5, Loss: 2.7938, Val Dice: -0.3406
Validation loss decreased (-0.340639 --> -0.340639). Saving model...
Saving best model with val_dice_score: -0.340639
Epoch 2/5, Loss: 2.3552, Val Dice: -0.6728
Validation loss decreased (-0.672837 --> -0.672837). Saving model...
Epoch 3/5, Loss: 1.5974, Val Dice: -0.3970
Epoch 4/5, Loss: 1.3746, Val Dice: 0.2439
Saving best model with val_dice_score: 0.243891
Epoch 5/5, Loss: 1.5423, Val Dice: -0.2383


## Inference Script

In [None]:
model = VNet(num_classes=14)  

model.load_state_dict(torch.load('best_model.pth'))

model.eval()

In [None]:
import torch

import torch

def predict_segmentation(model, input_image):
    model.eval()  
    with torch.no_grad(): 
        input_tensor = torch.tensor(input_image).unsqueeze(0).unsqueeze(0).float()  

        output = model(input_tensor) 

        probabilities = torch.softmax(output, dim=1)
        predicted_segmentation = torch.argmax(probabilities, dim=1)  # Shape: [1, 128, 128, 128]
        predicted_segmentation = predicted_segmentation.squeeze(0).cpu().numpy().astype(np.int32)  # Shape: [128, 128, 128]

    return predicted_segmentation


input_image = nib.load('C:\\Projects\\Hackathons\\5C Fellowship\\FLARE22Train\\FLARE22Train\\images\\FLARE22_Tr_0002_0000.nii.gz').get_fdata()
predicted_segmentation = predict_segmentation(model, input_image)
nib.save(nib.Nifti1Image(predicted_segmentation, np.eye(4)), 'predicted_segmentation.nii.gz')