In [1]:
!pip install torch nibabel torchinfo 'monai[all]'

Collecting monai[all]
  Downloading monai-1.4.0-py3-none-any.whl.metadata (11 kB)
Collecting clearml>=1.10.0rc0 (from monai[all])
  Downloading clearml-1.16.5-py2.py3-none-any.whl.metadata (17 kB)
Collecting einops (from monai[all])
  Downloading einops-0.8.0-py3-none-any.whl.metadata (12 kB)
Collecting fire (from monai[all])
  Downloading fire-0.7.0.tar.gz (87 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m87.2/87.2 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l- \ done
[?25hCollecting gdown>=4.7.3 (from monai[all])
  Downloading gdown-5.2.0-py3-none-any.whl.metadata (5.8 kB)
Collecting imagecodecs (from monai[all])
  Downloading imagecodecs-2024.9.22-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB)
Collecting itk>=5.2 (from monai[all])
  Downloading itk-5.4.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (22 kB)
Collecting lmdb (from monai[all])
  Downloading lmdb-1.

In [2]:
pip install torchmetrics


Note: you may need to restart the kernel to use updated packages.


In [3]:
from monai.networks.nets import AttentionUnet
from monai.networks.layers import Norm
import torchinfo

import zipfile
import random
import os

import torch
import numpy as np
import torch.nn.functional as F
from torch.utils.data import Dataset

from torch.utils.data import Dataset, DataLoader, random_split

from tqdm import tqdm
import matplotlib.pyplot as plt
from IPython.display import clear_output
from torchmetrics.functional import structural_similarity_index_measure
from monai.metrics import DiceMetric
from tqdm import tqdm
import matplotlib.pyplot as plt

> # **Model : Attention Unet**



In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = CustomAttentionUnet(
    spatial_dims=3,
    in_channels=1,
    out_channels=1,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    kernel_size=3
).to(device)

NameError: name 'CustomAttentionUnet' is not defined

In [None]:
torchinfo.summary(model, input_size=(8, 1, 16,  128,  128))

> # Dataset : openBHB Dataset

In [None]:
def split_data(source_dir, test_ratio=0.2, max_volumes=500):
    # Get a list of all volume files in the source directory
    volumes = [f for f in os.listdir(source_dir) if os.path.isfile(os.path.join(source_dir, f))]

    # Limit to the first 400 volumes
    volumes = volumes[:max_volumes]

    # Shuffle volumes randomly
    random.shuffle(volumes)

    # Calculate the split index
    split_index = int(len(volumes) * (1 - test_ratio))

    # Split volumes into train and test sets
    train_volumes = volumes[:split_index]
    test_volumes = volumes[split_index:]

    return train_volumes, test_volumes

# Usage example:
source_dir = "/kaggle/input/openbhb/val_quasiraw"
test_ratio = 0.1  # Adjust this to change the test/train split

train_volumes, test_volumes = split_data(source_dir, test_ratio, max_volumes=600)


print("Training set volumes:", len(train_volumes))
print("Testing set volumes:", len(test_volumes))


In [None]:
import os
import numpy as np
import torch

class BrainMRIDataset(Dataset):
    def __init__(self, base_path, volume_list, target_size=(88, 128, 128), slice_depth=16, transform=None):
        self.base_path = base_path
        self.volume_list = volume_list
        self.target_size = target_size
        self.slice_depth = slice_depth
        self.transform = transform

    def __len__(self):
        return len(self.volume_list)

    def crop(self, volume, start_y=20, end_y=160, start_x=20, end_x=196, start_z=50, end_z=130):
        cropped_volume = volume[start_z:end_z, start_y:end_y, start_x:end_x]
        return cropped_volume

    def get_slices(self, mri_volume):
        if len(mri_volume.shape) != 3:
            mri_volume = mri_volume.squeeze()

        patches = []
        num_slices = mri_volume.shape[0] // self.slice_depth

        for i in range(num_slices):
            start = i * self.slice_depth
            end = start + self.slice_depth
            patch = mri_volume[start:end, :, :]
            patches.append(patch)

        remainder = mri_volume.shape[0] % self.slice_depth
        if remainder > 0:
            patch = mri_volume[-self.slice_depth:, :, :]
            patches.append(patch)

        return patches

    def __getitem__(self, index):
        # Get the filename for the current volume
        volume_file = self.volume_list[index]

        # Construct the file path
        file_path = os.path.join(self.base_path, volume_file)

        # Load the .npy file
        mri_volume = np.load(file_path)  # Shape: (1, 1, 182, 218, 182)

        # Convert to tensor and remove redundant dimensions
        mri_volume = torch.tensor(mri_volume).float().squeeze().squeeze()  # Shape: [182, 218, 182]

        # (C, H, W, D ) - > (C, D, H, W)
        mri_volume = mri_volume.permute(2, 0, 1)

        # Crop to (80, 140, 176)
        mri_crop = self.crop(mri_volume)

        # Resize to target shape (128x128x128)
        mri_resize = F.interpolate(
            mri_crop.unsqueeze(0).unsqueeze(0),
            size=(mri_crop.shape[0], 128, 128),
            mode='trilinear',
            align_corners=False
        ).squeeze()

        # Normalize the volume
        mri_volume = (mri_resize - mri_resize.min()) / (mri_resize.max() - mri_resize.min() + 1e-8)

        # Slice the volume along the depth axis
        mri_slices = self.get_slices(mri_volume)

        # Add channel dimension back to each slice
        mri_slices = [slice.unsqueeze(0) for slice in mri_slices]  # Shape: [1, 128, 128, 8]

        return mri_slices


In [None]:
# Initialize dataset
dataset = BrainMRIDataset(base_path='/kaggle/input/openbhb/val_quasiraw', volume_list=train_volumes)

# Test the dataset
volume_slices = dataset[0]
print(f"Dataset length : {len(dataset)}, Number of slices: {len(volume_slices)}, Slice shape: {volume_slices[0].shape}")

In [None]:
class SliceDatasetFromList(Dataset):
    def __init__(self, patch_list):
        self.patch_list = [patch for sublist in patch_list for patch in sublist]

    def __len__(self):
        return len(self.patch_list)

    def __getitem__(self, index):
        patch = self.patch_list[index]
        patch_tensor = torch.tensor(patch).float()

        return patch_tensor, patch_tensor


Slices = SliceDatasetFromList(dataset)

train_ratio = 0.85
val_ratio = 0.15

train_size = int(train_ratio * len(Slices))
val_size = len(Slices) - train_size

train_dataset, val_dataset = random_split(Slices, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

print("Train set size:", len(train_dataset))
print("Validation set size:", len(val_dataset))


> # **Training**

In [None]:
class SSIM_MSE_Loss(torch.nn.Module):
    def __init__(self):
        """
        Initialize the combined SSIM + MSE loss without weighting.
        """
        super(SSIM_MSE_Loss, self).__init__()

    def forward(self, predicted, target):
        # Calculate SSIM loss
        ssim_loss = 1 - structural_similarity_index_measure(predicted, target, data_range=1.0)

        # Calculate MSE loss
        mse_loss = F.mse_loss(predicted, target)

        # Combined loss (simple addition)
        combined_loss = ssim_loss + mse_loss
        return combined_loss

# Usage:
criterion = SSIM_MSE_Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)


In [None]:
criterion = SSIM_MSE_Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)

In [None]:
print(device)

In [None]:
def save_checkpoint(state, is_best, checkpoint_dir="/kaggle/working/checkpoint_attUNET", filename="checkpoint_MSE_SSIM.pth"):
    os.makedirs(checkpoint_dir, exist_ok=True)
    checkpoint_path = os.path.join(checkpoint_dir, filename)
    torch.save(state, checkpoint_path)
    if is_best:
        best_model_path = os.path.join(checkpoint_dir, "best_model_MSE_SSIM.pth")
        torch.save(state, best_model_path)
        print(f"Best model saved to {best_model_path}")


In [None]:
import torch
import monai
from monai.losses import SSIMLoss


train_loss_values = []
val_loss_values = []

num_epochs = 50
best_val_loss = float('inf')
patience = 10
early_stop_counter = 0

for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0

    # Training loop
    for imgs, targets in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} (Training)'):
        imgs, targets = imgs.to(device), targets.to(device)
        optimizer.zero_grad()  # Clear gradients
        outputs = model(imgs)  # Forward pass
        loss = criterion(outputs, targets)  # Compute loss
        loss.backward()  # Backward pass
        optimizer.step()  # Update weights

        train_loss += loss.item()

    # Average training loss for this epoch
    train_loss /= len(train_loader)
    train_loss_values.append(train_loss)

    # Validation loop
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for imgs, targets in tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs} (Validation)'):
            imgs, targets = imgs.to(device), targets.to(device)
            outputs = model(imgs)  # Forward pass
            loss = criterion(outputs, targets)  # Compute validation loss
            val_loss += loss.item()  # Accumulate validation loss

    # Average validation loss for this epoch
    val_loss /= len(val_loader)
    val_loss_values.append(val_loss)

    # Print the losses for this epoch
    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

    is_best = val_loss < best_val_loss
    if is_best:
        best_val_loss = val_loss
        early_stop_counter = 0

        # Save the best model checkpoint
        checkpoint = {
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss
        }
        save_checkpoint(checkpoint, is_best)  # Ensure save_checkpoint function is defined
    else:
        early_stop_counter += 1

    # Early stopping condition
    if early_stop_counter >= patience:
        print(f"Early stopping at epoch {epoch+1}. No improvement in validation loss for {patience} consecutive epochs.")
        break

# Print the best validation loss
print(f"Best Validation Loss: {best_val_loss:.4f}")

# Plotting the loss
plt.plot(train_loss_values, label='Training Loss')
plt.plot(val_loss_values, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss Over Epochs')
plt.legend()
plt.grid()
plt.show()


**Restarting traning from checkpoint**

In [None]:
def load_checkpoint(filepath):
    if os.path.isfile(filepath):
        print(f"Loading checkpoint from '{filepath}'...")
        checkpoint = torch.load(filepath)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch']
        best_val_loss = checkpoint['val_loss']
        train_loss_values = checkpoint.get('train_loss_values', [])
        val_loss_values = checkpoint.get('val_loss_values', [])
        print(f"Resuming from epoch {start_epoch} with best validation loss {best_val_loss:.4f}")
        return start_epoch, best_val_loss, train_loss_values, val_loss_values
    else:
        print("No checkpoint found, starting from scratch.")
        return 0, float('inf'), [], []


# Load the checkpoint (if it exists)
checkpoint_path = '/kaggle/input/checkpoint/checkpoint_MSE_SSIM.pth'
start_epoch, best_val_loss, train_loss_values, val_loss_values = load_checkpoint(checkpoint_path)

# Training parameters
num_epochs = 50
patience = 10
early_stop_counter = 0

for epoch in range(start_epoch, num_epochs):
    model.train()
    train_loss = 0.0

    # Training loop
    for imgs, targets in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} (Training)'):
        imgs, targets = imgs.to(device), targets.to(device)
        optimizer.zero_grad()  # Clear gradients
        outputs = model(imgs)  # Forward pass
        loss = criterion(outputs, targets)  # Compute loss
        loss.backward()  # Backward pass
        optimizer.step()  # Update weights

        train_loss += loss.item()

    # Average training loss for this epoch
    train_loss /= len(train_loader)
    train_loss_values.append(train_loss)

    # Validation loop
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for imgs, targets in tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs} (Validation)'):
            imgs, targets = imgs.to(device), targets.to(device)
            outputs = model(imgs)  # Forward pass
            loss = criterion(outputs, targets)  # Compute validation loss
            val_loss += loss.item()  # Accumulate validation loss

    # Average validation loss for this epoch
    val_loss /= len(val_loader)
    val_loss_values.append(val_loss)

    # Print the losses for this epoch
    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

    # Check if the current validation loss is the best we've seen
    is_best = val_loss < best_val_loss
    if is_best:
        best_val_loss = val_loss
        early_stop_counter = 0

        # Save the best model checkpoint
        checkpoint = {
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
            'train_loss_values': train_loss_values,
            'val_loss_values': val_loss_values
        }
        save_checkpoint(checkpoint, is_best)  # Save checkpoint with current best model
    else:
        early_stop_counter += 1

    # Early stopping condition
    if early_stop_counter >= patience:
        print(f"Early stopping at epoch {epoch+1}. No improvement in validation loss for {patience} consecutive epochs.")
        break

# Print the best validation loss
print(f"Best Validation Loss: {best_val_loss:.4f}")

# Plotting the loss
plt.plot(train_loss_values, label='Training Loss')
plt.plot(val_loss_values, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss Over Epochs')
plt.legend()
plt.grid()
plt.show()


In [None]:
torch.save(model.state_dict(), '3d_AttUnet_slices_spatial_SSIM_MSE.pth')

In [None]:
dataset_test = BrainMRIDataset(base_path='/kaggle/input/openbhb/val_quasiraw', volume_list=test_volumes)

Slices_test = SliceDatasetFromList(dataset_test)

test_loader = DataLoader(Slices_test, batch_size=8)


In [None]:
import matplotlib.pyplot as plt

# Test loop
test_loss = 0.0

model.eval()

with torch.no_grad():  # Disable gradient calculations for testing
    for imgs, targets in tqdm(test_loader, desc='Testing'):
        imgs, targets = imgs.to(device), targets.to(device)  # Move data to GPU
        outputs = model(imgs)  # Forward pass
        loss = criterion(outputs, targets)  # Compute test loss
        test_loss += loss.item()  # Accumulate test loss

        # Visualize the input and reconstructed volume at depth 8
        depth = 8
        fig, ax = plt.subplots(1, 2, figsize=(12, 6))

        # Original Volume (Target)
        ax[0].imshow(targets[0, 0, depth, :, :].cpu().numpy(), cmap='gray')
        ax[0].set_title('Original Volume (Target)')
        ax[0].axis('off')

        # Reconstructed Volume (Output)
        ax[1].imshow(outputs[0, 0, depth, :, :].cpu().numpy(), cmap='gray')
        ax[1].set_title('Reconstructed Volume (Output)')
        ax[1].axis('off')

        plt.show()  # Display the plot

# Average test loss
test_loss /= len(test_loader)
print(f"Average Test Loss: {test_loss:.4f}")


In [None]:
import nibabel as nib

# Define the root directory where all subject subdirectories are stored
data_dir = '/kaggle/input/brats20-dataset-training-validation/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/'

# Define functions
def load_nifti_file(file_path):
    return nib.load(file_path).get_fdata()

def crop(mri_volume, crop_size=(160, 130, 170)):
    d, h, w = mri_volume.shape
    new_d, new_h, new_w = crop_size
    start_d = (d - new_d) // 2
    start_h = (h - new_h) // 2
    start_w = (w - new_w) // 2
    return mri_volume[:, start_h:start_h+new_h, start_w:start_w+new_w]

def resize_volume(mri_volume):
    mri_tensor = torch.tensor(mri_volume, dtype=torch.float32)
    mri_resized = F.interpolate(
        mri_tensor.unsqueeze(0).unsqueeze(0),  # Add batch and channel dimensions
        size=(mri_tensor.shape[0], 128, 128),  # Only resize height and width
        mode='trilinear',
        align_corners=False
    )
    return mri_resized

def get_slices(mri_volume):
    slices = []
    for i in range(8):
        start = i * 16
        end = start + 16
        slice_chunk = mri_volume[:, :, start:end, :, :]
        slices.append(slice_chunk)
    return slices

# Metric function for residual map evaluation
def calculate_residuals(true, predicted):
    return torch.abs(true - predicted)

# Dice score calculation function
def dice_score(true, predicted, threshold=0.5):
    true = true.cpu().numpy().flatten()
    predicted = (predicted.cpu().numpy() > threshold).astype(int).flatten()
    intersection = (true * predicted).sum()
    return (2. * intersection) / (true.sum() + predicted.sum())

# Loop through all subject subdirectories
subjects = [os.path.join(data_dir, subj) for subj in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, subj))]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

subject_count = 10
curr_count = 0

for subject_dir in subjects:
    if curr_count < subject_count:
        # Construct paths for the T2-weighted MRI and segmentation files
        t2_path = os.path.join(subject_dir, [f for f in os.listdir(subject_dir) if '_t2.nii' in f][0])
        seg_path = t2_path.replace('_t2.nii', '_seg.nii')

        # Load the MRI and segmentation volumes
        t2_vol = load_nifti_file(t2_path)
        seg_vol = load_nifti_file(seg_path)
        


        # Preprocess the T2 MRI volume
        transpose_vol = np.transpose(t2_vol, (2, 0, 1))  # Reorient the volume
        cropped_vol = crop(transpose_vol)  # Crop
        resized_vol = resize_volume(cropped_vol)  # Resize

        # Preprocess the segmentation volume in the same way
        transpose_seg_vol = np.transpose(seg_vol, (2, 0, 1))  # Reorient the segmentation
        cropped_seg_vol = crop(transpose_seg_vol)  # Crop
        resized_seg_vol = resize_volume(cropped_seg_vol)  # Resize the segmentation to match the MRI

        # Split the volume into slices (both MRI and segmentation)
        slices_vol = get_slices(resized_vol)
        slices_seg = get_slices(resized_seg_vol)

        # Perform inference using the model
        with torch.no_grad():
            tensor_vol_slice = slices_vol[4].to(device)  # Taking the 5th chunk as an example (index 4)
            output = model(tensor_vol_slice)

            # Calculate the residuals
            residuals = calculate_residuals(tensor_vol_slice, output)

            # Extract the corresponding segmentation slice
            tensor_seg_slice = slices_seg[4].to(device)

            # Calculate Dice score as a metric
            dice = dice_score(tensor_seg_slice, residuals)

        # Plot example slices
        depth = 8  # Slice depth to visualize
        fig, ax = plt.subplots(1, 4, figsize=(24, 6))

        # Plot the original MRI volume slice
        ax[0].imshow(tensor_vol_slice[0, 0, depth, :, :].cpu().numpy(), cmap='gray')
        ax[0].set_title('Original Volume')

        # Plot the reconstructed volume slice
        ax[1].imshow(output[0, 0, depth, :, :].cpu().numpy(), cmap='gray')
        ax[1].set_title('Reconstructed Volume')

        # Plot the residual map slice
        ax[2].imshow(residuals[0, 0, depth, :, :].cpu().numpy(), cmap='hot')
        ax[2].set_title(f'Residuals (Dice: {dice:.4f})')

        # Plot the corresponding segmentation map slice
        ax[3].imshow(tensor_seg_slice[0, 0, depth, :, :].cpu().numpy(), cmap='gray')
        ax[3].set_title('Segmentation Map')

        plt.show()

        print(f"Processed {subject_dir}: Dice Score: {dice:.4f}")
    
    curr_count += 1