In [59]:
import os
from glob import glob
import torch
import torch_em
from torch_em.model import UNet3d
from torch.utils.data import Dataset, DataLoader
import random

In [60]:
%matplotlib inline
%load_ext tensorboard
import os
import imageio
import shutil
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from scipy.ndimage import binary_erosion
import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import torch.nn as nn
from torch.nn import functional as F
from torchvision import transforms
from tensorflow import keras


The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [118]:
#train_images = glob(os.path.join('./rescaled' , 'train' , 'images' , '*.tif'))
#train_labels = glob(os.path.join('./rescaled' , 'train' , 'labels' , '*.tif'))

In [19]:
dataset = './rescaled'
train_images_path = os.path.join(dataset , 'train' , 'images')
val_images_path = os.path.join(dataset , 'validation' , 'images')
test_images_path = os.path.join(dataset , 'test' , 'images')


In [20]:
class DrosophilaDataset(Dataset): #Dataset from torch
    def __init__(self , root_dir , transform=None):
        self.root_dir = root_dir
        self.samples = os.listdir(root_dir)
        self.transform = transform
        #self.inp_transforms = transforms.Compose([transforms.Grayscale(), # some of the images are RGBtransforms.ToTensor(),transforms.Normalize([0.5], [0.5])])
        #self.inp_transforms = transforms.ToTensor()
        #self.mask_transforms = transforms.ToTensor()

        
    def __len__(self):
        return len(self.samples)
        
    def __getitem__(self , idx):
        im_name = self.samples[idx]
        im_path = os.path.join(self.root_dir , im_name)
        image = imageio.volread(im_path)
        #image = np.expand_dims(image,axis=0)
        image = torch.Tensor(image) # Transform image to tensor
        
        # Look at masks
        labels_dir = self.root_dir.replace('images' , 'labels')
        label_name = im_name.replace('Rec' , 'Rec_labels')
        label_path = os.path.join(labels_dir , label_name)
        
        label = imageio.volread(label_path)
        label = torch.Tensor(label)
        #label = self.mask_transforms(label)
        label[label == 4] = 3 #converts ovaries to same label
        
        image = torch.unsqueeze(image, dim=0)
        #label = torch.unsqueeze(label, dim=0)

    
        return image , label
    

In [None]:
def show_random_dataset_image(dataset):
    idx = np.random.randint(0, len(dataset))    # take a random sample
    img, label = dataset[idx]                    # get the image and the nuclei masks
    f, axarr = plt.subplots(1, 2)               # make two plots on one figure
    axarr[0].imshow(img[0])                     # show the image
    axarr[1].imshow(label[0])                    # show the masks
    _ = [ax.axis('off') for ax in axarr]        # remove the axes
    print('Image size is %s' % {img[0].shape})
    plt.show()
    
#show_random_dataset_image(train_data)

In [21]:
train_data = DrosophilaDataset(root_dir = train_images_path)
train_loader = DataLoader(train_data, batch_size=1, shuffle=True)

val_data = DrosophilaDataset(root_dir = val_images_path)
val_loader = DataLoader(val_data, batch_size=1)

test_data = DrosophilaDataset(root_dir = test_images_path)
test_loader = DataLoader(test_data, batch_size=1)

# U-NET

In [22]:
class UNet(nn.Module):
    """ UNet implementation
    Arguments:
      in_channels: number of input channels
      out_channels: number of output channels
      final_activation: activation applied to the network output
    """
   
    def _conv_block(self, in_channels, out_channels):
        return nn.Sequential(nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
                             nn.ReLU(),
                             nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
                             nn.ReLU())       


    # upsampling via transposed 3d convolutions
    def _upsampler(self, in_channels, out_channels):
        return nn.ConvTranspose3d(in_channels, out_channels,
                                kernel_size=2, stride=2)
    
    def __init__(self, in_channels=1, out_channels=5, 
                 final_activation=None):
        super().__init__()
    
        self.depth = 4

        # the final activation must either be None or a Module
        if final_activation is not None:
            assert isinstance(final_activation, nn.Module), "Activation must be torch module"
        
        # all lists of conv layers (or other nn.Modules with parameters) must be wraped
        # itnto a nn.ModuleList
        
        # modules of the encoder path
        self.encoder = nn.ModuleList([self._conv_block(in_channels, 16),
                                      self._conv_block(16, 32),
                                      self._conv_block(32, 64),
                                      self._conv_block(64, 128)])
        # the base convolution block
        self.base = self._conv_block(128, 256)
        # modules of the decoder path
        self.decoder = nn.ModuleList([self._conv_block(256, 128),
                                      self._conv_block(128, 64),
                                      self._conv_block(64, 32),
                                      self._conv_block(32, 16)])
        
        # the pooling layers; we use 2x2 MaxPooling
        self.poolers = nn.ModuleList([nn.MaxPool3d(2) for _ in range(self.depth)])
        # the upsampling layers
        self.upsamplers = nn.ModuleList([self._upsampler(256, 128),
                                         self._upsampler(128, 64),
                                         self._upsampler(64, 32),
                                         self._upsampler(32, 16)])
        # output conv and activation
        # the output conv is not followed by a non-linearity, because we apply
        # activation afterwards
        self.out_conv = nn.Conv3d(16, out_channels, 1)
        self.activation = final_activation
    
    def forward(self, input):
        x = input
        # apply encoder path
        encoder_out = []
        for level in range(self.depth):
            x = self.encoder[level](x)
            encoder_out.append(x)
            x = self.poolers[level](x)

        # apply base
        x = self.base(x)
        
        # apply decoder path
        encoder_out = encoder_out[::-1]
        for level in range(self.depth):
            x = self.upsamplers[level](x)
            x = self.decoder[level](torch.cat((x, encoder_out[level]), dim=1))
        
        # apply output conv and activation (if given)
        x = self.out_conv(x)
        if self.activation is not None:
            x = self.activation(x)
        return x


In [32]:
# check if we have  a gpu
if torch.cuda.is_available():
    print("GPU is available")
    device = torch.device("cuda")
else:
    print("GPU is not available")
    device = torch.device("cpu")
    
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')
#

GPU is available
Using cuda device


In [33]:
model = UNet(in_channels = 1 , out_channels=4)
model.to(device)

UNet(
  (encoder): ModuleList(
    (0): Sequential(
      (0): Conv3d(1, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (1): ReLU()
      (2): Conv3d(16, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (3): ReLU()
    )
    (1): Sequential(
      (0): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (1): ReLU()
      (2): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (3): ReLU()
    )
    (2): Sequential(
      (0): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (1): ReLU()
      (2): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (3): ReLU()
    )
    (3): Sequential(
      (0): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (1): ReLU()
      (2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (3): ReLU()
    )
  )
  (base): Sequential(
 

# Train & Test Loop

In [112]:
def train_loop(dataloader, model, loss_fn, optimizer , epoch , tb_logger=tb_logger):
    size = len(dataloader.dataset)
    model.train() #set the model to train mode
    # iterate over the batches of this epoch
    for batch_id, (x, y) in enumerate(dataloader):
        # move input and target to the active device (either cpu or gpu)
        #X, y = X.to(device), y.to(device)
        
        # zero the gradients for this iteration
        #optimizer.zero_grad()
        # apply model, calculate loss and run backwards pass
        y = y.long()
        x, y = x.to(device), y.to(device)

        
        prediction = model(x) # apply model
        loss = loss_fn(prediction , y) # calculate loss
        #loss = loss_fn(prediction , y[:0])#Expected target size [1, 128, 128, 128], got [1, 1, 128, 128, 128]
        loss.backward()
        optimizer.step()
        
        # Log to console
        log_interval = 100
        if batch_id % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                  epoch, batch_id * len(x),
                  len(dataloader.dataset),
                  100. * batch_id / len(dataloader), loss.item()))
            
        # Log to tensorboard
        log_image_interval=20
        if tb_logger is not None:
            step = epoch * len(dataloader) + batch_id
            tb_logger.add_scalar(tag='train_loss', scalar_value=loss.item(), global_step=step)
            # check if we log images in this iteration
            if step % log_image_interval == 0:
                x_new = x[0][0][64]
                x_new = x_new[None , None ,:]
                tb_logger.add_images(tag='input', img_tensor = x_new.to('cpu'), global_step=step)
                #tb_logger.add_images(tag='target', img_tensor = y.to('cpu'), global_step=step)
                #tb_logger.add_images(tag='prediction', img_tensor=prediction.to('cpu').detach(), global_step=step)
        
def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0
    
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    
    

def validate(model, loader, loss_function, metric,  step=None, tb_logger=tb_logger):
    model.eval
    # running loss and metric values
    val_loss = 0
    val_metric = 0
    
    # disable gradients during validation
    with torch.no_grad():
        # iterate over validation loader and update loss and metric values
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            prediction = model(x)
            y = y.long()
            val_loss += loss_function(prediction , y).item()
            val_metric += metric(prediction, y).item()

    
    # normalize loss and metric
    val_loss /= len(loader)
    val_metric /= len(loader)
    
    if tb_logger is not None:
        tb_logger.add_scalar(tag='val_loss', scalar_value=val_loss, global_step=step)
        tb_logger.add_scalar(tag='val_metric', scalar_value=val_metric, global_step=step)

        
    print('\nValidate: Average loss: {:.4f}, Average Metric: {:.4f}\n'.format(val_loss, val_metric))



In [113]:
class DiceCoefficient(nn.Module):
    def __init__(self, eps=1e-6):
        super().__init__()
        self.eps = eps
        
    # the dice coefficient of two sets represented as vectors a, b ca be 
    # computed as (2 *|a b| / (a^2 + b^2))
    def forward(self, prediction, target):
        intersection = (prediction * target).sum()
        denominator = (prediction * prediction).sum() + (target * target).sum()
        return (2 * intersection / denominator.clamp(min=self.eps))

In [114]:
from torch.optim import Adam


loss_function = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=1.e-4) #and not e.-3
metric = DiceCoefficient()

n_epochs = 10 #Number of time the NN is iterate
for epoch in range(n_epochs):
    print(f"Epoch {epoch+1}\n-------------------------------")
    train_loop(train_loader, model, loss_function, optimizer , epoch , tb_logger=tb_logger)
    #test_loop(test_loader, model, loss_function)
    step = epoch * len(train_loader.dataset)
    validate(model, test_loader, loss_function, metric, step, tb_logger=tb_logger)
print("Done!")

Epoch 1
-------------------------------

Validate: Average loss: 0.3659, Average Metric: -0.0000

Epoch 2
-------------------------------

Validate: Average loss: 0.3471, Average Metric: -0.0000

Epoch 3
-------------------------------

Validate: Average loss: 0.3549, Average Metric: -0.0000

Epoch 4
-------------------------------

Validate: Average loss: 0.3415, Average Metric: -0.0000

Epoch 5
-------------------------------

Validate: Average loss: 0.3573, Average Metric: -0.0000

Epoch 6
-------------------------------

Validate: Average loss: 0.3270, Average Metric: -0.0000

Epoch 7
-------------------------------

Validate: Average loss: 0.3304, Average Metric: -0.0000

Epoch 8
-------------------------------

Validate: Average loss: 0.3871, Average Metric: -0.0000

Epoch 9
-------------------------------

Validate: Average loss: 0.3506, Average Metric: -0.0000

Epoch 10
-------------------------------

Validate: Average loss: 0.3438, Average Metric: -0.0000

Done!


In [115]:
#Save model
torch.save(UNet , "model.pth")
# Load model
#model = torch.load('model.pth')


In [117]:
#tb_logger = SummaryWriter('runs/Unet')

%tensorboard --logdir runs --port 6508

In [71]:
#!kill 6880