In [None]:
!pip3 install nibabel
!pip3 install segmentation-models-3d
!pip install -U segmentation-models-pytorch
!pip install -U torchmetrics

# Import Libraries and Load Data

In [None]:
import numpy as np
import torch
import nibabel as nib
from torch.utils.data import Dataset, DataLoader

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

In [None]:
import glob

data_path = '/kaggle/input/'

volumes = [file for file in glob.glob(data_path+'volume-*')]
labels = [file for file in glob.glob(data_path+'labels-*')]

# volumes.extend(glob.glob(data_path+'volume-[01].nii'))
# labels.extend(glob.glob(data_path+'labels-[01].nii'))

volumes = sorted(volumes)
labels = sorted(labels)
volumes, labels

# Data Information

In [None]:
# data = torch.Tensor(nib.load('/content/drive/MyDrive/Datasets/volume-0.nii.gz').get_fdata()).to(device)
# print(data.shape)
# patches = pt.patchify(data, (8,8,8), step=8)
# patches.shape
# patches = data.unfold(2, 8, 8).unfold(1, 8, 8).unfold(0, 8, 8)
# patches = patches.contiguous().view(-1, 8, 8, 8)
# print(patches.shape)
# patches = patches.view(-1,1,8,8,8)
# print(patches.shape)

# Custom Dataset Class

In [None]:
import os
import shutil

class CustomDataset(Dataset):
    def __init__(self, volumes, labels, transform=None):
        self.volumes = volumes
        self.labels = labels

    def __getitem__(self, idx):
        return self.volumes[idx], self.labels[idx]

    def __len__(self):
        return len(self.volumes)

class CustomPatchDataset(Dataset):
    def __init__(self, volume_patches, label_patches, transform=None):
        self.vol_patches = volume_patches
        self.label_patches = label_patches

    def __getitem__(self, idx):
        return self.vol_patches[idx], self.label_patches[idx]

    def __len__(self):
        return len(self.vol_patches)

In [None]:
dataset = CustomDataset(volumes, labels)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

# Model Implementation

In [None]:
import torch.nn as nn

class conv_block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm3d(out_channels)
        self.relu = nn.ReLU()

    def forward(self, inputs):
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.relu(x)

        return x

class encoder_block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.conv = conv_block(in_channels, out_channels)
        self.pool = nn.MaxPool3d(kernel_size=2)

    def forward(self, inputs):
        x = self.conv(inputs)
        p = self.pool(x)

        return x, p

class decoder_block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=2, stride=2, padding=0)
        self.conv = conv_block(out_channels+out_channels, out_channels)

    def forward(self, inputs, skip):
        x = self.up(inputs)
        x = torch.cat([x, skip], axis=1)
        x = self.conv(x)

        return x

class build_unet(nn.Module):
    def __init__(self):
        super().__init__()
        """ Encoder """
        self.e1 = encoder_block(6,16)
        self.e2 = encoder_block(16,32)
        self.e3 = encoder_block(32,64)
        self.e4 = encoder_block(64,256)
        # self.e5 = encoder_block(128,256)

        """ Bottleneck """
        self.b = encoder_block(256,512)

        """ Decoder """
        self.d0 = decoder_block(512, 256)
        self.d1 = decoder_block(256, 64)
        # self.d2 = decoder_block(128, 64)
        self.d3 = decoder_block(64, 32)
        self.d4 = decoder_block(32, 16)

        # More Complex Model
        """ Encoder """
        # self.e1 = encoder_block(6,16)
        # self.e2 = encoder_block(16,32)
        # self.e3 = encoder_block(32,64)
        # self.e4 = encoder_block(64,128)
        # self.e5 = encoder_block(128,256)

        # """ Bottleneck """
        # self.b = encoder_block(256,512)

        # """ Decoder """
        # self.d0 = decoder_block(512, 256)
        # self.d1 = decoder_block(256, 128)
        # self.d2 = decoder_block(128, 64)
        # self.d3 = decoder_block(64, 32)
        # self.d4 = decoder_block(32, 16)

        """ Classifier """
        self.outputs = nn.Conv3d(16, 6, kernel_size=1, padding=0)


    def forward(self, inputs):
        """ Encoder """
        s1, p1 = self.e1(inputs)
        s2, p2 = self.e2(p1)
        s3, p3 = self.e3(p2)
        s4, p4 = self.e4(p3)
        # s5, p5 = self.e5(p4)

        """ Bottleneck """
        b, b1 = self.b(p4)

        """ Decoder """
        d0 = self.d0(b, s4)
        d1 = self.d1(d0, s3)
        # d2 = self.d2(d1, s3)
        d3 = self.d3(d1, s2)
        d4 = self.d4(d3, s1)

        """ Classifier """
        outputs = self.outputs(d4)

        return outputs


# Training Model on 2 GPU (30GB) with PyTorch DataParallel

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from segmentation_models_pytorch.losses import FocalLoss as fl, DiceLoss as dl
from segmentation_models_pytorch.metrics.functional import iou_score
from torchmetrics import JaccardIndex, Dice
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

colors = ['black', 'green', 'red', 'blue', 'purple', 'orange', 'yellow']
custom_cmap = ListedColormap(colors)

model = nn.DataParallel(build_unet()).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)

loss_obj = fl(mode='multiclass', gamma=2).to(device)
loss_dice = Dice(multiclass=True, num_classes=6).to(device)
metric_iou = JaccardIndex(num_classes=6, task="multiclass").to(device)

# Training loop
num_epochs = 100
for epoch in range(num_epochs):
    model.train()
    total_iou = 0.0
    total_loss = 0.0
    total_patches_epoch = 0
    print('#'*50)
    print('-'*20, f'EPOCH {epoch}', '-'*20)

    for batch_idx, (data_path, target_path) in enumerate(dataloader):
        print('='*50)
        print(f'Working on : {data_path} {target_path}')
        print('-'*50)
        volume = torch.Tensor(nib.load(data_path[0]).get_fdata()).to(device)
        label = torch.Tensor(nib.load(target_path[0]).get_fdata()).to(device)

        # Patchify
        volume_patches = volume.unfold(2,64,64).unfold(1,64,64).unfold(0,64,64)
        label_patches = label.unfold(2,64,64).unfold(1,64,64).unfold(0,64,64)

        volume_patches = volume_patches.contiguous().view(-1,1,64,64,64)
        label_patches = label_patches.contiguous().view(-1,1,64,64,64)

        patch_dataset = CustomPatchDataset(volume_patches, label_patches)
        patch_dataloader = DataLoader(patch_dataset, batch_size=32)

        total_patches = len(patch_dataset)
        patch_loader_size = len(patch_dataloader)
        total_patches_epoch += total_patches
        print('Total Patches', total_patches, 'Dataset Loader', patch_loader_size)

        for patch_idx, (data, target) in zip(tqdm(np.arange(patch_loader_size), total=patch_loader_size), patch_dataloader):
            batch_loss = 0.0
            batch_iou = 0.0

            # for data, target in patch_dataloader:
            data = data.expand(-1,6,-1,-1,-1)
            # target = favourable_target(target)

            optimizer.zero_grad()
            output = model(data)

            # print(output.shape, target.shape)
            loss = loss_obj(output, target.squeeze()) + loss_dice(output, target.to(torch.int).squeeze())

            loss.backward()
            optimizer.step()

            # Calculate IOU for the current batch
            pred = torch.argmax(output, dim=1)

            iou_patch = []
            iou_patch.append(metric_iou(output, target.squeeze()))

            actual_iou_patch = sum(iou_patch)/len(iou_patch)

            total_iou += actual_iou_patch
            total_loss += loss.item()

            batch_iou += actual_iou_patch
            batch_loss += loss.item()

            # Extract the center slices (32, 32) from both images
            output_cp = output.detach().clone()

            target_cpu = target.cpu().numpy()
            output_cpu = output_cp.cpu().numpy()

            output_cpu = np.argmax(output_cpu, axis=1)

            # Now, single_channel_image contains integer values from 0 to 5, corresponding to different organs.
            center_slice_actual = target_cpu.squeeze()[16, 32, :, :]
            center_slice_predicted = output_cpu[16, 32, :, :]

            # Create a figure and plot the actual and predicted single-channel images side by side
            plt.figure(figsize=(12, 5))

            plt.subplot(1, 2, 1)
            plt.imshow(center_slice_actual, cmap=custom_cmap, vmin=0, vmax=5)
            plt.title('Actual Center Slice')

            plt.subplot(1, 2, 2)
            plt.imshow(center_slice_predicted, cmap=custom_cmap, vmin=0, vmax=5)
            plt.title('Predicted Center Slice')

            plt.savefig(f'segment-p{patch_idx}-b{batch_idx}-e{epoch}.png')
            plt.close()

            avg_batch_loss = batch_loss / len(patch_dataloader)
            avg_batch_metric = batch_iou / len(patch_dataloader)
            torch.save(model.state_dict(), 'unet3d_model.pth')

            print(f"Epoch [{epoch+1}/{num_epochs}] Batch [{batch_idx+1}/{len(dataloader)}] Loss: {avg_batch_loss:.4f} IOU: {avg_batch_metric:.4f}")
            print()

    # Calculate average metrics for the epoch
    avg_iou = total_iou / total_patches_epoch
    avg_loss = total_loss / total_patches_epoch

    print(f"Epoch [{epoch+1}/{num_epochs}] Avg Loss: {avg_loss:.4f} Avg IOU: {avg_iou:.4f}")

# Save the trained model if needed
torch.save(model.state_dict(), 'unet3d_model.pth')

In [None]:
# del model
# del optimizer
# torch.cuda.empty_cache()
# import gc
# gc.collect()

In [None]:
!zip data-color.zip /kaggle/working/*

In [None]:
import os
import subprocess
from IPython.display import FileLink, display

def download_file(path, download_file_name):
    os.chdir('/kaggle/working/')
    zip_name = f"/kaggle/working/{download_file_name}.zip"
#     command = f"zip {zip_name} {path} -r"
#     result = subprocess.run(command, shell=True, capture_output=True, text=True)
#     if result.returncode != 0:
#         print("Unable to run zip command!")
#         print(result.stderr)
#         return
    display(FileLink(f'{download_file_name}.zip'))
download_file('', 'data-color')