In [None]:
!pip install SimpleITK
!pip install wandb

Collecting SimpleITK
  Downloading SimpleITK-2.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.9 kB)
Downloading SimpleITK-2.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (52.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m52.4/52.4 MB[0m [31m20.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: SimpleITK
Successfully installed SimpleITK-2.4.0


In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
from torch import nn
import cv2
import glob
import pandas as pd
import SimpleITK as sitk
import random
import wandb
import albumentations as A
from albumentations.pytorch import ToTensorV2
import shutil
import re
import os

  check_for_updates()


In [None]:
wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33maaryadev[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [None]:
wandb.init(
    # set the wandb project where this run will be logged
    project="brats 3d tedunet",
)


In [None]:

# prompt: drive

from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# input_tensor = torch.randn(1, 1, 91, 109, 91)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
def center_crop(tensor, target_dims):

        current_dims= tensor.shape
        start_indices = [(curr_dim - target_dim) // 2 for curr_dim, target_dim in zip(current_dims, target_dims)]
        end_indices = [start + target_dim for start, target_dim in zip(start_indices, target_dims)]

        cropped_data = tensor[
            start_indices[0]:end_indices[0],
            start_indices[1]:end_indices[1],
            start_indices[2]:end_indices[2],

        ]

        return cropped_data

In [None]:
class TEDUNet(nn.Module):
    def __init__(self, verbose):
        super(TEDUNet, self).__init__()

        self.verbose = verbose

        def conv_block(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
                nn.BatchNorm3d(out_channels),
                nn.LeakyReLU(),
                nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
                nn.BatchNorm3d(out_channels),
                nn.LeakyReLU()
            )

        def up_conv_block(in_channels, out_channels):
            return nn.ConvTranspose3d(in_channels, out_channels, kernel_size=2, stride=3, padding=1)

        # Downward path (encoding)
        self.conv_downwards = nn.ModuleList([
            conv_block(1, 64),
            nn.MaxPool3d(kernel_size=2, stride=1),
            conv_block(64, 128),
            nn.MaxPool3d(kernel_size=2, stride=2),
            conv_block(128, 256),
            nn.MaxPool3d(kernel_size=2, stride=1),
            conv_block(256, 512),
            nn.MaxPool3d(kernel_size=2, stride=2),
        ])

        # Bottleneck layer
        self.bottleneck_layer = nn.Sequential(
            nn.Conv3d(512, 1024, kernel_size=3, padding=1),
            nn.BatchNorm3d(1024),
            nn.LeakyReLU(),
            nn.Conv3d(1024, 1024, kernel_size=3, padding=1),
            nn.BatchNorm3d(1024),
            nn.LeakyReLU()
        )

        # Upward path (decoding)
        self.conv_upwards = nn.ModuleList([
            up_conv_block(1024, 512),
            conv_block(1024, 512),
            up_conv_block(512, 256),
            conv_block(512, 256),
            up_conv_block(256, 128),
            conv_block(256, 128),
            up_conv_block(128, 64),
            conv_block(128, 64),
        ])

        # Final 1x1 Conv layer
        self.conv_1x1 = nn.Conv3d(64, 1, kernel_size=1, padding=1)

        # Softmax
        self.sigmoid = nn.Sigmoid()



    def center_crop(self, tensor, target_dims):

        current_dims= tensor.shape
        start_indices = [(curr_dim - target_dim) // 2 for curr_dim, target_dim in zip(current_dims, target_dims)]
        end_indices = [start + target_dim for start, target_dim in zip(start_indices, target_dims)]

        cropped_data = tensor[
            :,
            :,
            start_indices[2]:end_indices[2],
            start_indices[3]:end_indices[3],
            start_indices[4]:end_indices[4],

        ]

        return cropped_data

    def forward(self, x):

        original_x_shape = x.shape
        # Downward path
        skip_connections = []
        for layer in self.conv_downwards:
            if isinstance(layer, nn.MaxPool3d):
                skip_connections.append(x)  # Save skip connection before MaxPool

            if self.verbose:
                print(f'applying {layer} => {x.shape}\n\n')
            x = layer(x)

        # Bottleneck
        x = self.bottleneck_layer(x)

        if self.verbose:
            print(f'applying bottleneck layer {self.bottleneck_layer} => {x.shape}\n\n')

        # Upward path
        for i in range(0, len(self.conv_upwards), 2):
            x = self.conv_upwards[i](x)
            if self.verbose:
                print(f'applying {self.conv_upwards[i]} => {x.shape}')

            popped_x = skip_connections.pop()

            if self.verbose:
                print(f'CONCATENATING')
                print('popped_x shape is ', popped_x.shape)
                print(f'x shape before concatenating is {x.shape}\n\n')
            x = self.center_crop(x, popped_x.shape)
            x = torch.cat((popped_x, x), dim=1)  # Concatenate with skip connection
            x = self.conv_upwards[i + 1](x)
            if self.verbose:
                print(f'applying layer {self.conv_upwards[i + 1]} => {x.shape}\n\n')

        # Final 1x1 Conv
        x = self.conv_1x1(x)

        if self.verbose:
            print(f'applying conv 1x1 => {x.shape}\n\n')
        x = self.sigmoid(x)

        x = self.center_crop(x, original_x_shape)


        return x


## Convert Sample to Sliced Spaces

In [None]:
subject_folder_names = glob.glob('/content/drive/MyDrive/learning machine learning/T2F + Segmap Samples/*')
sliced_spaces_save_folder_name = "Space Slice Dataset"

# shutil.rmtree('./' + sliced_spaces_save_folder_name)

if sliced_spaces_save_folder_name not in os.listdir():
  os.mkdir(sliced_spaces_save_folder_name)

In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Define transformations
transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.RandomBrightnessContrast(p=0.2),
    A.ElasticTransform(p=0.1),
])


def augment(image, mask, transform=transform):
    augmented = transform(image=image, mask=mask)
    augmented_image = augmented["image"]
    augmented_mask = augmented["mask"]
    return augmented_image, augmented_mask


In [None]:
samples_to_write = 15000
init_stw = samples_to_write
break_outer_loop = False

subjects_read = 0

for subject in subject_folder_names:

  print(f'\n\nReading subject {subjects_read}')
  subjects_read += 1

  file_paths = sorted(glob.glob(subject + '/*'))

  segmap_path = file_paths[0]
  image_path = file_paths[1]

  sample_name = subject.split('/')[-1]

  image = sitk.ReadImage(image_path)
  image_array = sitk.GetArrayFromImage(image)

  segmap = sitk.ReadImage(segmap_path)
  segmap_array = sitk.GetArrayFromImage(segmap)



  # functoin to convert all class labels that are > 1, to 1
  # to convert this into a single class segmentation problem

  def np_convert_nonzero_to_one(mask_tensor):
    for i in range(mask_tensor.shape[0]):
      for j in range(mask_tensor.shape[1]):
        for k in range(mask_tensor.shape[2]):
          if int(mask_tensor[i, j, k]) > 1:
            mask_tensor[i, j, k] = 1.0

    return mask_tensor


  segmap_array = np_convert_nonzero_to_one(segmap_array)


  # performing jump based space divisions

  jump_coords = [26, 26, 26]


  # # lists to store each space slice to be saved
  # image_space_slices = []
  # segmap_space_slices = []


  # logic to slice a space out of a 3d array
  # each slice space is then saved as an nii.gz file to be read by the dataloader while training
  # space slices are created for both, the image and the segmentation mask

  it0 = 1

  while it0 * jump_coords[1] < image_array.shape[1]:

      it1 = 1
      while it1 * jump_coords[2] < image_array.shape[2]:
          it2 = 1
          while it2 * jump_coords[0] < image_array.shape[0]:

              # slicing space from image
              curr_image_slice = image_array[(it2 - 1) * jump_coords[0]:it2*jump_coords[0], (it0-1) * jump_coords[1]:it0*jump_coords[1], (it1-1)*jump_coords[2]:it1*jump_coords[2]]

              # slicing space from segmantation map
              curr_segmap_slice = segmap_array[(it2 - 1) * jump_coords[0]:it2*jump_coords[0], (it0-1) * jump_coords[1]:it0*jump_coords[1], (it1-1)*jump_coords[2]:it1*jump_coords[2]]

              # if 1.0 in list(set(curr_segmap_slice.reshape(-1))) or random.randint(0, 4) == 0:
              if sum(list(curr_segmap_slice.reshape(-1))) > 10.0:

                curr_image_slice, curr_segmap_slice = augment(curr_image_slice, curr_segmap_slice)
                # write new samples space slices
                sitk.WriteImage(sitk.GetImageFromArray(curr_image_slice), f'./Space Slice Dataset/{sample_name}{it0}_{it1}_{it2}_data.nii.gz')
                sitk.WriteImage(sitk.GetImageFromArray(curr_segmap_slice), f'./Space Slice Dataset/{sample_name}{it0}_{it1}_{it2}_seg.nii.gz')

                samples_to_write -= 1
                print(f'{init_stw - samples_to_write} sample/s written')


                if not samples_to_write:
                  break_outer_loop = True
                  break

              it2 += 1
          it1 += 1
          if break_outer_loop:
            break
      it0 += 1
      if break_outer_loop:
            break

  if break_outer_loop:
    break

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
5594 sample/s written
5595 sample/s written
5596 sample/s written
5597 sample/s written
5598 sample/s written
5599 sample/s written
5600 sample/s written
5601 sample/s written
5602 sample/s written
5603 sample/s written
5604 sample/s written
5605 sample/s written
5606 sample/s written
5607 sample/s written
5608 sample/s written
5609 sample/s written
5610 sample/s written
5611 sample/s written
5612 sample/s written
5613 sample/s written
5614 sample/s written
5615 sample/s written
5616 sample/s written
5617 sample/s written
5618 sample/s written
5619 sample/s written


Reading subject 229
5620 sample/s written
5621 sample/s written
5622 sample/s written
5623 sample/s written
5624 sample/s written
5625 sample/s written
5626 sample/s written
5627 sample/s written
5628 sample/s written
5629 sample/s written
5630 sample/s written
5631 sample/s written
5632 sample/s written
5633 sample/s written
5634 sample/s written
5635 sample

KeyError: dtype('int16')

## Create Dataset

In [None]:
# create Dataset for new sliced spaces

class SlicedSpacesDataset(torch.utils.data.Dataset):
  def __init__(self, dirname=sliced_spaces_save_folder_name):
    self.data_filenames = sorted(glob.glob(f'./{dirname}/*_data.nii.gz'))
    self.seg_filenames = sorted(glob.glob(f'./{dirname}/*_seg.nii.gz'))

    # fn to extract the four numbers from the filename
    def extract_numbers(filename):
        match = re.search(r'(\d+)-(\d+)_(\d+)_(\d+)', filename)
        if match:
            return tuple(map(int, match.groups()))  # return as tuple of integers
        return None

    # Dictionary to hold pairs
    self.file_pairs = []

    # Find matching pairs
    for data_file in self.data_filenames:
        data_nums = extract_numbers(data_file)
        if data_nums:
            for seg_file in self.seg_filenames:
                seg_nums = extract_numbers(seg_file)
                if seg_nums == data_nums:
                    self.file_pairs.append((data_file, seg_file))
                    break

  def __len__(self):
      return len(self.seg_filenames)

  def __getitem__(self, idx):
    image = sitk.ReadImage(self.file_pairs[idx][0])
    image_array = sitk.GetArrayFromImage(image)
    image_array = ((image_array - np.mean(image_array)) / (np.std(image_array) + 1e-7)) # adding term for when stddev =0

    image_tensor = torch.from_numpy(image_array).unsqueeze(0)

    segmap = sitk.ReadImage(self.file_pairs[idx][1])
    segmap_array = sitk.GetArrayFromImage(segmap)

    # convert to edge detected segmaps

    edged_imgs = []
    for i in range(segmap_array.shape[0]):
      img = cv2.Canny(segmap_array[i].astype(np.uint8), 0, 1)
      edged_imgs.append(img)

    edged_segmap = np.array(edged_imgs)

    segmap_tensor = torch.from_numpy(edged_segmap).unsqueeze(0)

    segmap_tensor = segmap_tensor.type(torch.float32)

    if segmap_tensor.max() > 1:
      segmap_tensor = segmap_tensor / 255

    return image_tensor, segmap_tensor

In [None]:
dataset = SlicedSpacesDataset()

In [None]:
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [int(0.8*len(dataset)), len(dataset) - int(0.8*len(dataset))])
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=8, shuffle=True)

In [None]:
class DiceLoss(nn.Module):

    def __init__(self, smooth=1e-6):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def dice_coeff(self, pred, target, smooth=1e-6):
        """
        Compute the Dice coefficient for binary classification.

        Parameters:
        - pred: Predicted tensor (probabilities).
        - target: Ground truth tensor (binary).
        - smooth: Smoothing factor to avoid division by zero.

        Returns:
        - Dice coefficient (mean over batch).
        """
        # Flatten tensors
        pred = pred.view(-1)
        target = target.view(-1)

        # Compute intersection and union
        intersection = (pred * target).sum()
        union = pred.sum() + target.sum()

        # Compute Dice coefficient
        dice = (2. * intersection + smooth) / (union + smooth)
        return dice

    def dice_loss(self, pred, target):
        """
        Compute the Dice loss for binary classification.

        Parameters:
        - pred: Predicted tensor (logits or probabilities).
        - target: Ground truth tensor (binary).
        - smooth: Smoothing factor to avoid division by zero.

        Returns:
        - Dice loss (1 - Dice coefficient).
        """
        # Apply sigmoid to get probabilities if not already applied
        pred = torch.sigmoid(pred)

        # Compute Dice coefficient
        dice = self.dice_coeff(pred, target, smooth=self.smooth)

        # Dice loss is 1 - Dice coefficient
        return 1 - dice

    def forward(self, pred, target):
        return self.dice_loss(pred, target)


In [None]:
tedu = TEDUNet(verbose=False).to(device)
bce_loss = nn.BCELoss()
dice_loss = DiceLoss()
optimizer = torch.optim.Adam(tedu.parameters(), lr=3e-4, weight_decay=1e-5)

In [None]:
epochs = 200
for _ in range(epochs):
  print(f'\n\n\nEpoch {_ + 1}')
  for batch_idx, (data, segmap) in enumerate(train_dataloader):

    data = data.to(device)
    segmap = segmap.to(device)

    output = tedu(data.float())

    loss = bce_loss(output, segmap.float()) + dice_loss(output, segmap.float())
    print(f'Batch:{batch_idx} ; Loss:{loss.item()}')

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()



    with torch.no_grad():
      try:
        # plotting intermediate test results

        sample_number = random.randint(0, 3)
        slice_number = random.randint(0, 25)

        test_data, test_segmap = next(iter(test_dataloader))
        test_data, test_segmap = test_data.to(device), test_segmap.to(device)

        test_output = tedu(test_data.float())

        test_loss = bce_loss(test_output, test_segmap.float()) + 2 * dice_loss(test_output, test_segmap.float())
        wandb.log({
            'train loss': loss.item(),
            'test loss': test_loss.item()
        })

        fig, axs = plt.subplots(2, 2, figsize=(12, 6))

        axs[0, 0].imshow(test_segmap[sample_number][0][slice_number].detach().cpu().numpy(), cmap='gray')
        axs[0, 0].set_title('Test GT Segmap Slice')

        axs[0, 1].imshow(test_output[sample_number][0][slice_number].detach().cpu().numpy(), cmap='gray')
        axs[0, 1].set_title('Test Pred Segmap Slice')


        axs[1, 0].imshow(segmap[sample_number][0][slice_number].detach().cpu().numpy(), cmap='gray')
        axs[1, 0].set_title('Train GT Segmap Slice')

        axs[1, 1].imshow(output[sample_number][0][slice_number].detach().cpu().numpy(), cmap='gray')
        axs[1, 1].set_title('Train Pred Segmap Slice')

        # Log the figure to W&B
        wandb.log({
            "Predictions vs Ground Truth": wandb.Image(fig, caption=f"Epoch: {_ + 1}, Batch: {batch_idx}")
        })

        plt.close(fig)
      except:
        pass

In [None]:
def slice_spaces(array):
  # performing jump based space divisions

  jump_coords = [26, 26, 26]


  # # lists to store each space slice to be saved
  array_space_slices = []
  # segmap_space_slices = []


  # logic to slice a space out of a 3d array
  # each slice space is then saved as an nii.gz file to be read by the dataloader while training
  # space slices are created for both, the image and the segmentation mask

  it0 = 1

  while it0 * jump_coords[1] < array.shape[1]:

      it1 = 1
      while it1 * jump_coords[2] < array.shape[2]:
          it2 = 1
          while it2 * jump_coords[0] < array.shape[0]:

              # slicing space from image
              curr_image_slice = array[(it2 - 1) * jump_coords[0]:it2*jump_coords[0], (it0-1) * jump_coords[1]:it0*jump_coords[1], (it1-1)*jump_coords[2]:it1*jump_coords[2]]
              array_space_slices.append(curr_image_slice)
              it2 += 1
          it1 += 1
      it0 += 1

  return np.array(array_space_slices)





In [None]:
def reconstruct_original_array(slices, original_shape, jump_coords=[26, 26, 26]):
    reconstructed_array = np.zeros(original_shape)
    index = 0

    it0 = 1
    while it0 * jump_coords[1] < original_shape[1]:
      it1 = 1
      while it1 * jump_coords[2] < original_shape[2]:
          it2 = 1
          while it2 * jump_coords[0] < original_shape[0]:

              if index < len(slices):
                # slicing space from image
                reconstructed_array[(it2 - 1) * jump_coords[0]:it2*jump_coords[0], (it0-1) * jump_coords[1]:it0*jump_coords[1], (it1-1)*jump_coords[2]:it1*jump_coords[2]] = slices[index]
                index += 1
              else:
                print(f'arre re re')
                break

              it2 += 1
          it1 += 1
      it0 += 1

    return reconstructed_array