In [None]:
# Torch modules
import torch
from torch import nn
import torch.utils.data
import torchvision
import torchvision.transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.optim import lr_scheduler
from torch.utils.data import Dataset

# Matplotlib modules
import matplotlib.pyplot as plt
import matplotlib.image as img

# numpy and pandas
import numpy as np
import pandas as pd

# Common python modules
import os
import re
import datetime
import sys
import time
from tqdm import tqdm
from collections import OrderedDict


In [None]:
# Import the data
!git clone https://github.com/soniamartinot/MVA-Dose-Prediction.git

In [None]:
train_dir='./MVA-Dose-Prediction/train/'
test_dir='./MVA-Dose-Prediction/test/'
val_dir='./MVA-Dose-Prediction/validation/'

In [None]:
# Creation of the train, test and validation dataset

class DoseDataset(torch.utils.data.Dataset):
  def __init__(self,data_path):
    self.data_path=data_path
    self.samples=sorted(os.listdir(data_path),key=lambda s:int(re.search(r'\d+',s).group()))
  

  def __getitem__(self,idx):
        sample_path = self.data_path + os.sep + self.samples[idx]
        ct_scan_np=np.load(sample_path + os.sep + 'ct.npy')
        possible_dose_mask_np=np.load(sample_path + os.sep + 'possible_dose_mask.npy')
        structure_masks_np=np.load(sample_path + os.sep + 'structure_masks.npy')

        ct_scan_np=np.multiply(ct_scan_np,possible_dose_mask_np)  #We apply the possible_dose_mask to the ct_scan

        for i in range(10):
          structure_masks_np[i]=np.multiply(ct_scan_np,structure_masks_np[i])       #We apply each mask on the new ct_scan

        ct_scan = torch.from_numpy(ct_scan_np).float().unsqueeze(0)
        possible_dose_mask = torch.from_numpy(possible_dose_mask_np).float().unsqueeze(0)
        dose = torch.from_numpy(np.load(sample_path + os.sep + 'dose.npy')).float().unsqueeze(0)
        structure_masks = torch.from_numpy(structure_masks_np).float()

        inp=torch.cat((ct_scan,structure_masks,possible_dose_mask),0).float()

        return inp, dose


  def __len__(self):
    return len(self.samples)

class DoseDatasetTest(torch.utils.data.Dataset):

  def __init__(self,data_path):
    self.data_path=data_path
    self.samples=sorted(os.listdir(data_path),key=lambda s:int(re.search(r'\d+',s).group()))
  
  def __getitem__(self,idx):
        sample_path = self.data_path + os.sep + self.samples[idx]
        ct_scan_np=np.load(sample_path + os.sep + 'ct.npy')
        possible_dose_mask_np=np.load(sample_path + os.sep + 'possible_dose_mask.npy')
        structure_masks_np=np.load(sample_path + os.sep + 'structure_masks.npy')

        ct_scan_np=np.multiply(ct_scan_np,possible_dose_mask_np)

        for i in range(10):
          structure_masks_np[i]=np.multiply(ct_scan_np,structure_masks_np[i])

        ct_scan = torch.from_numpy(ct_scan_np).float().unsqueeze(0)
        possible_dose_mask = torch.from_numpy(possible_dose_mask_np).float().unsqueeze(0)
        structure_masks = torch.from_numpy(structure_masks_np).float()

        inp=torch.cat((ct_scan,structure_masks,possible_dose_mask),0).float()
        
        return inp

  def __len__(self):
    return len(self.samples)

In [None]:
# Create dataloaders
batch_size = 16


train_dataset = DoseDataset(train_dir)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,drop_last=True)


val_dataset = DoseDataset(val_dir)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,drop_last=True)

# Dataset test for the Codalab test

test_dataset=DoseDatasetTest(test_dir)
dataset_test=DataLoader(test_dataset,batch_size=1,shuffle=False)

In [None]:
class Squeeze_Excite(nn.Module):
    
    def __init__(self,channel,reduction):
        super().__init__()
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.lin1 = nn.Linear(channel, channel // reduction, bias=False)
        self.relu = nn.ReLU(inplace=True)   # inplace = True decreases memory usage 
        self.lin2 = nn.Linear(channel // reduction, channel, bias=False)
        self.sig = nn.Sigmoid()
 

        
    def forward(self,x):
        b, c, _, _ = x.size()
        out = self.avgpool(x).view(b, c)
        out = self.lin1(out)
        out = self.relu(out)
        out = self.lin2(out)
        out = self.sig(out).view(b,c,1,1)
        return x * out.expand_as(x)  #Resizing
        

In [None]:
class VGGBlock(nn.Module):
    
    def __init__(self, in_channels, middle_channels, out_channels):
        super().__init__()
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_channels, middle_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(middle_channels)
        self.conv2 = nn.Conv2d(middle_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.SE = Squeeze_Excite(out_channels,reduction = 8)
    
    def forward(self,x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.SE(out)
      
        return out

In [None]:
class Unetplus(nn.Module):
    
    def __init__(self, input_channels=12, output_size=1, **kwargs):
        super().__init__()

        nb_channels = [64, 128, 256, 512, 1024]

        self.pool = nn.MaxPool2d(2, 2)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.conv0_0 = VGGBlock(input_channels, nb_channels[0], nb_channels[0])
        self.conv1_0 = VGGBlock(nb_channels[0], nb_channels[1], nb_channels[1])
        self.conv2_0 = VGGBlock(nb_channels[1], nb_channels[2], nb_channels[2])
        self.conv3_0 = VGGBlock(nb_channels[2], nb_channels[3], nb_channels[3])
        self.conv4_0 = VGGBlock(nb_channels[3], nb_channels[4], nb_channels[4])

        self.conv0_1 = VGGBlock(nb_channels[0]+nb_channels[1], nb_channels[0], nb_channels[0])
        self.conv1_1 = VGGBlock(nb_channels[1]+nb_channels[2], nb_channels[1], nb_channels[1])
        self.conv2_1 = VGGBlock(nb_channels[2]+nb_channels[3], nb_channels[2], nb_channels[2])
        self.conv3_1 = VGGBlock(nb_channels[3]+nb_channels[4], nb_channels[3], nb_channels[3])
        self.conv0_2 = VGGBlock(nb_channels[0]*2+nb_channels[1], nb_channels[0], nb_channels[0])
        self.conv1_2 = VGGBlock(nb_channels[1]*2+nb_channels[2], nb_channels[1], nb_channels[1])
        self.conv2_2 = VGGBlock(nb_channels[2]*2+nb_channels[3], nb_channels[2], nb_channels[2])
        self.conv0_3 = VGGBlock(nb_channels[0]*3+nb_channels[1], nb_channels[0], nb_channels[0])
        self.conv1_3 = VGGBlock(nb_channels[1]*3+nb_channels[2], nb_channels[1], nb_channels[1])
        self.conv0_4 = VGGBlock(nb_channels[0]*4+nb_channels[1], nb_channels[0], nb_channels[0])


        self.final = nn.Conv2d(nb_channels[0], output_size, kernel_size=1)
        
    def forward(self, input):
        x0_0 = self.conv0_0(input)
        x1_0 = self.conv1_0(self.pool(x0_0))
        x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))

        x2_0 = self.conv2_0(self.pool(x1_0))
        x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))
        x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))

        x3_0 = self.conv3_0(self.pool(x2_0))
        x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1))
        x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))
        x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))

        x4_0 = self.conv4_0(self.pool(x3_0))
        x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
        x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1))
        x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))
        x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))


        output = self.final(x0_4)
        
        return output

In [None]:
model=Unetplus()
model.cuda()

In [None]:
# Number of parameters of our model

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

count_parameters(model)

In [None]:
learning_rate=1e-4

criterion=nn.L1Loss().cuda()
optimizer=torch.optim.Adam(model.parameters(),lr=learning_rate)

In [None]:
def print_summary(epoch, i, nb_batch, loss, batch_time, 
                  average_loss, average_time, mode):
    '''
        Print the losses and the computing time during the training or the validation
    '''
    summary = '[' + str(mode) + '] Epoch: [{0}][{1}/{2}]\t'.format(
        epoch, i, nb_batch)

    string = ''
    string += ('L1 Loss {:.4f} ').format(loss)
    string += ('(Average {:.4f}) \t').format(average_loss)
    string += ('Batch Time {:.4f} ').format(batch_time)
    string += ('(Average {:.4f}) \t').format(average_time)

    summary += string
    print(summary)

In [None]:
# Train the model
def train_loop(loader, model, criterion, optimizer, epoch):

    logging_mode = 'Train' if model.training else 'Val'
    if model.training:print('training')
    
    epoch_time_sum, epoch_loss_sum = [], []
    
    for i, sample in enumerate(loader, 1):
        start = time.time()
        # Take variable 
        (inp, dose) = sample

        # Put variables to GPU
        inp = inp.float().cuda()
        dose = dose.float().cuda()
    
        # Compute model prediction
        pred_dose = model(inp)

        # Compute loss
        loss = criterion(pred_dose, dose)

        # If in training mode ...
        if model.training:
            # Initialize optimizer gradients to zero
            optimizer.zero_grad()
            # Perform backpropagation
            loss.backward()
            # Update the model's trainable parameters using the computed gradients
            optimizer.step()


        # Compute elapsed time
        batch_time = time.time() - start

        epoch_time_sum += [batch_time]
        epoch_loss_sum += [loss.item()]
        
        average_time = np.mean(epoch_time_sum)
        average_loss = np.mean(epoch_loss_sum)


        # Display the loss and the time 
        if i % print_frequency == 0:
            print_summary(epoch + 1, i, len(loader), loss, batch_time,
                          average_loss, average_time,logging_mode)
        step = epoch*len(loader) + i
            
    return np.mean(epoch_loss_sum)

In [None]:
epochs = 40
print_frequency = 20
train_loss,val_loss=[],[]

In [None]:
for epoch in range(epochs):  #  Training loop
    print('******** Epoch [{}/{}]  ********'.format(epoch+1, epochs+1))
    model.train()
    print('Training')
    train_loss.append(train_loop(train_dataloader, model, criterion, optimizer, epochs))

    # Evaluate on validation set
    print('Validation')
    with torch.no_grad():   # Disable gradient computatio
        model.eval()        
        val_loss.append(train_loop(val_dataloader, model, criterion, optimizer, epochs))



In [None]:
# Generate the predictions and save them
dir='Unet++results'
if not os.path.exists("./"+dir):
    os.makedirs("./"+dir)
for i,batch in enumerate(dataset_test):
  img=model(batch[0].unsqueeze(0).float().cuda())[0][0]
  time.sleep(0.1)
  np.save('./'+dir+'/sample_'+str(9000+i)+'.npy',img.cpu().detach().numpy())

In [None]:
plt.plot(train_loss)

In [None]:
plt.plot(val_loss)