# Install Dependencies

In [1]:
%%capture
!pip install monai
!pip install dicom2nifti

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
!cp /content/drive/MyDrive/Machine-Learning-Biomedicine/PankVision-3D/preprocess/final_preprocess.py /content
!cp /content/drive/MyDrive/Machine-Learning-Biomedicine/PankVision-3D/utilities/final_utilities.py /content
!cp /content/drive/MyDrive/Machine-Learning-Biomedicine/PankVision-3D/model/get_model.py /content

In [12]:
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.losses import DiceLoss, DiceCELoss

import torch
from final_utilities import train
from get_model import get_model
from final_preprocess import preprocess_image, prepare

import numpy as np
from skimage import exposure
from monai.transforms import (
    Compose,
    LoadImaged,
    AddChanneld,
    Spacingd,
    Orientationd,
    ScaleIntensityRanged,
    CropForegroundd,
    Resized,
    ToTensord,
)
from monai.data import CacheDataset, Dataset, DataLoader
from glob import glob
import os
from monai.utils import set_determinism

In [15]:
from monai.utils import first
import matplotlib.pyplot as plt
import torch
import os
import numpy as np
from monai.losses import DiceLoss
from tqdm import tqdm
from sklearn.metrics import confusion_matrix

def dice_metric(predicted, target):
    '''
    In this function we take `predicted` and `target` (label) to calculate the dice coeficient then we use it
    to calculate a metric value for the training and the validation.
    '''
    dice_value = DiceLoss(to_onehot_y=True, sigmoid=True, squared_pred=True)
    value = 1 - dice_value(predicted, target).item()
    return value

def calculate_weights(val1, val2):
    '''
    In this function we take the number of the background and the forgroud pixels to return the `weights`
    for the cross entropy loss values.
    '''
    count = np.array([val1, val2])
    summ = count.sum()
    weights = count/summ
    weights = 1/weights
    summ = weights.sum()
    weights = weights/summ
    return torch.tensor(weights, dtype=torch.float32)

def train(model, data_in, loss, optim, max_epochs, model_dir, test_interval=1 , device=torch.device("cuda:0")):
    best_metric = -1
    best_metric_epoch = -1
    save_loss_train = []
    save_loss_test = []
    save_metric_train = []
    save_metric_test = []
    train_loader, test_loader = data_in

    for epoch in range(max_epochs):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{max_epochs}")
        model.train()
        train_epoch_loss = 0
        train_step = 0
        epoch_metric_train = 0
        for batch_data in train_loader:

            train_step += 1

            volume = batch_data["vol"]
            label = batch_data["seg"]
            label = label != 0
            volume, label = (volume.to(device), label.to(device))

            optim.zero_grad()
            outputs = model(volume)

            train_loss = loss(outputs, label)

            train_loss.backward()
            optim.step()

            train_epoch_loss += train_loss.item()
            print(
                f"{train_step}/{len(train_loader) // train_loader.batch_size}, "
                f"Train_loss: {train_loss.item():.4f}")

            train_metric = dice_metric(outputs, label)
            epoch_metric_train += train_metric
            print(f'Train_dice: {train_metric:.4f}')

        print('-'*20)

        train_epoch_loss /= train_step
        print(f'Epoch_loss: {train_epoch_loss:.4f}')
        save_loss_train.append(train_epoch_loss)
        np.save(os.path.join(model_dir, 'loss_train.npy'), save_loss_train)

        epoch_metric_train /= train_step
        print(f'Epoch_metric: {epoch_metric_train:.4f}')

        save_metric_train.append(epoch_metric_train)
        np.save(os.path.join(model_dir, 'metric_train.npy'), save_metric_train)

        if test_interval > 0 and (epoch + 1) % test_interval == 0:
            model.eval()
            with torch.no_grad():
                test_epoch_loss = 0
                test_metric = 0
                epoch_metric_test = 0
                test_step = 0

                all_preds = []
                all_labels = []
                for test_data in test_loader:
                    test_step += 1
                    test_volume = test_data["vol"]
                    test_label = test_data["seg"]
                    test_label = test_label != 0
                    test_volume, test_label = (test_volume.to(device), test_label.to(device),)

                    test_outputs = model(test_volume)

                    # Calculate test loss
                    test_loss = loss(test_outputs, test_label)
                    test_epoch_loss += test_loss.item()

                    # Calculate test dice metric
                    test_metric = dice_metric(test_outputs, test_label)
                    epoch_metric_test += test_metric

                    # Get predicted pixel-wise labels
                    _, preds = torch.max(test_outputs, dim=1)
                    all_preds.extend(preds.cpu().numpy().flatten())  # Flatten predicted labels
                    all_labels.extend(test_label.cpu().numpy().flatten())  # Flatten ground truth labels

                # Print test loss and test metric
                test_epoch_loss /= test_step
                print(f'test_loss_epoch: {test_epoch_loss:.4f}')

                epoch_metric_test /= test_step
                print(f'test_dice_epoch: {epoch_metric_test:.4f}')

                # Calculate and print confusion matrix
                cm = confusion_matrix(all_labels, all_preds)
                print('Confusion Matrix:')
                print(cm)

                test_epoch_loss /= test_step
                print(f'test_loss_epoch: {test_epoch_loss:.4f}')
                save_loss_test.append(test_epoch_loss)
                np.save(os.path.join(model_dir, 'loss_test.npy'), save_loss_test)

                epoch_metric_test /= test_step
                print(f'test_dice_epoch: {epoch_metric_test:.4f}')
                save_metric_test.append(epoch_metric_test)
                np.save(os.path.join(model_dir, 'metric_test.npy'), save_metric_test)

                if epoch_metric_test > best_metric:
                    best_metric = epoch_metric_test
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(), os.path.join(
                        model_dir, "best_metric_model.pth"))

                print(
                    f"current epoch: {epoch + 1} current mean dice: {test_metric:.4f}"
                    f"\nbest mean dice: {best_metric:.4f} "
                    f"at epoch: {best_metric_epoch}"
                )

    print(
        f"train completed, best_metric: {best_metric:.4f} "
        f"at epoch: {best_metric_epoch}")



def show_patient(data, SLICE_NUMBER=1, train=True, test=False):
    """
    This function is to show one patient from your datasets, so that you can si if the it is okay or you need
    to change/delete something.

    `data`: this parameter should take the patients from the data loader, which means you need to can the function
    prepare first and apply the transforms that you want after that pass it to this function so that you visualize
    the patient with the transforms that you want.
    `SLICE_NUMBER`: this parameter will take the slice number that you want to display/show
    `train`: this parameter is to say that you want to display a patient from the training data (by default it is true)
    `test`: this parameter is to say that you want to display a patient from the testing patients.
    """

    check_patient_train, check_patient_test = data

    view_train_patient = first(check_patient_train)
    view_test_patient = first(check_patient_test)


    if train:
        plt.figure("Visualization Train", (12, 6))
        plt.subplot(1, 2, 1)
        plt.title(f"vol {SLICE_NUMBER}")
        plt.imshow(view_train_patient["vol"][0, 0, :, :, SLICE_NUMBER], cmap="gray")

        plt.subplot(1, 2, 2)
        plt.title(f"seg {SLICE_NUMBER}")
        plt.imshow(view_train_patient["seg"][0, 0, :, :, SLICE_NUMBER])
        plt.show()

    if test:
        plt.figure("Visualization Test", (12, 6))
        plt.subplot(1, 2, 1)
        plt.title(f"vol {SLICE_NUMBER}")
        plt.imshow(view_test_patient["vol"][0, 0, :, :, SLICE_NUMBER], cmap="gray")

        plt.subplot(1, 2, 2)
        plt.title(f"seg {SLICE_NUMBER}")
        plt.imshow(view_test_patient["seg"][0, 0, :, :, SLICE_NUMBER])
        plt.show()


def calculate_pixels(data):
    val = np.zeros((1, 2))

    for batch in tqdm(data):
        batch_label = batch["seg"] != 0
        _, count = np.unique(batch_label, return_counts=True)

        if len(count) == 1:
            count = np.append(count, 0)
        val += count

    print('The last values:', val)
    return val

## Checking for CUDA

In [None]:
if torch.cuda.is_available():
    print("CUDA is available. You can use the GPU.")
else:
    print("CUDA is not available. You can only use the CPU.")

CUDA is available. You can use the GPU.


## Preprocess Dataset

In [6]:
data_dir = '/content/drive/MyDrive/Machine-Learning-Biomedicine/PankVision-3D/dataset/dataset-007/Data_Train_Test/'

In [7]:
data_in = prepare(data_dir, cache=True)

Loading dataset: 100%|██████████| 222/222 [12:10<00:00,  3.29s/it]
Loading dataset: 100%|██████████| 56/56 [05:34<00:00,  5.98s/it]


# Training

In [10]:
model_dir = '/content/drive/MyDrive/Machine-Learning-Biomedicine/PankVision-3D/results/dataset-007/v7'

In [17]:
import torch
from torch import nn
from monai.losses import DiceLoss
from sklearn.metrics import confusion_matrix

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

args = {
    'model_name': 'UNet',
    'pretrained': False,
    'dropout': 0.2
}
model = get_model(args)
model = model.to(device)

# loss_function = nn.CrossEntropyLoss().to(device)
loss_function = DiceLoss(to_onehot_y=True, softmax=True).to(device)
optimizer = torch.optim.Adam(model.parameters(), 1e-4, amsgrad=True)

if __name__ == '__main__':
    train(model, data_in, loss_function, optimizer, 150, model_dir)

----------
epoch 1/150
1/222, Train_loss: 0.7491
Train_dice: 0.3845
2/222, Train_loss: 0.7434
Train_dice: 0.3900
3/222, Train_loss: 0.7373
Train_dice: 0.3913
4/222, Train_loss: 0.7418
Train_dice: 0.3869
5/222, Train_loss: 0.7346
Train_dice: 0.3961
6/222, Train_loss: 0.7266
Train_dice: 0.3963
7/222, Train_loss: 0.7288
Train_dice: 0.3925
8/222, Train_loss: 0.7227
Train_dice: 0.3949
9/222, Train_loss: 0.7208
Train_dice: 0.3937
10/222, Train_loss: 0.7169
Train_dice: 0.3939
11/222, Train_loss: 0.7130
Train_dice: 0.3959
12/222, Train_loss: 0.7117
Train_dice: 0.3990
13/222, Train_loss: 0.7041
Train_dice: 0.4041
14/222, Train_loss: 0.7038
Train_dice: 0.4049
15/222, Train_loss: 0.7027
Train_dice: 0.3983
16/222, Train_loss: 0.7079
Train_dice: 0.3966
17/222, Train_loss: 0.7022
Train_dice: 0.3976
18/222, Train_loss: 0.7062
Train_dice: 0.3969
19/222, Train_loss: 0.7032
Train_dice: 0.3937
20/222, Train_loss: 0.7024
Train_dice: 0.3927
21/222, Train_loss: 0.7013
Train_dice: 0.3943
22/222, Train_loss: 

KeyboardInterrupt: ignored

## Check the Size of Data (Add to main) - For Debugging

In [None]:
# assuming prepare function has been imported from another file
train_loader, test_loader = prepare(data_dir)

for batch_data in  train_loader:
  volume = batch_data["vol"]
  label = batch_data["seg"]

  # print size of volume and label tensors
  print(f"Volume size: {volume.size()}")
  print(f"Label size: {label.size()}")

  # only process one batch
  break