In [33]:
import os
import numpy as np
import nibabel as nib
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import jaccard_score

# UNet Model Definition (Lightweight Version)
class UNet3D(nn.Module):
    def __init__(self):
        super(UNet3D, self).__init__()
        def conv_block(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
                nn.ReLU(inplace=True)
            )
        
        self.encoder1 = conv_block(4, 8)
        self.pool1 = nn.MaxPool3d(2, 2)
        self.encoder2 = conv_block(8, 16)
        self.pool2 = nn.MaxPool3d(2, 2)
        
        self.bottleneck = conv_block(16, 32)
        
        self.upconv2 = nn.ConvTranspose3d(32, 16, kernel_size=2, stride=2)
        self.decoder2 = conv_block(32, 16)
        self.upconv1 = nn.ConvTranspose3d(16, 8, kernel_size=2, stride=2)
        self.decoder1 = conv_block(16, 8)
        
        self.final = nn.Conv3d(8, 1, kernel_size=1)
    
    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))
        bottleneck = self.bottleneck(self.pool2(enc2))
        dec2 = self.decoder2(torch.cat([self.upconv2(bottleneck), enc2], dim=1))
        dec1 = self.decoder1(torch.cat([self.upconv1(dec2), enc1], dim=1))
        return torch.sigmoid(self.final(dec1))

# Dataset Class
class BrainSegmentationDataset(Dataset):
    def __init__(self, img_paths, mask_paths=None, train=True, transform=None):
        self.img_paths = img_paths
        self.mask_paths = mask_paths
        self.train = train
        self.transform = transform

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        # Load and stack all modalities
        img_paths = self.img_paths[idx]
        images = [nib.load(p).get_fdata() for p in img_paths]
        img = np.stack(images, axis=0)  # Stack along the channel dimension
        img = img.astype(np.float32)

        if self.train:
            # Load the mask
            mask_path = self.mask_paths[idx]
            mask = nib.load(mask_path).get_fdata()
            
            # Ensure mask is single-channel (binary segmentation)
            if mask.ndim > 3:
                mask = mask[..., 0]  # Use the first channel, adjust as needed

            mask = np.expand_dims(mask, axis=0)  # Ensure shape is [1, depth, height, width]

            if self.transform:
                img = self.transform(img)
                mask = self.transform(mask)
            return torch.tensor(img), torch.tensor(mask).float()  # Ensure masks have shape [1, D, H, W] and float type
        else:
            if self.transform:
                img = self.transform(img)
            return torch.tensor(img)

# Training Loop with Dice Score
def train_epoch_with_dice(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    dice_scores = []
    
    for inputs, masks in dataloader:
        inputs, masks = inputs.to(device), masks.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        outputs = outputs.view_as(masks).float()  # Reshape outputs to match masks
        masks = masks.float()  # Ensure masks are float type for BCELoss

        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

        # Compute Dice Score
        with torch.no_grad():
            outputs = (outputs > 0.5).float()
            y_true = masks.cpu().numpy().flatten()
            y_pred = outputs.cpu().numpy().flatten()
            intersection = np.sum(y_true * y_pred)
            dice = (2. * intersection) / (np.sum(y_true) + np.sum(y_pred) + 1e-5)
            dice_scores.append(dice)
    
    avg_dice_score = np.mean(dice_scores)
    return running_loss / len(dataloader), avg_dice_score

# Testing Loop
def test_epoch(model, dataloader, device):
    model.eval()
    dice_scores = []
    with torch.no_grad():
        for inputs, masks in dataloader:
            inputs, masks = inputs.to(device), masks.to(device)

            outputs = model(inputs)
            outputs = (outputs > 0.5).float()
            
            # Compute Dice Score
            y_true = masks.cpu().numpy().flatten()
            y_pred = outputs.cpu().numpy().flatten()
            intersection = np.sum(y_true * y_pred)
            dice = (2. * intersection) / (np.sum(y_true) + np.sum(y_pred) + 1e-5)
            dice_scores.append(dice)
    return np.mean(dice_scores)

# Main Script
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Directories (Update these paths to point to your dataset)
train_dir = "E:/Images"

# Collect image and mask paths
train_img_paths = []
train_mask_paths = []
for sample in os.listdir(train_dir):
    sample_path = os.path.join(train_dir, sample)
    if os.path.isdir(sample_path):
        modalities = sorted([os.path.join(sample_path, f) for f in os.listdir(sample_path) if "Segmented" not in f])
        mask_files = [os.path.join(sample_path, f) for f in os.listdir(sample_path) if "Segmented" in f]

        if len(modalities) != 4:
            print(f"Warning: Expected 4 modalities but found {len(modalities)} for sample {sample}")
            continue

        if not mask_files:  # Check if no mask files are found
            print(f"Warning: No mask found for sample {sample}")
            continue  # Skip this sample

        mask = mask_files[0]
        train_img_paths.append(modalities)
        train_mask_paths.append(mask)

# Datasets and Dataloaders
train_dataset = BrainSegmentationDataset(train_img_paths, train_mask_paths, train=True)
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)

# Model, Loss, Optimizer
model = UNet3D().to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Training with Dice Score
num_epochs = 10
for epoch in range(num_epochs):
    train_loss, train_dice = train_epoch_with_dice(model, train_dataloader, criterion, optimizer, device)
    print(f"Epoch {epoch+1}/{num_epochs}, Training Loss: {train_loss:.4f}, Training Dice Score: {train_dice:.4f}")


Epoch 1/10, Training Loss: 0.3929, Training Dice Score: 0.9745
Epoch 2/10, Training Loss: 0.3626, Training Dice Score: 0.9745
Epoch 3/10, Training Loss: 0.3626, Training Dice Score: 0.9745
Epoch 4/10, Training Loss: 0.3626, Training Dice Score: 0.9745
Epoch 5/10, Training Loss: 0.3626, Training Dice Score: 0.9745
Epoch 6/10, Training Loss: 0.3626, Training Dice Score: 0.9745


KeyboardInterrupt: 