<a href="https://colab.research.google.com/github/Marten-Verbree/Reproducing_OrganNet2.5D/blob/main/Rebuilding_OrganNet2_5D.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#imports
import torch
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
from torchsummary import summary
!pip install simpleitk=="2.0.0"
!pip install MedPy
from medpy.io import load
from scipy.ndimage import zoom



In [None]:
!pip freeze > requirements.txt

In [None]:
project_id = 'rebuildingorgannet25d'
from google.colab import auth
auth.authenticate_user()
import uuid
bucket_name = 'rebuildingbucket'

In [None]:
!gcloud config set project {project_id}

Updated property [core/project].


In [None]:
!gsutil -m cp -r gs://{bucket_name}/_data /content/_data/

Copying gs://rebuildingbucket/_data/HaN_MICCAI2015/processed/test_offsite/data_3D/0522c0555/img_resampled_0522c0555.mha...
Copying gs://rebuildingbucket/_data/HaN_MICCAI2015/processed/test_offsite/data_3D/0522c0555/mask_resampled_0522c0555.mha...
/ [0 files][    0.0 B/  1.5 GiB]                                                / [0/137 files][    0.0 B/  2.1 GiB]   0% Done                                  Copying gs://rebuildingbucket/_data/HaN_MICCAI2015/processed/test_offsite/data_3D/0522c0555/voxelinfo.json...
/ [0/137 files][    0.0 B/  2.1 GiB]   0% Done                                  Copying gs://rebuildingbucket/_data/HaN_MICCAI2015/processed/test_offsite/data_3D/0522c0576/img_resampled_0522c0576.mha...
/ [0/137 files][    0.0 B/  2.1 GiB]   0% Done                                  Copying gs://rebuildingbucket/_data/HaN_MICCAI2015/processed/test_offsite/data_3D/0522c0576/mask_resampled_0522c0576.mha...
/ [0/137 files][    0.0 B/  2.1 GiB]   0% Done                          

In [None]:
# !pip install zipfile36
# import zipfile
# with zipfile.ZipFile("file.zip","r") as zip_ref:
#     zip_ref.extractall("targetdir")

In [None]:
# train_loader

imgs = "/content/_data/_data/HaN_MICCAI2015/processed/train/data_3D/img_resampled.csv"
msks = "/content/_data/_data/HaN_MICCAI2015/processed/train/data_3D/mask_resampled.csv"
t_imgs = "/content/_data/_data/HaN_MICCAI2015/processed/test_offsite/data_3D/img_resampled.csv"
t_msks = "/content/_data/_data/HaN_MICCAI2015/processed/test_offsite/data_3D/mask_resampled.csv"
shape = (1, 800, 800, 232)

class TorchDataset(torch.utils.data.Dataset):

    def __init__(self, x_data, y_data, number_of_slices = 64):
        """
        Args:
            x_data (xarray): if test, 2000 by 6. if training, 8000 by 6 array.
            y_data (xarray): if test, 2000 by 3, if training, 8000 by 3 array.
        """
        imges = open(x_data, "r")
        images = imges.readlines()
        self.x_data = [x[:-1] for x in images]
        masks = open(y_data, "r")
        masks_read = masks.readlines()
        self.number_of_slices = number_of_slices
        self.y_data = [x[:-1] for x in masks_read]
        imges.close()
        masks.close()
    def __len__(self):
        return len(self.x_data)

    def __getitem__(self, idx):
        """
        Args:
            idx (int or 0D tensor): index of
        """
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        x,_ = load(self.x_data[idx])
        y,_ = load(self.y_data[idx])
        
        # Normalization
        mean = x.mean()
        std = x.std()
        x = (x-mean)/std

        # Resizing
        x = np.copy(x)
        x.resize(shape)  
        y = np.copy(y)
        y.resize(shape)

        # print('x:', slice_idx_x)
        # print('y:', slice_idx_y)
        
        x = zoom(x, (1, 0.5, 0.5, 1))
        y = zoom(y, (1, 0.5, 0.5, 1))
        
        # print('shape x: ', x_datapoint.shape)
        # print('shape y: ', y_datapoint.shape)
        sample = (x,y)
        return sample
training_dataset = TorchDataset(imgs, msks)
test_dataset = TorchDataset(t_imgs, t_msks)



In [None]:
#creating a dataloader
train_loader = torch.utils.data.DataLoader(training_dataset, batch_size=1, shuffle=False, num_workers=0)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)

In [None]:
class ConvBlock2D(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(ConvBlock2D, self).__init__()
    self.conv1 = nn.Conv3d(in_channels,out_channels, kernel_size=(3,3,1), padding='same') #unsure whether padding is used, assuming that it is
    self.relu1 = nn.ReLU()
    self.batchnorm = nn.BatchNorm3d(out_channels)
    self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=(3,3,1), padding='same') #not clearly mentioned in paper that it is out_channels to out_channels
    self.relu2 = nn.ReLU()
    self.batchnorm2 = nn.BatchNorm3d(out_channels)
  def forward(self, x):
    x = self.conv1(x)
    x = self.batchnorm(x)
    x = self.relu1(x)
    x = self.conv2(x)
    x = self.batchnorm2(x)
    x = self.relu2(x)
    return x


In [None]:
class ConvBlock3DResse(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(ConvBlock3DResse, self).__init__()
    self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding='same')
    self.relu1 = nn.ReLU()
    self.batchnorm1 = nn.BatchNorm3d(out_channels)
    self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding='same')
    self.relu2 = nn.ReLU()
    self.batchnorm2 = nn.BatchNorm3d(out_channels)
    self.globalpool = nn.AdaptiveAvgPool3d(output_size=1)
    self.flatten1 = nn.Flatten() #might have to change start_dim
    self.linear1 = nn.Linear(1, 1)
    self.relu3 = nn.ReLU()
    self.linear2 = nn.Linear(1, 1)
    self.sigmoid1 = nn.Sigmoid()
  def forward(self, x):
    x = self.conv1(x)
    x = self.batchnorm1(x)
    x = self.relu1(x)
    x = self.conv2(x)
    x = self.batchnorm2(x)
    x = self.relu2(x)
    x1 = self.globalpool(x)
    x1 = self.linear1(x1)
    x1 = self.relu3(x1)
    x1 = self.linear2(x1)
    x1 = self.sigmoid1(x1)
    return x1*x + x



In [None]:
class HybridDilatedConv3DResse(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(HybridDilatedConv3DResse, self).__init__()
    #different dilation rates? but how different?
    self.hdc = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding='same', dilation = 2)
    self.batchnorm1 = nn.BatchNorm3d(out_channels)
    self.relu1 = nn.ReLU()
    self.globalpool = nn.AdaptiveAvgPool3d(output_size=1)
    self.flatten1 = nn.Flatten() #might not be necessary, let's see what the output is of the globalpooling. I suppose this is already flattened (output_size=1).
    self.linear1 = nn.Linear(1, 1)
    self.relu2 = nn.ReLU()
    self.linear2 = nn.Linear(1, 1)
    self.sigmoid1 = nn.Sigmoid()
  def forward(self, x):
    x = self.hdc(x)
    x = self.batchnorm1(x)
    x = self.relu1(x)
    x1 = self.globalpool(x)
    x1 = self.linear1(x1)
    x1 = self.relu2(x1)
    x1 = self.linear2(x1)
    x1 = self.sigmoid1(x1)
    return x1*x + x


In [None]:
class Conv3Dfine(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(Conv3Dfine, self).__init__()
    self.conv3D = nn.Conv3d(in_channels, out_channels, kernel_size=1, padding='same')
    self.relu = nn.ReLU()
    self.batchnorm = nn.BatchNorm3d(out_channels)
  def forward(self, x):
    x = self.conv3D(x)
    x = self.batchnorm(x)
    x = self.relu(x)
    return x

critique: Very much uncertain what type of pooling is done, it is only specified for one layer. We asssume this to be constant. Based on advice of TA we do maxpooling, but it's bad that we have to rely on advice instead of just the paper. Also, it would be nice if strides and padding would be mentioned. 

In [None]:
torch.manual_seed(1)

<torch._C.Generator at 0x7f36b630cb90>

In [None]:
class TorchCNN(nn.Module):
  def __init__(self, in_channels, hidden_channels, out_features):
    super(TorchCNN, self).__init__()
    self.conv2D1 = ConvBlock2D(in_channels, hidden_channels[0])
    
    self.pool1 = nn.MaxPool3d(kernel_size=(2,2,1),stride=(2,2,1)) 
    self.conv3D_coarse1 = ConvBlock3DResse(hidden_channels[0],hidden_channels[1])
    self.pool2 = nn.MaxPool3d(kernel_size=2, stride =2)
    self.conv3D_coarse2 = ConvBlock3DResse(hidden_channels[1],hidden_channels[2])
    self.hdc1 = HybridDilatedConv3DResse(hidden_channels[2], hidden_channels[3])
    self.hdc2 = HybridDilatedConv3DResse(hidden_channels[3], hidden_channels[4])
    self.conv3D_fine1 = Conv3Dfine(hidden_channels[4], hidden_channels[3])
    self.hdc3 = HybridDilatedConv3DResse(hidden_channels[4], hidden_channels[3])
    self.conv3D_fine2 = Conv3Dfine(hidden_channels[3], hidden_channels[2])
    self.conv3D_coarse3 = ConvBlock3DResse(hidden_channels[3],hidden_channels[2])
    self.transpose1 = nn.ConvTranspose3d(hidden_channels[2], hidden_channels[1], kernel_size = 2, stride=2)
    self.conv3D_coarse4 = ConvBlock3DResse(hidden_channels[2], hidden_channels[1])
    self.transpose2 = nn.ConvTranspose3d(hidden_channels[1], hidden_channels[0], kernel_size = (2,2,1), stride=(2,2,1))
    self.conv2D2 = ConvBlock2D(hidden_channels[1], hidden_channels[1])
    self.conv3D_fine3 = Conv3Dfine(hidden_channels[1], out_features)
  def forward(self, x):
    x1 = self.conv2D1(x)
    x2 = self.pool1(x1)
    x2 = self.conv3D_coarse1(x2)
    x3 = self.pool2(x2)
    x3 = self.conv3D_coarse2(x3)
    x4 = self.hdc1(x3)
    x5 = self.hdc2(x4)
    x5 = self.conv3D_fine1(x5)    
    x5 = torch.cat((x4, x5), dim=1)
    x5 = self.hdc3(x5)
    x5 = self.conv3D_fine2(x5)
    x5 = torch.cat((x3, x5), dim=1)
    x5 = self.conv3D_coarse3(x5)
    x5 = self.transpose1(x5)
    x5 = torch.cat((x2,x5), dim =1)
    x5 = self.conv3D_coarse4(x5)
    x5 = self.transpose2(x5)
    x5 = torch.cat((x1,x5), dim = 1)
    x5 = self.conv2D2(x5)
    x5 = self.conv3D_fine3(x5)
    return x5
    #Global average pooling is not implemented, apparently this is an alternative: 
#   #x = torch.randn(16, 14, 14)
    # out = F.adaptive_max_pool2d(x.unsqueeze(0), output_size=1)

In [None]:
# in_channels = 1
# hidden_channels = [16, 32, 64, 128, 256]
# out_channels = 10 # for Miccai data set
# CNN = TorchCNN(in_channels, hidden_channels, out_channels)
# x = torch.randn((1, 1, 100, 100, 232))
# out = CNN(x)
# print(out.shape)

In [None]:
def dice_coef(y_true, y_pred, epsilon=1e-6, testing = False):
    # Computes the Sørensen-dice score coefficien(DSC).
    #   DSC = (2*(|X&Y|)\(|X| + |Y|)
    #      = 2*sum(|A*B|)/(sum(A^2 + sum(B^2)
    #    ref: https://github.com/shalabh147/Brain-Tumor-Segmentation-and-Survival-Prediction-using-Deep-Neural-Networks/blob/master/utils.py
    #   ref: https://gist.github.com/jeremyjordan/9ea3032a32909f71dd2ab35fe3bacc08#file-soft_dice_loss-py

    #Args:
    #  :param y_true: is a tensor [H, W, D, L] with the ground truth of the OAR
    #  :param y_pred: is a tensor [H, W, D, L] with the predicted area of the OAR
    #  :param epsilon: Used for numerical stability to avoid divide by zeros.
    
    dice_labels = torch.zeros(y_pred.shape)

    y_pred_label = y_pred
    y_true_label = y_true
    if torch.sum(y_true_label) > 0 and not testing:
          dice_numerator = 2 * torch.sum(y_true_label * y_pred_label)
          dice_denominator = 10*torch.sum(y_true_label + y_pred_label) + epsilon
          dice_score = dice_numerator/dice_denominator
          # print('dice score:', dice_score)
    elif torch.sum(y_true_label)>0 and testing:
          dice_numerator = 2 * torch.sum(y_true_label * y_pred_label, dim = (2,3,4))
          dice_denominator = torch.sum(y_true_label + y_pred_label, dim = (2,3,4)) + epsilon
          dice_score = dice_numerator/dice_denominator
          # print('dice scores: ', dice_score)
    elif testing:
          dice_score = torch.zeros(10)
    else:
          dice_score = 0
    return dice_score

def dice_coef_loss(y_true, y_pred):
    dice_coef1 = dice_coef(y_true, y_pred)
    # print('dice_coef loss:', 1 - dice_coef1)
    return 1 - dice_coef1


def focal_loss(y_true, y_pred,  epsilon =1e-6):
    """ Computes the focal loss.
            FL(p_t) = mean(-alpha(1-p_t^gamma)* y *ln(p_t)
            Notice: y_pred is probability after using softmax
            ref: https://arxiv.org/pdf/2109.12634.pdf

        Args:
            :param y_true: is a tensor [H, W, D, L] with the ground truth of the OAR
            :param y_pred: is a tensor [H, W, D, L] with the predicted area of the OAR
            :param epsilon: Used for numerical stability to avoid divide by zeros
            :param gamma: Focal Tversky loss' focal parameter controls degree of down-weighting of easy examples, by default 2.0
            :param ALPHA: assigned weights according to Chen et al. (2021)
        """

    ALPHA = torch.tensor([0.5, 1.0, 4.0, 1.0, 4.0, 4.0, 1.0, 1.0, 3.0, 3.0]) 
    GAMMA = 2

    loss_labels = torch.zeros((ALPHA.shape[0],100,100,232))

    for i in range(y_pred.shape[1]):

        y_pred_label = y_pred[:, i]
        y_true_label =  y_true[:, i]

        y_pred_clamp = torch.softmax(y_pred_label, dim=1) 
        # print('y_pred_clamp', y_pred_clamp)
        cross_entropy = -y_true_label * torch.log(y_pred_clamp + epsilon)
        # print('cross_entropy', cross_entropy)
        back_ce = torch.pow(1 - y_pred_clamp, GAMMA) * cross_entropy
        
        focal_loss_label = torch.mul(ALPHA[i], back_ce)
        # print('focal_loss_label', focal_loss_label)

        loss_labels[i] = focal_loss_label

    loss = torch.mean(loss_labels)
    # print('focal loss:', loss)
    return loss


def final_loss(y_true, y_pred):
    return focal_loss(y_true, y_pred) + dice_coef_loss(y_true, y_pred)


In [None]:
def try_gpu():
    """
    If GPU is available, return torch.device as cuda:0; else return torch.device
    as cpu.
    """
    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')
    return device

In [None]:
#training parameters
learning_rate = 0.001
epochs = 6
train_losses = []
train_accs = []
test_accs = []
in_channels = 1
hidden_channels = [16, 32, 64, 128, 256]
out_channels = 10 # for Miccai data set
CNN = TorchCNN(in_channels, hidden_channels, out_channels)
CNN = CNN.float()

In [None]:


optimizer = torch.optim.Adam(CNN.parameters(), lr = learning_rate) # check what optimizer is used in paper
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)
device = try_gpu()
CNN.train()
CNN.to(device)
for epoch in range(epochs):

    # Network in training mode and to device
    

    train_dsc = torch.zeros((1,10)).to(device)
    # Training loop
    for i, (x_batch, y_batch) in enumerate(train_loader):
        
        # Set to same device
        x_batch, y_batch = x_batch.to(device).float(), y_batch.to(device).float()
        for idx_h in range(4):
          for idx_w in range(4):
            x_batch_mod = x_batch[:,:,idx_h*100:100*idx_h+100, idx_w*100:100*idx_w+100,:]
            y_batch_mod = y_batch[:,:,idx_h*100:100*idx_h+100, idx_w*100:100*idx_w+100,:]
        # Set the gradients to zero
            optimizer.zero_grad()

        # Perform forward pass
            y_pred = CNN(x_batch_mod)
            del x_batch_mod #emptying memory
            # Compute the loss
            loss = final_loss(y_pred, y_batch_mod)
            train_dsc += 100/len(train_loader)*dice_coef(y_pred, y_batch_mod, testing=True).to(device) #change to dice score
            # print(train_dsc)
            del y_batch_mod
            del y_pred
            # Backward computation and update
            loss.backward()
            train_losses.append(loss.detach().item())
            del loss
            optimizer.step()
        # if i % 32 == 0:
        #     print('loss:', train_losses[64*epoch+i])
        del x_batch
        del y_batch
        
    

    scheduler.step()
    
    
    # Development of performance
    train_accs.append(train_dsc.detach().tolist())
    CNN.eval()
    test_dsc = torch.zeros((1,10)).to(device)
    with torch.no_grad():
        for i, (x_batch, y_batch) in enumerate(test_loader):
            x_batch, y_batch = x_batch.to(device).float(), y_batch.to(device).float()
            for idx_h in range(4):
              for idx_w in range(4):
                x_batch_mod = x_batch[:,:,idx_h*100:100*idx_h+100, idx_w*100:100*idx_w+100,:]
                y_batch_mod = y_batch[:,:,idx_h*100:100*idx_h+100, idx_w*100:100*idx_w+100,:]
                y_pred = CNN(x_batch_mod)
                del x_batch_mod
                test_dsc +=  torch.mul(100/len(test_loader), dice_coef(y_pred, y_batch_mod, testing=True)).to(device) #change to dice score
                del y_batch_mod
                del y_pred
            del y_batch
            del x_batch
            
    test_accs.append(test_dsc.detach().tolist())
    print('Dice score of test set: {:.00f}%'.format(test_dsc.mean()))
    del test_dsc
    # Print performance
    print('Epoch: {:.0f}'.format(epoch+1))
    print('Dice score of train set: {:.00f}%'.format(train_dsc.mean()))
    del train_dsc
    print('loss of train set:', train_losses[-1])
    print('')

Dice score of test set: 28%
Epoch: 1
Dice score of train set: 1015%
loss of train set: 0.9977800846099854

Dice score of test set: 0%
Epoch: 2
Dice score of train set: 195%
loss of train set: 1.0

Dice score of test set: 0%
Epoch: 3
Dice score of train set: 0%
loss of train set: 1.0

Dice score of test set: 0%
Epoch: 4
Dice score of train set: 0%
loss of train set: 1.0

Dice score of test set: 0%
Epoch: 5
Dice score of train set: 0%
loss of train set: 1.0



In [None]:
#training cycle:
# print('Epoch: {:.0f}'.format(epoch+1))
# print('Accuracy of train set: {:.00f}%'.format(train_dsc))
# # print('Accuracy of test set: {:.00f}%'.format(test_dsc))
# print('')

In [None]:
PATH = "/content/"
torch.save(CNN.state_dict(), PATH + 'model_params')
list_ = train_losses
torch.save(list_, PATH +'training_loss')
torch.save(train_accs/16, PATH + 'training_dscs')
torch.save(test_accs/16, PATH + 'test_dscs')

In [None]:
in_channels = 1
hidden_channels = [16, 32, 64, 128, 256]
out_channels = 10 # for Miccai data set
Test_CNN = TorchCNN(in_channels, hidden_channels, out_channels)
Test_CNN = CNN.float()
Test_CNN.load_state_dict(torch.load(PATH))
Test_CNN.eval()
test_dsc = 0
with torch.no_grad():
  for i, (x_batch, y_batch) in enumerate(test_loader):
    x_batch, y_batch = x_batch.to(device).float(), y_batch.to(device).float()
        
    y_pred = Test_CNN(x_batch)
    del x_batch
    test_dsc += 100/len(test_loader)*dice_coef(y_pred, y_batch) #change to dice score
    del y_batch
    del y_pred
print('Dice score of test set: {:.00f}%'.format(test_dsc))

In [None]:
train_losses = torch.load('/content/drive/MyDrive/_model/training_loss')
print(train_losses)