In [1]:
!kaggle datasets download -d debayan20000/brainsegmentation2018-dataset

Dataset URL: https://www.kaggle.com/datasets/debayan20000/brainsegmentation2018-dataset
License(s): apache-2.0
Downloading brainsegmentation2018-dataset.zip to /kaggle/working
 99%|██████████████████████████████████████▌| 1.93G/1.95G [00:07<00:00, 277MB/s]
100%|███████████████████████████████████████| 1.95G/1.95G [00:07<00:00, 268MB/s]


In [2]:
!pip install rasterio

Collecting rasterio
  Downloading rasterio-1.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.1 kB)
Collecting affine (from rasterio)
  Downloading affine-2.4.0-py3-none-any.whl.metadata (4.0 kB)
Downloading rasterio-1.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (22.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m22.2/22.2 MB[0m [31m9.7 MB/s[0m eta [36m0:00:00[0m0:00:01[0m00:01[0m
[?25hDownloading affine-2.4.0-py3-none-any.whl (15 kB)
Installing collected packages: affine, rasterio
Successfully installed affine-2.4.0 rasterio-1.4.1


In [3]:
import zipfile
import os

# Path to your .zip file
zip_file_path = '/kaggle/working/brainsegmentation2018-dataset.zip'

# Directory where you want to extract the files
extract_to_path = '/kaggle/working/'

# Create the extraction directory if it doesn't exist
os.makedirs(extract_to_path, exist_ok=True)

# Unzipping the file
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
    zip_ref.extractall(extract_to_path)

print("File unzipped successfully!")


File unzipped successfully!


In [4]:
import os
import numpy as np
import nibabel as nib
import torch
from torch.amp import autocast, GradScaler
from skimage.transform import resize
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Resize
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np

In [5]:
class BrainSegmentationDataset(Dataset):
    def __init__(self, root_dir, patch_size=(48, 240, 240), stride=8, transform=None):
        self.root_dir = root_dir
        self.patch_size = (64, 64, 16)
        self.stride = 16
        self.transform = transform
        self.patches_img = []
        self.patches_mask = []

        # Load data and extract patches
        self._load_data()
        print(f"Number of image patches: {len(self.patches_img)}") # Print the number of image patches

    def _load_data(self):
        for subfolder in os.listdir(self.root_dir):
            subfolder_path = os.path.join(self.root_dir, subfolder)
            if os.path.isdir(subfolder_path):
                pre_dir = os.path.join(subfolder_path, 'pre')
                seg_dir = os.path.join(subfolder_path, 'segm.nii')

                # Load MRI modalities
                t1_path = os.path.join(pre_dir, 'T1.nii/T1.nii')
                ir_path = os.path.join(pre_dir, 'IR.nii/IR.nii')
                flair_path = os.path.join(pre_dir, 'FLAIR.nii/FLAIR.nii')
                seg_path = os.path.join(seg_dir, 'segm.nii')

                print(f"Checking for files in subfolder: {subfolder_path}") # Print the subfolder being checked
                print(f" - T1 path: {t1_path}")
                print(f" - IR path: {ir_path}")
                print(f" - FLAIR path: {flair_path}")
                print(f" - Segmentation path: {seg_path}")

                # Check if files exist
                if not os.path.exists(t1_path):
                    print(f" - T1 file not found: {t1_path}")
                if not os.path.exists(ir_path):
                    print(f" - IR file not found: {ir_path}")
                if not os.path.exists(flair_path):
                    print(f" - FLAIR file not found: {flair_path}")
                if not os.path.exists(seg_path):
                    print(f" - Segmentation file not found: {seg_path}")
                    continue  # Skip to the next subfolder

                # Read NIfTI files
                try:
                    t1_img = nib.load(t1_path).get_fdata()
                    ir_img = nib.load(ir_path).get_fdata()
                    flair_img = nib.load(flair_path).get_fdata()
                    seg_img = nib.load(seg_path).get_fdata()

                    # Print shapes after loading images
                    print(f" - T1 shape after loading: {t1_img.shape}")
                    print(f" - IR shape after loading: {ir_img.shape}")
                    print(f" - FLAIR shape after loading: {flair_img.shape}")
                    print(f" - Segmentation shape after loading: {seg_img.shape}")

                except Exception as e:
                    print(f"Error loading images from subfolder {subfolder_path}: {e}")
                    continue

                # Resize the images to the target shape
                t1_img = self._resize(t1_img)
                ir_img = self._resize(ir_img)
                flair_img = self._resize(flair_img)
                seg_img = self._resize(seg_img)

                print(f" - T1 shape: {t1_img.shape}")
                print(f" - IR shape: {ir_img.shape}")
                print(f" - FLAIR shape: {flair_img.shape}")
                print(f" - Segmentation shape: {seg_img.shape}")



                # Extract patches
                img = np.stack([t1_img, ir_img, flair_img], axis=0)  # Shape: (3, D, H, W)
                patches_img, patches_mask = self._extract_patches(img, seg_img)

                # Print number of patches extracted
                print(f" - Number of image patches extracted: {len(patches_img)}")
                print(f" - Number of mask patches extracted: {len(patches_mask)}")


                # Store patches
                self.patches_img.extend(patches_img)
                self.patches_mask.extend(patches_mask)

    def _resize(self, image):
        # Resize image to target shape
        target_shape = (240, 240, 48)  # For simplicity, you can resize to this shape
        return resize(image, target_shape, mode='reflect', anti_aliasing=True)
    def _extract_patches(self, image, mask):
        patches_img = []
        patches_mask = []
        _, depth, height, width = image.shape
        p_depth, p_height, p_width = self.patch_size

        # Check if patch extraction is feasible
        if depth < p_depth or height < p_height or width < p_width:
            print(f"Patch size {self.patch_size} is too large for image dimensions {image.shape}.")
            return patches_img, patches_mask

        print(f"Stride: {self.stride}, Patch size: {self.patch_size}")
        print(f"Depth: {depth}, Height: {height}, Width: {width}")

        # Extract patches
        for i in range(0, depth - p_depth + 1, self.stride):
            for j in range(0, height - p_height + 1, self.stride):
                for k in range(0, width - p_width + 1, self.stride):
                    img_patch = image[:, i:i+p_depth, j:j+p_height, k:k+p_width]
                    mask_patch = mask[i:i+p_depth, j:j+p_height, k:k+p_width]
                    patches_img.append(img_patch)
                    patches_mask.append(mask_patch)
                    print(f"Patch extracted: depth {i}, height {j}, width {k}")

        print(f"Total patches extracted: {len(patches_img)}")
        return np.array(patches_img), np.array(patches_mask)

    def __len__(self):
        return len(self.patches_img)
        print(f"Number of image patches: {len(self.patches_img)}") # Print the number of image patches

    def __getitem__(self, idx):
        img = self.patches_img[idx]
        mask = self.patches_mask[idx]

        if self.transform:
            img = self.transform(img)

        return img, mask

In [6]:
torch.cuda.empty_cache()

In [7]:
print(torch.cuda.memory_summary())

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |      0 B   |      0 B   |      0 B   |      0 B   |
|       from large pool |      0 B   |      0 B   |      0 B   |      0 B   |
|       from small pool |      0 B   |      0 B   |      0 B   |      0 B   |
|---------------------------------------------------------------------------|
| Active memory         |      0 B   |      0 B   |      0 B   |      0 B   |
|       from large pool |      0 B   |      0 B   |      0 B   |      0 B   |
|       from small pool |      0 B   |      0 B   |      0 B   |      0 B   |
|---------------------------------------------------------------

In [8]:
# Define the dataset paths
train_data_dir = '/kaggle/working/training_dataset'

In [9]:
dataset = BrainSegmentationDataset(train_data_dir)

Checking for files in subfolder: /kaggle/working/training_dataset/5
 - T1 path: /kaggle/working/training_dataset/5/pre/T1.nii/T1.nii
 - IR path: /kaggle/working/training_dataset/5/pre/IR.nii/IR.nii
 - FLAIR path: /kaggle/working/training_dataset/5/pre/FLAIR.nii/FLAIR.nii
 - Segmentation path: /kaggle/working/training_dataset/5/segm.nii/segm.nii
 - T1 shape after loading: (256, 256, 192)
 - IR shape after loading: (240, 240, 48)
 - FLAIR shape after loading: (240, 240, 48)
 - Segmentation shape after loading: (240, 240, 48)
 - T1 shape: (240, 240, 48)
 - IR shape: (240, 240, 48)
 - FLAIR shape: (240, 240, 48)
 - Segmentation shape: (240, 240, 48)
Stride: 16, Patch size: (64, 64, 16)
Depth: 240, Height: 240, Width: 48
Patch extracted: depth 0, height 0, width 0
Patch extracted: depth 0, height 0, width 16
Patch extracted: depth 0, height 0, width 32
Patch extracted: depth 0, height 16, width 0
Patch extracted: depth 0, height 16, width 16
Patch extracted: depth 0, height 16, width 32
Pat

In [10]:
print(len(dataset.patches_img))
print(len(dataset.patches_mask))

3024
3024


In [11]:

# Validation split from dataset (20% validation, 80% training)
train_dataset, val_dataset = train_test_split(dataset, test_size=0.2, random_state=42)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

In [12]:
for inputs, labels in train_loader:
    # Print shapes of inputs and labels
    print(f"Input shape: {inputs.shape}")
    print(f"Label shape: {labels.shape}")
    break


Input shape: torch.Size([16, 3, 64, 64, 16])
Label shape: torch.Size([16, 64, 64, 16])


In [13]:
# Define the 3D U-Net model with transition layers
class UNet3D(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet3D, self).__init__()

        # Encoder with transition layers
        self.encoder1 = self.contract_block(in_channels, 32, kernel_size=3, padding=1)
        self.trans1 = nn.MaxPool3d(2)  # Downsampling layer
        self.encoder2 = self.contract_block(32, 64, kernel_size=3, padding=1)
        self.trans2 = nn.MaxPool3d(2)  # Downsampling layer
        self.encoder3 = self.contract_block(64, 128, kernel_size=3, padding=1)
        self.trans3 = nn.MaxPool3d(2)  # Downsampling layer

        # Bottleneck
        self.bottleneck = nn.Sequential(
            nn.Conv3d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv3d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

        # Decoder with transition layers
        self.uptrans3 = nn.ConvTranspose3d(256, 128, kernel_size=2, stride=2)  # Upsampling layer
        self.decoder3 = self.expand_block(256, 128, kernel_size=3, padding=1)
        self.uptrans2 = nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2)   # Upsampling layer
        self.decoder2 = self.expand_block(128, 64, kernel_size=3, padding=1)
        self.uptrans1 = nn.ConvTranspose3d(64, 32, kernel_size=2, stride=2)    # Upsampling layer
        self.decoder1 = self.expand_block(64, 32, kernel_size=3, padding=1)

        # Output layer - out_channels = number of classes (no softmax here)
        self.out = nn.Conv3d(32, out_channels, kernel_size=1)

    def contract_block(self, in_channels, out_channels, kernel_size=3, padding=1):
        return nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, padding=padding),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, kernel_size=kernel_size, padding=padding),
            nn.ReLU(inplace=True)
        )

    def expand_block(self, in_channels, out_channels, kernel_size=3, padding=1):
        return nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, padding=padding),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, kernel_size=kernel_size, padding=padding),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # Encoder
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.trans1(enc1))
        enc3 = self.encoder3(self.trans2(enc2))

        # Bottleneck
        bottleneck = self.bottleneck(self.trans3(enc3))

        # Decoder
        dec3 = self.uptrans3(bottleneck)
        dec3 = torch.cat((dec3, enc3), dim=1)  # Skip connection
        dec3 = self.decoder3(dec3)
        dec2 = self.uptrans2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)  # Skip connection
        dec2 = self.decoder2(dec2)
        dec1 = self.uptrans1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)  # Skip connection
        dec1 = self.decoder1(dec1)

        # Output layer with raw logits (no activation)
        out = self.out(dec1)
        return out

In [14]:
print(torch.cuda.is_available())

True


In [15]:
# Training settings
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [16]:
# Initialize the model
model = UNet3D(in_channels=3, out_channels=8)  # 8 labels to segment (excluding background)
model.to(device)


UNet3D(
  (encoder1): Sequential(
    (0): Conv3d(3, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (1): ReLU(inplace=True)
    (2): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (3): ReLU(inplace=True)
  )
  (trans1): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (encoder2): Sequential(
    (0): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (1): ReLU(inplace=True)
    (2): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (3): ReLU(inplace=True)
  )
  (trans2): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (encoder3): Sequential(
    (0): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (1): ReLU(inplace=True)
    (2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (3): ReLU(inplace=True)
  )
  (trans3): MaxPool3d(kernel_size=2, stride=2, padding

In [17]:
# Define loss function and optimizer

optimizer = torch.optim.Adam(model.parameters(), lr=0.00000001, weight_decay=1e-5)
criterion = torch.nn.CrossEntropyLoss()


In [18]:
for name, param in model.named_parameters():
    if torch.isnan(param).any():
        print(f"NaN detected in weights: {name}")

In [19]:
os.environ['TORCH_USE_CUDA_DSA'] = '1'

In [20]:
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [21]:
epochs = 1000

In [22]:
# Training and validation loop
scaler = GradScaler()  # For mixed precision training


In [23]:
def weights_init(m):
    if isinstance(m, nn.Conv3d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)


In [24]:
model.apply(weights_init)

UNet3D(
  (encoder1): Sequential(
    (0): Conv3d(3, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (1): ReLU(inplace=True)
    (2): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (3): ReLU(inplace=True)
  )
  (trans1): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (encoder2): Sequential(
    (0): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (1): ReLU(inplace=True)
    (2): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (3): ReLU(inplace=True)
  )
  (trans2): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (encoder3): Sequential(
    (0): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (1): ReLU(inplace=True)
    (2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (3): ReLU(inplace=True)
  )
  (trans3): MaxPool3d(kernel_size=2, stride=2, padding

In [25]:
def calculate_accuracy(preds, labels):
    # Apply softmax to get class probabilities
    preds = torch.argmax(preds, dim=1)  # Get the class with the highest score along the channel dimension (class dimension)
    
    # Ensure labels have the correct shape (if labels are one-hot encoded, convert to class indices)
    if labels.dim() == 4:  # For one-hot encoded labels
        labels = torch.argmax(labels, dim=1)  # Convert one-hot encoded labels to class indices
    
    # Ensure preds and labels have the same shape
    if preds.shape != labels.shape:
        raise RuntimeError(f"Shape mismatch: preds {preds.shape}, labels {labels.shape}")
    
    # Calculate the number of correct predictions
    correct = (preds == labels).float().sum()  # Count the correct pixel predictions
    accuracy = correct / labels.numel()  # Normalize by the total number of pixels
    
    return accuracy



In [None]:
# Define patience for early stopping and the learning rate scheduler
patience = 3  # Number of epochs to wait before stopping if no improvement
best_val_loss = np.inf  # Track the best validation loss
early_stop_counter = 0  # Counter for early stopping
# Define the scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2, verbose=True)

for epoch in range(epochs):
    print(f"Epoch {epoch+1}/{epochs}")

    # Training phase
    model.train()
    train_loss = 0.0
    train_accuracy = 0.0

    for i, (images, masks) in enumerate(train_loader):
        images = images.type(torch.float32).to(device)  # Ensure float32
        masks = masks.to(device)
        images = (images - images.min()) / (images.max() - images.min())

        if torch.isnan(images).any():
            print("NaN detected in input images.")
        if torch.isnan(masks).any():
            print("NaN detected in input masks.")

        # Forward pass
        with autocast(device_type='cuda'):
            outputs = model(images)
            outputs = outputs[:, 0, :, :, :]
            loss = criterion(outputs, masks)

        # Backward pass and optimization
        optimizer.zero_grad()
        scaler.scale(loss).backward()

        scaler.step(optimizer)
        scaler.update()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        train_loss += loss.item()
        train_accuracy += calculate_accuracy(outputs, masks).item()

    avg_train_loss = train_loss / len(train_loader)
    avg_train_acc = train_accuracy / len(train_loader)
    print(f"Training Loss: {avg_train_loss:.4f}, Training Accuracy: {avg_train_acc:.4f}")

    # Validation phase
    model.eval()
    val_loss = 0.0
    val_accuracy = 0.0

    with torch.no_grad():
        for i, (images, masks) in enumerate(val_loader):
            images = images.type(torch.float32).to(device)  # Ensure float32
            masks = masks.to(device)
            images = (images - images.min()) / (images.max() - images.min())

            if torch.isnan(images).any():
                print("NaN detected in input images.")
            if torch.isnan(masks).any():
                print("NaN detected in input masks.")

            # Forward pass
            outputs = model(images)
            outputs = outputs[:, 0, :, :, :]
            loss = criterion(outputs, masks)

            val_loss += loss.item()
            val_accuracy += calculate_accuracy(outputs, masks).item()

    avg_val_loss = val_loss / len(val_loader)
    avg_val_acc = val_accuracy / len(val_loader)
    print(f"Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {avg_val_acc:.4f}")

    # Step the scheduler based on validation loss
    scheduler.step(avg_val_loss)

    # Early Stopping
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        early_stop_counter = 0  # Reset counter if validation loss improves
        # Optionally save the best model
        torch.save(model.state_dict(), 'best_model.pth')
        print(f"New best validation loss: {best_val_loss:.4f}, saving model.")
    else:
        early_stop_counter += 1
        print(f"No improvement in validation loss. Early stop counter: {early_stop_counter}/{patience}")

    if early_stop_counter >= patience:
        print("Early stopping triggered. Stopping training.")
        break
    




Epoch 1/1000
Training Loss: 265.9273, Training Accuracy: 0.4873
Validation Loss: 246.0437, Validation Accuracy: 0.5052
New best validation loss: 246.0437, saving model.
Epoch 2/1000
Training Loss: 266.8014, Training Accuracy: 0.4875
Validation Loss: 246.0408, Validation Accuracy: 0.5066
New best validation loss: 246.0408, saving model.
Epoch 3/1000
Training Loss: 267.0446, Training Accuracy: 0.4896
Validation Loss: 246.0379, Validation Accuracy: 0.5079
New best validation loss: 246.0379, saving model.
Epoch 4/1000
Training Loss: 266.3851, Training Accuracy: 0.4888
Validation Loss: 246.0350, Validation Accuracy: 0.5092
New best validation loss: 246.0350, saving model.
Epoch 5/1000
Training Loss: 268.1013, Training Accuracy: 0.4913
Validation Loss: 246.0321, Validation Accuracy: 0.5105
New best validation loss: 246.0321, saving model.
Epoch 6/1000
Training Loss: 266.2347, Training Accuracy: 0.4936
Validation Loss: 246.0293, Validation Accuracy: 0.5118
New best validation loss: 246.0293, 