# Install Dependencies

In [2]:
%%capture
!pip install monai
!pip install dicom2nifti

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!cp /content/drive/MyDrive/Machine-Learning-Biomedicine/PankVision-3D/preprocess/final_preprocess.py /content
!cp /content/drive/MyDrive/Machine-Learning-Biomedicine/PankVision-3D/utilities/final_utilities.py /content
!cp /content/drive/MyDrive/Machine-Learning-Biomedicine/PankVision-3D/model/get_model.py /content

In [None]:
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.losses import DiceLoss, DiceCELoss

import torch
from final_utilities import train
from get_model import get_model
from final_preprocess import preprocess_image, prepare

import numpy as np
from skimage import exposure
from monai.transforms import (
    Compose,
    LoadImaged,
    AddChanneld,
    Spacingd,
    Orientationd,
    ScaleIntensityRanged,
    CropForegroundd,
    Resized,
    ToTensord,
)
from monai.data import CacheDataset, Dataset, DataLoader
from glob import glob
import os
from monai.utils import set_determinism

## Checking for CUDA

In [None]:
if torch.cuda.is_available():
    print("CUDA is available. You can use the GPU.")
else:
    print("CUDA is not available. You can only use the CPU.")

CUDA is available. You can use the GPU.


## Preprocess Dataset

In [None]:
data_dir = '/content/drive/MyDrive/Machine-Learning-Biomedicine/PankVision-3D/dataset/dataset-007/Data_Train_Test/'

In [None]:
data_in = prepare(data_dir, cache=True)

Loading dataset: 100%|██████████| 222/222 [08:43<00:00,  2.36s/it]
Loading dataset: 100%|██████████| 56/56 [02:51<00:00,  3.06s/it]


# Testing New Architecture

In [None]:
import torch
import torch.nn as nn

class conv_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

class encoder_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.conv = conv_block(in_c, out_c)
        self.pool = nn.MaxPool2d((2, 2))

    def forward(self, x):
        s = self.conv(x)
        p = self.pool(s)
        return s, p

class attention_gate(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.Wg = nn.Sequential(
            nn.Conv2d(in_c[0], out_c, kernel_size=1, padding=0),
            nn.BatchNorm2d(out_c)
        )
        self.Ws = nn.Sequential(
            nn.Conv2d(in_c[1], out_c, kernel_size=1, padding=0),
            nn.BatchNorm2d(out_c)
        )
        self.relu = nn.ReLU(inplace=True)
        self.output = nn.Sequential(
            nn.Conv2d(out_c, out_c, kernel_size=1, padding=0),
            nn.Sigmoid()
        )

    def forward(self, g, s):
        Wg = self.Wg(g)
        Ws = self.Ws(s)
        out = self.relu(Wg + Ws)
        out = self.output(out)
        return out * s

class decoder_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.ag = attention_gate(in_c, out_c)
        self.c1 = conv_block(in_c[0]+out_c, out_c)

    def forward(self, x, s):
        x = self.up(x)
        s = self.ag(x, s)
        x = torch.cat([x, s], axis=1)
        x = self.c1(x)
        return x

class attention_unet(nn.Module):
    def __init__(self):
        super().__init__()

        self.e1 = encoder_block(3, 64)
        self.e2 = encoder_block(64, 128)
        self.e3 = encoder_block(128, 256)

        self.b1 = conv_block(256, 512)

        self.d1 = decoder_block([512, 256], 256)
        self.d2 = decoder_block([256, 128], 128)
        self.d3 = decoder_block([128, 64], 64)

        self.output = nn.Conv2d(64, 1, kernel_size=1, padding=0)

    def forward(self, x):
        s1, p1 = self.e1(x)
        s2, p2 = self.e2(p1)
        s3, p3 = self.e3(p2)

        b1 = self.b1(p3)

        d1 = self.d1(b1, s3)
        d2 = self.d2(d1, s2)
        d3 = self.d3(d2, s1)

        output = self.output(d3)
        return output

class TverskyLoss(nn.Module):
    def __init__(self, alpha=0.5, beta=0.5):
        super(TverskyLoss, self).__init__()
        self.alpha = alpha
        self.beta = beta

    def forward(self, input, target):
        smooth = 1e-5
        input = torch.sigmoid(input)

        # Flatten input and target tensors
        input = input.view(-1)
        target = target.view(-1)

        # True Positives, False Positives & False Negatives
        TP = (input * target).sum()
        FP = ((1-target) * input).sum()
        FN = (target * (1-input)).sum()

        Tversky = (TP + smooth) / (TP + self.alpha*FP + self.beta*FN + smooth)

        return 1 - Tversky

if __name__ == "__main__":
    x = torch.randn((8, 3, 256, 256))
    model = attention_unet()
    output = model(x)
    print(output.shape)

# Training

In [None]:
model_dir = '/content/drive/MyDrive/Machine-Learning-Biomedicine/PankVision-3D/results/dataset-007/v6 (dynunet)'

In [1]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

args = {
    'model_name': 'DynUNet',
    'pretrained': True,
    'dropout': 0.1
}
model = get_model(args)
model = model.to(device)

# loss_function = nn.CrossEntropyLoss().to(device)
# loss_function = TverskyLoss(alpha=0.7, beta=0.3).to(device)
loss_function = DiceLoss(to_onehot_y=True, softmax=True).to(device)
optimizer = torch.optim.Adam(model.parameters(), 1e-4, weight_decay=1e-5, amsgrad=True)

if __name__ == '__main__':
    train(model, data_in, loss_function, optimizer, 150, model_dir)

NameError: ignored

## Check the Size of Data (Add to main) - For Debugging

In [None]:
# assuming prepare function has been imported from another file
train_loader, test_loader = prepare(data_dir)

for batch_data in  train_loader:
  volume = batch_data["vol"]
  label = batch_data["seg"]

  # print size of volume and label tensors
  print(f"Volume size: {volume.size()}")
  print(f"Label size: {label.size()}")

  # only process one batch
  break