In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import glob
import cv2
import torch.nn.functional as F
import os
import time
from sys import maxsize
import SimpleITK as sitk
import wandb
import optuna
import nibabel as nib
from torchmetrics import JaccardIndex
from torchvision.transforms import CenterCrop
import wandb
# import torch.multiprocessing as mp
# import torch.distributed as dist
# from torch.nn.parallel import DistributedDataParallel as DDP
# from torch.utils.data.distributed import DistributedSampler

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mviraj-patil0911[0m ([33mbrats_2024_3d_unet[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [43]:
# prompt: cuda or cpu check

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [44]:
base_dir = '/data/mpstme-priyanka/training_data1_v2'

In [5]:
def center_crop(tensor, target_dims):
    
        current_dims= tensor.shape
        start_indices = [(curr_dim - target_dim) // 2 for curr_dim, target_dim in zip(current_dims, target_dims)]
        end_indices = [start + target_dim for start, target_dim in zip(start_indices, target_dims)]

        cropped_data = tensor[
            start_indices[0]:end_indices[0],
            start_indices[1]:end_indices[1],
            start_indices[2]:end_indices[2],

        ]

        return cropped_data

In [6]:
new_image_filenames = []
new_label_filenames = []

total_samples = 0
break_flag = 0

for subfolder in os.listdir(base_dir):
    subfolder_path = os.path.join(base_dir, subfolder)
    
    if os.path.isdir(subfolder_path):
        # Find image files with '-t2f.nii.gz'
        image_files = glob.glob(subfolder_path + '/*-t2f.nii.gz')
        # Find corresponding label files with '-seg.nii.gz'
        label_files = glob.glob(subfolder_path + '/*-seg.nii.gz')

        for image_file, label_file in zip(image_files, label_files):

            new_image_filenames.append(image_file)
            new_label_filenames.append(label_file)
        
            total_samples += 1
            print(f'total samples {total_samples}')
            
#             if total_samples == 10:
#                 break_flag = 1
                
            if break_flag:
                break
                
        if break_flag:
            break

print(f"Filtered {len(new_image_filenames)} images and {len(new_label_filenames)} labels.")

total samples 1
total samples 2
total samples 3
total samples 4
total samples 5
total samples 6
total samples 7
total samples 8
total samples 9
total samples 10
total samples 11
total samples 12
total samples 13
total samples 14
total samples 15
total samples 16
total samples 17
total samples 18
total samples 19
total samples 20
total samples 21
total samples 22
total samples 23
total samples 24
total samples 25
total samples 26
total samples 27
total samples 28
total samples 29
total samples 30
total samples 31
total samples 32
total samples 33
total samples 34
total samples 35
total samples 36
total samples 37
total samples 38
total samples 39
total samples 40
total samples 41
total samples 42
total samples 43
total samples 44
total samples 45
total samples 46
total samples 47
total samples 48
total samples 49
total samples 50
total samples 51
total samples 52
total samples 53
total samples 54
total samples 55
total samples 56
total samples 57
total samples 58
total samples 59
total 

In [7]:
class BraTSDatasetII(Dataset):

    def __init__(self, new_image_filenames,new_label_filenames):
        self.image_names = new_image_filenames
        self.label_mask_names = new_label_filenames


    def __len__(self):
        return len(self.image_names)
    
    
    def convert_nonzero_to_one(self, mask_tensor):
        for i in range(mask_tensor.shape[0]):
            for j in range(mask_tensor.shape[1]):
                for k in range(mask_tensor.shape[2]):
                    if int(mask_tensor[i, j, k]) > 1:
                        mask_tensor[i, j, k] = 1.0
                        
        return mask_tensor
        
        
    def __getitem__(self, idx):
        # Reading NIfTI image file
        image = sitk.ReadImage(self.image_names[idx])

        # Reading NIfTI label file
        label_mask = sitk.ReadImage(self.label_mask_names[idx])

        # Desired output size
        desired_size = (128, 128, 128)  # Target shape

        # Get original size
        original_size = image.GetSize()  # Original size in voxels (x, y, z)

        # Calculate the new spacing based on desired size
        original_spacing = image.GetSpacing()  # Current spacing in mm
        new_spacing = [
            original_spacing[i] * (original_size[i] / desired_size[i]) for i in range(3)
        ]

        # Resize the image to the desired size
        resampled_img = sitk.Resample(
            image,
            desired_size,
            sitk.Transform(),
            sitk.sitkLinear,  # Linear interpolation for images
            image.GetOrigin(),
            new_spacing,
            image.GetDirection(),
            0,
            image.GetPixelID()
        )

        # Resize the mask to the desired size
        resampled_mask = sitk.Resample(
            label_mask,
            desired_size,
            sitk.Transform(),
            sitk.sitkNearestNeighbor,  # Nearest-neighbor for label masks
            label_mask.GetOrigin(),
            new_spacing,
            label_mask.GetDirection(),
            0,
            label_mask.GetPixelID()
        )

        # Convert to NumPy arrays
        image_data = sitk.GetArrayFromImage(resampled_img)
        label_mask_data = sitk.GetArrayFromImage(resampled_mask)

        # Convert NumPy arrays to torch tensors
        image_tensor = torch.from_numpy(image_data).unsqueeze(0).float()  # Add channel dimension
        label_tensor = torch.from_numpy(label_mask_data.astype(np.float32))

        # Convert all non-zero values in the mask to 1
        label_tensor = self.convert_nonzero_to_one(label_tensor)

        # Normalize the image tensor
        image_tensor = (image_tensor - torch.mean(image_tensor)) / torch.std(image_tensor)

        return image_tensor, label_tensor

In [8]:
dataset = BraTSDatasetII(new_image_filenames,new_label_filenames)

In [9]:
train_ratio = 0.9
train_split = int(train_ratio * len(dataset))
test_split = int(len(dataset) - train_split)

train_dataset, test_dataset = random_split(dataset, [train_split, test_split])

In [10]:
batch_size = 4
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory = True)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True, pin_memory = True)

In [11]:
wandb.init(
    project="brats 3d unet"

)

In [12]:
# !nvidia-smi

In [13]:
class UNet3DD(nn.Module):
    def __init__(self, verbose):
        super(UNet3DD, self).__init__()
        
        self.verbose = verbose
        
        def conv_block(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
                nn.BatchNorm3d(out_channels),
                nn.LeakyReLU(negative_slope=0.01),
                nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
                nn.BatchNorm3d(out_channels),
                nn.LeakyReLU(negative_slope=0.01)
            )
        
        def up_conv_block(in_channels, out_channels):
            return nn.ConvTranspose3d(in_channels, out_channels, kernel_size=2, stride=3, padding=1)
        
        # Downward path (encoding)
        self.conv_downwards = nn.ModuleList([
            conv_block(1, 64),  # Conv block 1
            nn.MaxPool3d(kernel_size=2, stride=2),  # MaxPool after block 1
            conv_block(64, 128),  # Conv block 2
            nn.MaxPool3d(kernel_size=2, stride=2),  # MaxPool after block 2
            conv_block(128, 256),  # Conv block 3
            nn.MaxPool3d(kernel_size=2, stride=2),  # MaxPool after block 3
            conv_block(256, 512),  # Conv block 4
            nn.MaxPool3d(kernel_size=2, stride=2)   # MaxPool after block 4
        ])
        
        # Bottleneck layer
        self.bottleneck_layer = nn.Sequential(
            nn.Conv3d(512, 1024, kernel_size=3, padding=1),
            nn.BatchNorm3d(1024),
            nn.LeakyReLU(negative_slope=0.01),
            nn.Conv3d(1024, 1024, kernel_size=3, padding=1),
            nn.BatchNorm3d(1024),
            nn.LeakyReLU(negative_slope=0.01)
        )
        
        # Upward path (decoding)
        self.conv_upwards = nn.ModuleList([
            up_conv_block(1024, 512),  # UpConv block 1
            conv_block(1024, 512),  # Conv block after UpConv 1
            up_conv_block(512, 256),  # UpConv block 2
            conv_block(512, 256),  # Conv block after UpConv 2
            up_conv_block(256, 128),  # UpConv block 3
            conv_block(256, 128),  # Conv block after UpConv 3
            up_conv_block(128, 64),  # UpConv block 4
            conv_block(128, 64)  # Conv block after UpConv 4
        ])
        
        # Final 1x1 Conv layer
        self.conv_1x1 = nn.Conv3d(64, 1, kernel_size=1, padding=1)
        
        # Sigmoid
        self.sigmoid = nn.Sigmoid()
        
        
        
    def center_crop(self, tensor, target_dims):

        current_dims= tensor.shape
        start_indices = [(curr_dim - target_dim) // 2 for curr_dim, target_dim in zip(current_dims, target_dims)]
        end_indices = [start + target_dim for start, target_dim in zip(start_indices, target_dims)]

        cropped_data = tensor[
            :,
            :, 
            start_indices[2]:end_indices[2],
            start_indices[3]:end_indices[3],
            start_indices[4]:end_indices[4],

        ]

        return cropped_data

    def forward(self, x):
        
        original_x_shape = x.shape
        # Downward path
        skip_connections = []
        for layer in self.conv_downwards:
            if isinstance(layer, nn.MaxPool3d):
                skip_connections.append(x)  # Save skip connection before MaxPool
            
            if self.verbose:
                print(f'applying {layer} => {x.shape}\n\n')
            x = layer(x)

        # Bottleneck
        x = self.bottleneck_layer(x)
        
        if self.verbose:
            print(f'applying bottleneck layer {self.bottleneck_layer} => {x.shape}\n\n')
        
        # Upward path
        for i in range(0, len(self.conv_upwards), 2):
            x = self.conv_upwards[i](x)  # UpConvolution
            if self.verbose:
                print(f'applying {self.conv_upwards[i]} => {x.shape}')
              
            popped_x = skip_connections.pop()
            
            if self.verbose:
                print(f'CONCATENATING')
                print('popped_x shape is ', popped_x.shape)
                print(f'x shape before concatenating is {x.shape}\n\n')
            x = self.center_crop(x, popped_x.shape)
            x = torch.cat((popped_x, x), dim=1)  # Concatenate with skip connection
            x = self.conv_upwards[i + 1](x)  # Conv block after concatenation
            if self.verbose:
                print(f'applying layer {self.conv_upwards[i + 1]} => {x.shape}\n\n')

        # Final 1x1 Conv
        x = self.conv_1x1(x)
        
        if self.verbose:
            print(f'applying conv 1x1 => {x.shape}\n\n')
        x = self.sigmoid(x)
        
        x = self.center_crop(x, original_x_shape)
        
        
        return x


In [14]:
#### DICE LOSS ####

class DiceLoss(nn.Module):
    
    def __init__(self, smooth=1e-6):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
    
    def dice_coeff(self, pred, target, smooth=1e-6):
        """
        Compute the Dice coefficient for binary classification.

        Parameters:
        - pred: Predicted tensor (probabilities).
        - target: Ground truth tensor (binary).
        - smooth: Smoothing factor to avoid division by zero.

        Returns:
        - Dice coefficient (mean over batch).
        """
        # Flatten tensors
        pred = pred.view(-1)
        target = target.view(-1)

        # Compute intersection and union
        intersection = (pred * target).sum()
        union = pred.sum() + target.sum()

        # Compute Dice coefficient
        dice = (2. * intersection + smooth) / (union + smooth)
        return dice

    def dice_loss(self, pred, target):
        """
        Compute the Dice loss for binary classification.

        Parameters:
        - pred: Predicted tensor (logits or probabilities).
        - target: Ground truth tensor (binary).
        - smooth: Smoothing factor to avoid division by zero.

        Returns:
        - Dice loss (1 - Dice coefficient).
        """
        # Apply sigmoid to get probabilities if not already applied
        pred = torch.sigmoid(pred)

        # Compute Dice coefficient
        dice = self.dice_coeff(pred, target, smooth=self.smooth)

        # Dice loss is 1 - Dice coefficient
        return 1 - dice

    def forward(self, pred, target):
        return self.dice_loss(pred, target)


In [15]:
#### Tversky  LOSS ####

class TverskyLoss(nn.Module):
    
    def __init__(self, alpha=0.7, beta=0.3, smooth=1e-6):
        super(TverskyLoss, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.smooth = smooth
    
    def forward(self, pred, target):
        """
        Compute the Tversky loss for binary classification.

        Parameters:
        - pred: Predicted tensor (logits or probabilities).
        - target: Ground truth tensor (binary).

        Returns:
        - Tversky loss.
        """

        # Flatten tensors
        pred = pred.reshape(-1)
        target = target.reshape(-1)

        # Compute true positives, false positives, and false negatives
        true_positive = (pred * target).sum()
        false_positive = ((1 - target) * pred).sum()
        false_negative = (target * (1 - pred)).sum()

        # Compute Tversky coefficient
        tversky_coeff = (true_positive + self.smooth) / (true_positive + self.alpha * false_positive + self.beta * false_negative + self.smooth)
        
        # Tversky loss is 1 - Tversky coefficient
        return 1 - tversky_coeff



In [42]:
unet3d = UNet3DD(verbose=0)

In [None]:
dice_loss_fn = DiceLoss(smooth=1e-6)
tversky_loss_fn = TverskyLoss(alpha=0.4, beta=0.6, smooth=1e-6)
bce_loss_fn = nn.BCELoss()

In [None]:
# data, masks = next(iter(train_dataloader))
# data, masks = data.to(device), masks.to(device)

In [None]:
# data.shape

In [None]:
# output = unet3d(data)

In [None]:
# unet3d(data)

In [None]:
def objective(trial: optuna.Trial):
    lr = trial.suggest_float("lr", 0.05, 0.1)
    weight_decay = trial.suggest_float("weight_decay", 1e-5, 1e-4)
    beta1 = trial.suggest_float("beta1", 0.7, 0.99)
    beta2 = trial.suggest_float("beta2", 0.95, 0.99)
    
    optimizer = torch.optim.Adam(unet3d.parameters(), lr=lr, weight_decay=weight_decay, betas=(beta1, beta2))
    
    data, masks = next(iter(train_dataloader))
    data, masks = data.to(device), masks.to(device)
    
    loss = loss_fn(unet3d(data), masks)
    
    del data
    del masks
    
    return loss

In [None]:
# study = optuna.create_study()
# study.optimize(objective, n_trials=100)

In [None]:
# optimizer = torch.optim.Adam(unet3d.parameters(), lr=study.best_params['lr'], weight_decay=study.best_params['weight_decay'], betas=(study.best_params['beta1'], study.best_params['beta2']))
optimizer = torch.optim.Adam(unet3d.parameters(), lr=3e-4, betas=(0.9, 0.999), weight_decay = 1e-4)

In [None]:
epochs = 5

In [None]:
checkpoint_identifier = f'20240928_' + str(time.strftime("%H:%M"))

# Create a directory to save checkpoints
checkpoint_dir = "/data/mpstme-priyanka/" + checkpoint_identifier + "_tversky_checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

for epoch in range(epochs):
    epoch_loss = 0

    for batch_idx, (data, masks) in enumerate(train_dataloader):

        data = data.to(device)
        masks = masks.long().to(device)

        optimizer.zero_grad()
    
        pred = unet3d(data)
        target = masks
        
        train_loss = dice_loss_fn(pred, target) + bce_loss_fn(pred.squeeze(1), target.float()) + tversky_loss_fn(pred, target)

        print(f'Batch {batch_idx + 1} : Loss {train_loss}')
        wandb.log({"train_loss": train_loss.item()})
        
        if batch_idx % 5 == 0:
            
            # Getting test data and masks
            test_data, test_masks = next(iter(test_dataloader))
            test_data, test_masks = test_data.to(device), test_masks.to(device)

            test_pred = unet3d(test_data)
            test_loss = loss_fn(test_pred, test_masks)

            wandb.log({"test_loss": test_loss.item()})

            fig, axes = plt.subplots(2, 2, figsize=(10, 10))

            # Train Prediction
            axes[0, 0].imshow(pred[0][0][64].cpu().detach().numpy(), cmap='gray')
            axes[0, 0].set_title("Train Prediction")
            axes[0, 0].axis('off')

            # Train Target
            axes[0, 1].imshow(target[0][64].cpu().detach().numpy(), cmap='gray')
            axes[0, 1].set_title("Train Target")
            axes[0, 1].axis('off')
            
            axes[1, 0].imshow(test_pred[0][0][64].cpu().detach().numpy(), cmap='gray')
            axes[1, 0].set_title("Test Prediction")
            axes[1, 0].axis('off')

            # Train Target
            axes[1, 1].imshow(test_masks[0][64].cpu().detach().numpy(), cmap='gray')
            axes[1, 1].set_title("Test Target")
            axes[1, 1].axis('off')

            # Adjust layout
            plt.tight_layout()
            wandb.log({"Predictions": wandb.Image(fig)})
            plt.close(fig)

        epoch_loss += train_loss.item()
        
        train_loss.backward()
        optimizer.step()
        
    # Save a checkpoint after each epoch
    checkpoint_path = os.path.join(checkpoint_dir, f'{checkpoint_identifier}_checkpoint_epoch_{epoch + 1}.pth')
    torch.save({
        'epoch': epoch + 1,
        'model_state_dict': unet3d.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': epoch_loss,
    }, checkpoint_path)
    
    print(f'EPOCH {epoch + 1} ; LOSS {epoch_loss}\n\n')


In [None]:
!nvidia-smi

In [None]:
# output = unet3d(data)

In [None]:
# mask_file = sitk.GetImageFromArray(target[0].cpu().detach().numpy())
# mask_file = sitk.Cast(mask_file, sitk.sitkUInt8)

# pred_file = sitk.GetImageFromArray(output[0][0].cpu().detach().numpy())

# sitk.WriteImage(mask_file, "masks_20240923.nii.gz")
# sitk.WriteImage(pred_file, "pred_20240923.nii.gz")

In [None]:
# torch.save(unet3d.state_dict(), 'brats_3d_unet_epochs_36_20240928.pth')

## VALIDATION

In [None]:
# val_dir = base_dir = '/data/mpstme-priyanka/BraTS2024-BraTS-GLI-ValidationData/validation_data'

In [None]:
# new_image_filenames = []
# total_samples = 0
# break_flag = 0

# for subfolder in os.listdir(val_dir):
#     subfolder_path = os.path.join(val_dir, subfolder)
    
#     if os.path.isdir(subfolder_path):
#         # Find image files with '-t2f.nii.gz'
#         image_files = glob.glob(subfolder_path + '/*-t2f.nii.gz')

#         for image_file in image_files:  # Iterate directly over image_files
#             image = sitk.ReadImage(image_file)
            
#             # Cast the image to float32
#             image = sitk.Cast(image, sitk.sitkFloat32)
            
#             # Convert to a NumPy array
#             image_data = sitk.GetArrayFromImage(image)
#             image_data_shape = image_data.shape

#             # Optionally, check size conditions if needed
#             # if image_data_shape[0] >= 182 and image_data_shape[1] >= 218 and image_data_shape[2] >= 182:
#             new_image_filenames.append(image_file)
           
#             total_samples += 1
#             print(f'Total samples: {total_samples}')
            
#             if total_samples == 4:
#                 break_flag = 1
                
#             if break_flag:
#                 break
                
#         if break_flag:
#             break

# print(f"Filtered {len(new_image_filenames)} images")


In [None]:
# class BraTSDataset_val(Dataset):

#     def __init__(self, new_image_filenames):
#         self.image_names = new_image_filenames

#     def __len__(self):
#         return len(self.image_names)

#     def __getitem__(self, idx):
#         # Load the NIfTI file
#         image = sitk.ReadImage(self.image_names[idx])
        
#         # Downsample the image
#         new_size = [int(dim / 2) for dim in image.GetSize()]
#         downsampled_img = sitk.Resample(
#             image,
#             new_size,
#             sitk.Transform(),
#             sitk.sitkLinear,
#             image.GetOrigin(),
#             [sz * 2 for sz in image.GetSpacing()],
#             image.GetDirection(),
#             0,
#             image.GetPixelID()
#         )

#         # Convert to a NumPy array
#         image_data = sitk.GetArrayFromImage(downsampled_img)

#         # Reshape data to add channel (1) and convert to torch tensor
#         image_tensor = torch.from_numpy(image_data).unsqueeze(0).float()
        
#         # Normalize the image tensor
#         image_tensor = (image_tensor - torch.mean(image_tensor)) / torch.std(image_tensor)

#         return image_tensor

In [None]:
# val_dataset = BraTSDataset_val(new_image_filenames)

In [None]:
# val_dataloader =  DataLoader(val_dataset, batch_size=batch_size, shuffle=True)


In [None]:
# val_data = next(iter(val_dataloader))
# val_data = val_data.to(device)

In [None]:
# val_out = unet3d(val_data)

In [None]:
# masks.shape

In [None]:
# # mask_file = sitk.GetImageFromArray(target[0].cpu().detach().numpy())
# # mask_file = sitk.Cast(mask_file, sitk.sitkUInt8)

# mask_new_shape = sitk.GetImageFromArray(masks[0].cpu().detach().numpy())

# # sitk.WriteImage(mask_file, "masks_20240923.nii.gz")
# sitk.WriteImage(mask_new_shape, "mask_new_shape_20240928.nii.gz")

In [35]:
!nvidia-smi

Sat Sep 28 11:16:53 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.86.10              Driver Version: 535.86.10    CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA H100 PCIe               Off | 00000000:CA:00.0 Off |                   On |
| N/A   29C    P0              52W / 350W |                  N/A |     N/A      Default |
|                                         |                      |              Enabled |
+-----------------------------------------+----------------------+----------------------+

+------------------------------------------------------------------