In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
!pip install -r "/content/drive/MyDrive/Personal/MS/Brain_Tumor_segmentaion3D/environment.txt"

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


# Data Description:
 
### Image types: 
* Native T1-weighted (T1): This scan is obtained using a standard T1-weighted imaging sequence, which uses a short TR (repetition time) and a short TE (echo time) to provide high-resolution images of the brain tissue. This sequence highlights the differences in tissue types based on their contrast with the surrounding tissues.

* Post-contrast T1-weighted (T1Gd): This scan is obtained using a T1-weighted imaging sequence after the administration of a contrast agent such as Gadolinium. The contrast agent is injected intravenously and is taken up by cells with a disrupted blood-brain barrier, which is a common characteristic of brain tumors. This sequence highlights the regions of the brain with a disrupted blood-brain barrier, such as enhancing tumor regions.

* T2-weighted (T2): This scan is obtained using a T2-weighted imaging sequence, which uses a long TR and a long TE to provide a more detailed view of the brain tissue. This sequence highlights subtle differences in tissue types that are not visible on T1 scans.

* T2 Fluid Attenuated Inversion Recovery (T2-FLAIR): This scan is obtained using a T2-weighted imaging sequence that is modified to suppress the signal from cerebrospinal fluid (CSF). This is achieved by using an inversion recovery pulse before the T2-weighted acquisition. This sequence is useful for distinguishing between edema and other types of brain tissue because the CSF signal is suppressed.

### Segmentation Classes:
* label 0: No tumor
* label 1: necrotic tumor core (Visible in T2): This class represents the core of the tumor, which is composed of necrotic tissue and non-enhancing tumor cells.
* label 2: the peritumoral edematous/invaded tissue (Visible in flair):  This class represents the edema, or swelling, that occurs around the tumor due to the accumulation of fluid in the surrounding brain tissue.
* label 4: Gd-enhancing tumor (Needs to be converted to 3) (Visible in T1ce): This class represents the region of the tumor that enhances with the administration of contrast agent.

In [None]:
import tarfile

In [None]:
my_tarfile = tarfile.open('/kaggle/input/brats-2021-task1/BraTS2021_Training_Data.tar')
index = my_tarfile.getnames()

In [None]:
index[: 8]

['.',
 './.DS_Store',
 './BraTS2021_00000',
 './BraTS2021_00000/BraTS2021_00000_flair.nii.gz',
 './BraTS2021_00000/BraTS2021_00000_seg.nii.gz',
 './BraTS2021_00000/BraTS2021_00000_t1.nii.gz',
 './BraTS2021_00000/BraTS2021_00000_t1ce.nii.gz',
 './BraTS2021_00000/BraTS2021_00000_t2.nii.gz']

# Utilities

In [3]:
import os
import yaml
import importlib


def check_path(path):
    if not os.path.exists(path):
        os.mkdir(path)


def check_path_recursively(save_folder, config,
                           params=["model_name", "pretrained_name", "dataset", "similarity_measure", "prompt"]):
    params = [config[i] for i in params if config.get(i) is not None]
    for i in params:
        save_folder = save_folder + "/" + i
        check_path(save_folder)
    config["save_folder"] = save_folder


def yaml_writer(path, contents):
    with open(path, "w") as f:
        yaml.dump(contents, f)


def text_file_reader(path):
    with open(path, "r") as f:
        contents = list(f.readlines())
    contents = [i.split("\n")[0] for i in contents]
    return contents


def text_file_writer(path, contents):
    with open(path, "w") as f:
        for line in contents:
            if isinstance(line, list):
                f.write(" ".join(line))
                f.write("\n")
            else:
                f.write(line + "\n")


def instantiate_attribute(path):
    module_path, attribute_name = path.rsplit(".", 1)
    module = importlib.import_module(module_path)
    return getattr(module, attribute_name)


def instantiate_class(path, params):
    optimizer_attribute = instantiate_attribute(path)
    return optimizer_attribute(params)


# Basic Visualization

In [None]:
!pip install itkwidgets

In [None]:
import glob
import os
import numpy as np
import matplotlib.pyplot as plt
import nibabel as nib
from ipywidgets import interact
from tqdm import tqdm
from mpl_toolkits.mplot3d import Axes3D
from scipy.ndimage.measurements import label, center_of_mass

In [None]:
class VisualizePatientData:
    def __init__(self, patient_data_folder, img_type_id):
        self.patient_data_list = sorted(glob.glob(os.path.join(patient_data_folder, "*")))
        self.image_types = ["flair", "seg", "t1", "t1ce", "t2"]
        self.cmap_list = ["gray", "BuPu", "gray", "gray", "gray"]
        self.i = img_type_id
        self.fig = plt.figure(figsize=(1, 1));
    
    def visualize_brain_scans(self, cube_path):
        def create_display(layer):
            self.fig.add_subplot(3, 2, self.i + 1)
            plt.imshow(self.scans[:, :, layer], cmap=self.cmap_list[self.i]);
            plt.axis('off')
            return layer
        self.scans = np.asarray(nib.load(cube_path).get_fdata())
        print(seld.scans.shape)
        interact(create_display, layer=(0, self.scans.shape[2] - 1));

    def __call__(self, idx):        
        data_path = os.path.join(self.patient_data_list[idx], "BraTS20_Training_%03d_%s.nii" % (idx + 1, self.image_types[self.i]))
        self.visualize_brain_scans(data_path)

In [None]:
patient_data_folder = "/kaggle/input/brain-tumor-segmentation-in-mri-brats-2015/MICCAI_BraTS2020_TrainingData"
example_patient_id = 0

In [None]:
visualizer_flair = VisualizePatientData(patient_data_folder, 0)
visualizer_flair(example_patient_id)

In [None]:
visualizer_seg = VisualizePatientData(patient_data_folder, 1)
visualizer_seg(example_patient_id)

In [None]:
visualizer_t1 = VisualizePatientData(patient_data_folder, 2)
visualizer_t1(example_patient_id)

In [None]:
visualizer_t1ce = VisualizePatientData(patient_data_folder, 3)
visualizer_t1ce(example_patient_id)

In [None]:
visualizer_t2 = VisualizePatientData(patient_data_folder, 4)
visualizer_t2(example_patient_id)

# Segmentation Classes study

In [None]:
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
import glob
import os
from tqdm import tqdm
from mpl_toolkits.mplot3d import Axes3D
from scipy.ndimage.measurements import label, center_of_mass

In [None]:
patient_data_folder = "/kaggle/input/brain-tumor-segmentation-in-mri-brats-2015/MICCAI_BraTS2020_TrainingData"
patient_data_list = sorted(glob.glob(os.path.join(patient_data_folder, "*")))

In [None]:
core_tumor = []
peritumoral_tissue = []
enhancing_tumor = []
cube_size = 240 * 240 * 155
for i in tqdm(range(len(patient_data_list) - 2)):
    if i == 354:
        patient_label_data_path = "/kaggle/input/brain-tumor-segmentation-in-mri-brats-2015/MICCAI_BraTS2020_TrainingData/BraTS20_Training_355/W39_1998.09.19_Segm.nii"
    else:
        patient_label_data_path = os.path.join(patient_data_list[i], "BraTS20_Training_%03d_seg.nii" % (i + 1))
    patient_label_data = nib.load(patient_label_data_path).get_fdata()
    core_tumor.append((len(np.where(patient_label_data == 1)[0]) / cube_size) * 100)
    peritumoral_tissue.append((len(np.where(patient_label_data == 2)[0]) / cube_size) * 100)
    enhancing_tumor.append((len(np.where(patient_label_data == 4)[0]) / cube_size) * 100)

In [None]:
data = [core_tumor, peritumoral_tissue, enhancing_tumor]
data_string = ["core_tumor", "peritumoral_tissue", "enhancing_tumor"] 
fig = plt.figure(figsize=(12, 12))
for i in range(4):
    plt.subplot(2, 2, i + 1)
    if i == 3:
        plt.bar(x=[0, 1, 2], height=[np.average(core_tumor), np.average(peritumoral_tissue), np.average(enhancing_tumor)])
        plt.title("Average volume of %s, %s, %s" % (data_string[0], data_string[1], data_string[2]))
    else:    
        plt.hist(data[i])
        plt.title("Distribution of volume of %s" % data_string[i])

In [None]:
def visualize_brain_scans(scans):
    def create_display(layer):
        plt.imshow(scans[:, :, layer], cmap="BuPu");
        plt.axis('off')
        return layer
    interact(create_display, layer=(0, scans.shape[2] - 1));

In [None]:
def compute_centroid_volume_largest_component(seg_labels, label_id):
    seg_labels_core = np.asarray(seg_labels == label_id, dtype=np.uint8)
    labels, num_labels = label(seg_labels_core)
    volumes = []
    for i in range(1, num_labels + 1):
        volume = np.sum(labels == i)
        volumes.append(volume)
    volumes = np.array(volumes)
    if len(volumes):
        largest_components_id = np.argmax(volumes, -1) + 1
        centroid_i = center_of_mass(seg_labels_core, labels=labels, index=largest_components_id)
        volume_i = volumes[largest_components_id - 1]
        return centroid_i, volume_i
    else:
        return None, None

In [None]:
def centroid_volume_correlation(label_id):
    centroids = []
    volumes = []
    for i in tqdm(range(len(patient_data_list) - 2)):
        if i == 354:
            patient_label_data_path = "/kaggle/input/brain-tumor-segmentation-in-mri-brats-2015/MICCAI_BraTS2020_TrainingData/BraTS20_Training_355/W39_1998.09.19_Segm.nii"
        else:
            patient_label_data_path = os.path.join(patient_data_list[i], "BraTS20_Training_%03d_seg.nii" % (i + 1))
        patient_label_data = nib.load(patient_label_data_path).get_fdata()
        centroid_i, volume_i = compute_centroid_volume_largest_component(patient_label_data, label_id)
        if volume_i is None:
            continue
        centroids.append(centroid_i)
        volumes.append(volume_i)
    return centroids, volumes

In [None]:
label_id = 1
centroids, volumes = centroid_volume_correlation(label_id)
centroids = np.array(centroids)
volumes = np.array(volumes)

In [None]:
fig = plt.figure(figsize=(120, 120))
ax = fig.add_subplot(111, projection='3d')
ax.scatter(centroids[:,0], centroids[:,1], centroids[:,2], s=volumes, c=volumes, cmap='BuPu')
ax.set_xlabel('X Centroid', fontsize=100)
ax.set_ylabel('Y Centroid', fontsize=100)
ax.set_zlabel('Z Centroid', fontsize=100)
ax.set_title('Correlation between Centroid and Volume', fontsize=100)
# cbar = plt.colorbar()
# cbar.set_label('Volume')
plt.show()

In [None]:
label_id = 2
centroids, volumes = centroid_volume_correlation(label_id)
centroids = np.array(centroids)
volumes = np.array(volumes)

In [None]:
fig = plt.figure(figsize=(120, 120))
ax = fig.add_subplot(111, projection='3d')
ax.scatter(centroids[:,0], centroids[:,1], centroids[:,2], s=volumes, c=volumes, cmap='BuPu')
ax.set_xlabel('X Centroid', fontsize=100)
ax.set_ylabel('Y Centroid', fontsize=100)
ax.set_zlabel('Z Centroid', fontsize=100)
ax.set_title('Correlation between Centroid and Volume', fontsize=100)
# cbar = plt.colorbar()
# cbar.set_label('Volume')
plt.show()

In [None]:
label_id = 4
centroids, volumes = centroid_volume_correlation(label_id)
centroids = np.array(centroids)
volumes = np.array(volumes)
fig = plt.figure(figsize=(120, 120))
ax = fig.add_subplot(111, projection='3d')
ax.scatter(centroids[:,0], centroids[:,1], centroids[:,2], s=volumes, c=volumes, cmap='BuPu')
ax.set_xlabel('X Centroid', fontsize=100)
ax.set_ylabel('Y Centroid', fontsize=100)
ax.set_zlabel('Z Centroid', fontsize=100)
ax.set_title('Correlation between Centroid and Volume', fontsize=100)
# cbar = plt.colorbar()
# cbar.set_label('Volume')
plt.show()

# Generating two splits from a dataset

In [None]:
import random
import glob
import os
from tqdm import tqdm
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt

In [None]:
def compute_class_volumes(patient_folder_list):
    core_tumor = []
    peritumoral_tissue = []
    enhancing_tumor = []
    cube_size = 240 * 240 * 155
    for i in tqdm(patient_folder_list):
        if i.endswith("355"):
            patient_label_data_path = "/kaggle/input/brain-tumor-segmentation-in-mri-brats-2015/MICCAI_BraTS2020_TrainingData/BraTS20_Training_355/W39_1998.09.19_Segm.nii"
        else:
            patient_label_data_path = os.path.join(i, "BraTS20_Training_%03d_seg.nii" % int(i.split("_")[-1]))
        patient_label_data = nib.load(patient_label_data_path).get_fdata()
        core_tumor.append((len(np.where(patient_label_data == 1)[0]) / cube_size) * 100)
        peritumoral_tissue.append((len(np.where(patient_label_data == 2)[0]) / cube_size) * 100)
        enhancing_tumor.append((len(np.where(patient_label_data == 4)[0]) / cube_size) * 100)
    return core_tumor, peritumoral_tissue, enhancing_tumor

In [None]:
def visualize_class_distributions(data):
    core_tumor, peritumoral_tissue, enhancing_tumor = data
    data_string = ["core_tumor", "peritumoral_tissue", "enhancing_tumor"] 
    fig = plt.figure(figsize=(12, 12))
    for i in range(4):
        plt.subplot(2, 2, i + 1)
        if i == 3:
            plt.bar(x=[0, 1, 2], height=[np.average(core_tumor), np.average(peritumoral_tissue), np.average(enhancing_tumor)])
            plt.title("Average volume of %s, %s, %s" % (data_string[0], data_string[1], data_string[2]))
        else:    
            plt.hist(data[i])
            plt.title("Distribution of volume of %s" % data_string[i])
    return np.average(core_tumor), np.average(peritumoral_tissue), np.average(enhancing_tumor)

In [None]:
def compute_two_splits(data_list, split_ratio):
    continue_loop = True
    while continue_loop:
        list1 = random.sample(data_list, int(0.8 * len(data_list)))
        list2 = [i for i in data_list if i not in list1]
        list1_volumes = compute_class_volumes(list1)
        list2_volumes = compute_class_volumes(list2)
        list1_volumes_averages = [np.average(i) for i in list1_volumes]
        list2_volumes_averages = [np.average(i) for i in list2_volumes]
        overall_averages = [np.average(i + j) for i, j in zip(list1_volumes_averages, list2_volumes_averages)]
        list1_class_proportion = np.divide(list1_volumes_averages, np.sum(list1_volumes_averages))
        list2_class_proportion = np.divide(list2_volumes_averages, np.sum(list2_volumes_averages))
        overall_class_proportion = np.divide(overall_averages, np.sum(overall_averages))
        print("Current split data: ", list1_class_proportion, list2_class_proportion)
        if ((overall_class_proportion[0] - 0.01 <= list1_class_proportion[0] <= overall_class_proportion[0] + 0.01) and
           (overall_class_proportion[1] - 0.01 <= list1_class_proportion[1] <= overall_class_proportion[1] + 0.01) and
           (overall_class_proportion[2] - 0.01 <= list1_class_proportion[2] <= overall_class_proportion[2] + 0.01) and
           (overall_class_proportion[0] - 0.01 <= list2_class_proportion[0] <= overall_class_proportion[0] + 0.01) and
           (overall_class_proportion[1] - 0.01 <= list2_class_proportion[1] <= overall_class_proportion[1] + 0.01) and
           (overall_class_proportion[2] - 0.01 <= list2_class_proportion[2] <= overall_class_proportion[2] + 0.01)):
                return list1, list2

In [None]:
patient_data_folder = "/kaggle/input/brain-tumor-segmentation-in-mri-brats-2015/MICCAI_BraTS2020_TrainingData"
patient_data_list = sorted(glob.glob(os.path.join(patient_data_folder, "*")))[: -2]

In [None]:
train_split, test_split = compute_two_splits(patient_data_list, 0.8)

In [None]:
train_volumes = compute_class_volumes(train_split)
test_volumes = compute_class_volumes(test_split)
visualize_class_distributions(train_volumes)

In [None]:
visualize_class_distributions(test_volumes)

In [None]:
training_final, validation = compute_two_splits(train_split, 0.8)

In [None]:
train_final_volumes = compute_class_volumes(training_final)
validation_volumes = compute_class_volumes(validation)
visualize_class_distributions(train_final_volumes)

In [None]:
visualize_class_distributions(validation_volumes)

# Models

In [4]:
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Tuple, Union

import torch
import torch.nn as nn

from monai.networks.blocks import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock
from monai.networks.blocks import UnetOutBlock
from monai.networks.nets import ViT


class UNETR(nn.Module):
    """
    UNETR based on: "Hatamizadeh et al.,
    UNETR: Transformers for 3D Medical Image Segmentation <https://arxiv.org/abs/2103.10504>"
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        img_size: Tuple[int, int, int],
        feature_size: int = 16,
        hidden_size: int = 768,
        mlp_dim: int = 3072,
        num_heads: int = 12,
        pos_embed: str = "perceptron",
        norm_name: Union[Tuple, str] = "instance",
        conv_block: bool = False,
        res_block: bool = True,
        dropout_rate: float = 0.0,
    ) -> None:
        """
        Args:
            in_channels: dimension of input channels.
            out_channels: dimension of output channels.
            img_size: dimension of input image.
            feature_size: dimension of network feature size.
            hidden_size: dimension of hidden layer.
            mlp_dim: dimension of feedforward layer.
            num_heads: number of attention heads.
            pos_embed: position embedding layer type.
            norm_name: feature normalization type and arguments.
            conv_block: bool argument to determine if convolutional block is used.
            res_block: bool argument to determine if residual block is used.
            dropout_rate: faction of the input units to drop.

        Examples::

            # for single channel input 4-channel output with patch size of (96,96,96), feature size of 32 and batch norm
            >>> net = UNETR(in_channels=1, out_channels=4, img_size=(96,96,96), feature_size=32, norm_name='batch')

            # for 4-channel input 3-channel output with patch size of (128,128,128), conv position embedding and instance norm
            >>> net = UNETR(in_channels=4, out_channels=3, img_size=(128,128,128), pos_embed='conv', norm_name='instance')

        """

        super().__init__()

        if not (0 <= dropout_rate <= 1):
            raise AssertionError("dropout_rate should be between 0 and 1.")

        if hidden_size % num_heads != 0:
            raise AssertionError("hidden size should be divisible by num_heads.")

        if pos_embed not in ["conv", "perceptron"]:
            raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.")

        self.num_layers = 12
        self.patch_size = (16, 16, 16)  #(16, 16, 16)
        self.feat_size = (
            img_size[0] // self.patch_size[0],
            img_size[1] // self.patch_size[1],
            img_size[2] // self.patch_size[2],
        )
        self.hidden_size = hidden_size
        self.classification = False
        self.vit = ViT(
            in_channels=in_channels,
            img_size=img_size,
            patch_size=self.patch_size,
            hidden_size=hidden_size,
            mlp_dim=mlp_dim,
            num_layers=self.num_layers,
            num_heads=num_heads,
            pos_embed=pos_embed,
            classification=self.classification,
            dropout_rate=dropout_rate,
        )
        self.encoder1 = UnetrBasicBlock(
            spatial_dims=3,
            in_channels=in_channels,
            out_channels=feature_size,
            kernel_size=3,
            stride=1,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.encoder2 = UnetrPrUpBlock(
            spatial_dims=3,
            in_channels=hidden_size,
            out_channels=feature_size * 2,
            num_layer=2,
            kernel_size=3,
            stride=1,
            upsample_kernel_size=2,
            norm_name=norm_name,
            conv_block=conv_block,
            res_block=res_block,
        )
        self.encoder3 = UnetrPrUpBlock(
            spatial_dims=3,
            in_channels=hidden_size,
            out_channels=feature_size * 4,
            num_layer=1,
            kernel_size=3,
            stride=1,
            upsample_kernel_size=2,
            norm_name=norm_name,
            conv_block=conv_block,
            res_block=res_block,
        )
        self.encoder4 = UnetrPrUpBlock(
            spatial_dims=3,
            in_channels=hidden_size,
            out_channels=feature_size * 8,
            num_layer=0,
            kernel_size=3,
            stride=1,
            upsample_kernel_size=2,
            norm_name=norm_name,
            conv_block=conv_block,
            res_block=res_block,
        )
        self.decoder5 = UnetrUpBlock(
            spatial_dims=3,
            in_channels=hidden_size,
            out_channels=feature_size * 8,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.decoder4 = UnetrUpBlock(
            spatial_dims=3,
            in_channels=feature_size * 8,
            out_channels=feature_size * 4,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.decoder3 = UnetrUpBlock(
            spatial_dims=3,
            in_channels=feature_size * 4,
            out_channels=feature_size * 2,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.decoder2 = UnetrUpBlock(
            spatial_dims=3,
            in_channels=feature_size * 2,
            out_channels=feature_size,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.out = UnetOutBlock(spatial_dims=3, in_channels=feature_size, out_channels=out_channels)  # type: ignore

    def proj_feat(self, x, hidden_size, feat_size):
        x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size)
        x = x.permute(0, 4, 1, 2, 3).contiguous()
        return x

    def load_from(self, weights):
        with torch.no_grad():
            res_weight = weights
            # copy weights from patch embedding
            for i in weights["state_dict"]:
                print(i)
            self.vit.patch_embedding.position_embeddings.copy_(
                weights["state_dict"]["module.transformer.patch_embedding.position_embeddings_3d"]
            )
            self.vit.patch_embedding.cls_token.copy_(
                weights["state_dict"]["module.transformer.patch_embedding.cls_token"]
            )
            self.vit.patch_embedding.patch_embeddings[1].weight.copy_(
                weights["state_dict"]["module.transformer.patch_embedding.patch_embeddings.1.weight"]
            )
            self.vit.patch_embedding.patch_embeddings[1].bias.copy_(
                weights["state_dict"]["module.transformer.patch_embedding.patch_embeddings.1.bias"]
            )

            # copy weights from  encoding blocks (default: num of blocks: 12)
            for bname, block in self.vit.blocks.named_children():
                print(block)
                block.loadFrom(weights, n_block=bname)
            # last norm layer of transformer
            self.vit.norm.weight.copy_(weights["state_dict"]["module.transformer.norm.weight"])
            self.vit.norm.bias.copy_(weights["state_dict"]["module.transformer.norm.bias"])

    def forward(self, x_in):
        x, hidden_states_out = self.vit(x_in)
        enc1 = self.encoder1(x_in)
        x2 = hidden_states_out[3]
        enc2 = self.encoder2(self.proj_feat(x2, self.hidden_size, self.feat_size))
        x3 = hidden_states_out[6]
        enc3 = self.encoder3(self.proj_feat(x3, self.hidden_size, self.feat_size))
        x4 = hidden_states_out[9]
        enc4 = self.encoder4(self.proj_feat(x4, self.hidden_size, self.feat_size))
        dec4 = self.proj_feat(x, self.hidden_size, self.feat_size)
        dec3 = self.decoder5(dec4, enc4)
        dec2 = self.decoder4(dec3, enc3)
        dec1 = self.decoder3(dec2, enc2)
        out = self.decoder2(dec1, enc1)
        logits = self.out(out)
        return logits

# Data Generator

In [5]:
import torch
import nibabel as nib
import numpy as np
import os
import cv2
import glob
import torchvision.transforms as transforms

In [6]:
# device = "cuda" if torch.cuda.is_available() else "cpu"

In [7]:
def resize_3d_image(original_img, size):
    resized_img = [cv2.resize(original_img[:, :, i], (128, 128)) for i in range(original_img.shape[2])]
    resized_img = np.moveaxis(np.stack(resized_img), 0, -1)
    return resized_img


class BTSDataset(torch.utils.data.Dataset):
    def __init__(self, patient_data_list_path, src_folder, no_classes=4):
        patient_data_list = sorted(text_file_reader(patient_data_list_path))
        patient_data_list = [os.path.join(src_folder, i) for i in patient_data_list]
        self.patient_flair_scans_list = [glob.glob(os.path.join(i, "*_flair.nii.gz"))[0] for i in patient_data_list]
        self.patient_t1ce_scans_list = [glob.glob(os.path.join(i, "*_t1ce.nii.gz"))[0] for i in patient_data_list]
        self.patient_t2_scans_list = [glob.glob(os.path.join(i, "*_t2.nii.gz"))[0] for i in patient_data_list]
        self.patient_seg_scans_list = [glob.glob(os.path.join(i, "*_seg.nii.gz"))[0] for i in patient_data_list]
        self.transform = transforms.ToTensor()
        self.no_classes = no_classes

    def __len__(self):
        return len(self.patient_flair_scans_list)

    def __getitem__(self, idx):
        t1ce_scan = self.transform(resize_3d_image(np.asarray(nib.load(self.patient_t1ce_scans_list[idx]).get_fdata())[:, :, 5: -6], (128 ,128)))
        t2_scan = self.transform(resize_3d_image(np.asarray(nib.load(self.patient_t2_scans_list[idx]).get_fdata())[:, :, 5: -6], (128 ,128)))
        flair_scan = self.transform(resize_3d_image(np.asarray(nib.load(self.patient_flair_scans_list[idx]).get_fdata())[:, :, 5: -6], (128 ,128)))
        seg_label = resize_3d_image(np.asarray(nib.load(self.patient_seg_scans_list[idx]).get_fdata()[:, :, 5: -6]), (128 ,128))
        seg_label[seg_label == 4] = 3
        seg_label = self.transform(seg_label)
        seg_label_ohe = torch.nn.functional.one_hot(seg_label.to(torch.int64), self.no_classes)
        seg_label_ohe = torch.moveaxis(seg_label_ohe, -1, 0)
        image_scans_stacked = torch.stack([t1ce_scan, t2_scan, flair_scan])
        return image_scans_stacked.to(torch.float32), seg_label_ohe.to(torch.float32)

In [15]:
# path = "/content/drive/MyDrive/Personal/MS/Brain_Tumor_segmentaion3D/dataset/BraTS2021_Training_Data"
# folder_list = sorted(glob.glob(os.path.join(path, "*")))
# folder_list = [i for i in folder_list if os.path.isdir(i)]
# folder_list.pop(354)

In [16]:
# src_folder = "/content/drive/MyDrive/Personal/MS/Brain_Tumor_segmentaion3D"
# folder_list_path = "/content/drive/MyDrive/Personal/MS/Brain_Tumor_segmentaion3D/dataset/lists/sample_100.txt"
# # patient_data_list = sorted(text_file_reader(folder_list_path))
# # patient_data_list = [os.path.join(src_folder, i) for i in patient_data_list]
# dataset = BTSDataset(folder_list_path, src_folder)

In [17]:
# ex_idx = 0
# ex_img, ex_label = dataset[ex_idx]
# ex_img.shape, ex_label.shape

(torch.Size([3, 144, 128, 128]), torch.Size([4, 144, 128, 128]))

# Training setup

In [8]:
from tqdm import tqdm
import torch
import os
import numpy as np


device = "cuda" if torch.cuda.is_available() else "cpu"
os.environ["KMP_DUPLICATE_LIB_OK"] = 'True'


class Trainer:
    def __init__(self, model, loss, optimizer, number_of_epochs, weights_save_folder):
        self.loss = instantiate_class(**loss)
        self.number_of_epochs = number_of_epochs
        self.model = model
        self.model.to("cuda")
        self.optimizer = instantiate_attribute(optimizer["path"])(self.model.parameters(), **optimizer["params"])
        self.best_model_weights = None
        self.weights_save_folder = weights_save_folder
        check_path(self.weights_save_folder)

    def __call__(self, training_dataloader, validation_dataloader):
        average_val_loss = 10 ** 6
        for i in range(self.number_of_epochs):
            print("Epoch number: ", i)
            training_loss = []
            for image, label in tqdm(training_dataloader):
                self.optimizer.zero_grad()
                probabilities = self.model(image.cuda().to(device))
                # probabilities = self.model(image)
                image = None
                torch.cuda.empty_cache()
                training_loss.append(self.loss(probabilities, label.cuda().to(device)))
                # self.loss(probabilities, label)
                self.optimizer.step()
                del label, probabilities
            print("Average validation loss: ", np.avg(training_loss))
            validation_loss = []
            print("Validating now")
            for image, label in tqdm(validation_dataloader):
                with torch.no_grad():
                    image = image.cuda().to(device)
                    label = label.cuda().to(device)
                    probabilities = self.model(image)
                    image = None
                    torch.cuda.empty_cache()
                    validation_loss.append(self.loss(probabilities, label).cpu())
                    del probabilities, label
            print("Average validation loss: ", np.avg(validation_loss))
            if np.average(validation_loss) < average_val_loss:
                self.best_model_weights = self.model.state_dict()
                average_val_loss = np.average(validation_loss)
                torch.save(self.model.state_dict(), os.path.join(self.weights_save_folder, 'model_weights_%d.pth' % i))
        return self.best_model_weights, self.model.to("cpu")


# Evaluation

In [9]:
import numpy as np
import torch


device = "cuda" if torch.cuda.is_available() else "cpu"


def generate_predictions(model, test_dataloader):
    model.to(device)
    test_predictions = []
    test_labels = []
    with torch.no_grad():
        for image, label in test_dataloader:
            image = image.cuda().to(device)
            label = label.cpu().numpy()
            label = np.argmax(label, -1)
            probabilities = model(image)
            test_predictions.extend(np.argmax(probabilities.cpu(), -1))
            test_labels.extend(label)
            del probabilities, label
    return test_predictions, test_labels


# Experiments

In [10]:
if not os.path.exists("/content/Experiemnts"):
  os.mkdir("/content/Experiemnts")
if not os.path.exists("/content/Experiemnts/test"):
  os.mkdir("/content/Experiemnts/test")

In [11]:
config_path = "/content/drive/MyDrive/Personal/MS/Brain_Tumor_segmentaion3D/source_code/configs/overfitting_test.yaml"
with open(config_path, "r") as f:
    config = yaml.safe_load(f)

In [12]:
model = UNETR(**config["Model"])

In [13]:
training_datagenerator = BTSDataset(**config["Training_Dataset"])
training_dataloader = torch.utils.data.DataLoader(training_datagenerator, batch_size=1,
                                                  shuffle=True)

In [14]:
trainer = Trainer(**config["Trainer"], model=model, weights_save_folder=config["save_folder"])

In [15]:
model_weights, trained_model = trainer(training_dataloader, training_dataloader)

Epoch number:  0


  4%|▍         | 4/100 [00:33<13:17,  8.31s/it]


OutOfMemoryError: ignored