### Librairies import

In [1]:
import pandas as pd
import numpy as np
import os
from matplotlib import pyplot as plt

from skimage.segmentation import watershed, felzenszwalb
from skimage.filters import sobel
import pandas as pd
from pathlib import Path
import cv2
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from skimage.filters import rank
from scipy import ndimage as ndi
from skimage.morphology import disk
import sklearn.metrics
from io import BytesIO
import PIL.Image
import zipfile
import torch
from torch import Tensor
import torch.nn as nn

### Datasets import 

In [2]:
from utils import get_data, plot_slice_seg, rand_index_dataset, prediction_to_df

In [3]:
X_train, X_test, y_train = get_data()

#### X_train collected ####
#### X_test collected ####
#### y_train collected ####


### 1st Model : W-Net architecture
see this article : https://arxiv.org/pdf/1711.08506.pdf

![alternative text](https://miro.medium.com/max/1029/1*FU2BbaWCShWLvf6QsNGXlA.png)

Every block is composed of convolution/batch normalization/relu -> convolution/batch normalization/relu. There's a pooling layer between every block

In [28]:
class Block(nn.Module):
    def __init__(self, in_filters, out_filters, seperable=True):
        
        super(Block, self).__init__()
        
        if seperable:
            
# ---- We define here the 2 depthwise seperable convolutions inside a separable block
               
            self.depth_conv1 = nn.Conv2d(in_filters, in_filters, kernel_size=3, groups=in_filters, padding=1)
            self.point_conv1 = nn.Conv2d(in_filters, out_filters, kernel_size=1)
            self.conv1 = nn.Sequential(self.depth_conv1, self.point_conv1)
            
            self.depth_conv2 = nn.Conv2d(out_filters, out_filters, kernel_size=3, groups=out_filters, padding=1)
            self.point_conv2 = nn.Conv2d(out_filters, out_filters, kernel_size=1)
            self.conv2 = nn.Sequential(self.depth_conv2, self.point_conv2)
            
        else:
            
# ---- We define here the 2 convolutions inside a standard block
            
            self.conv1 = nn.Conv2d(in_filters, out_filters, kernel_size=3, padding=1)
            self.conv2 = nn.Conv2d(out_filters, out_filters, kernel_size=3, padding=1)
    
        self.batchnorm1 = nn.BatchNorm2d(out_filters)
        self.batchnorm2 = nn.BatchNorm2d(out_filters)
        
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        
        self.dropout1 = nn.Dropout(0.3)
        self.dropout2 = nn.Dropout(0.3)

    def forward(self, x):
        
# ---- We define here a block : convolution/batch norm/relu, convolution/batch norm/relu
        
        x = self.batchnorm1(self.conv1(x)).clamp(0)
        x = self.relu1(x)
        x = self.dropout1(x)
 
        x = self.batchnorm2(self.conv2(x)).clamp(0)
        x = self.relu2(x)
        x = self.dropout2(x)
        
        return x

class UEnc(nn.Module):

    def __init__(self, k, in_channels=3, filters=64):

        super(UEnc, self).__init__()
        
        self.enc1 = Block(in_channels, filters, seperable=False)
        self.enc2 = Block(filters, 2*filters)
        self.enc3 = Block(2*filters, 4*filters)
        self.enc4 = Block(4*filters, 8*filters)   
        self.enc5 = Block(8*filters, 16*filters) #minimal dimension space
        

# ---- We define the transpose convolution, that does up pooling 2x2 (inverse of maxpooling)

        self.up1 = nn.ConvTranspose2d(16*filters, 8*filters, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.dec1 = Block(16*filters, 8*filters)
        self.up2 = nn.ConvTranspose2d(8*filters, 4*filters, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.dec2 = Block(8*filters, 4*filters)
        self.up3 = nn.ConvTranspose2d(4*filters, 2*filters, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.dec3 = Block(4*filters, 2*filters)
        self.up4 = nn.ConvTranspose2d(2*filters, filters, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.dec4 = Block(2*filters, filters, seperable=False)
        
        self.final = nn.Conv2d(filters, k, kernel_size=(1, 1))


    def forward(self, x):

# ---- We define the down sampling phase : block/max pooling

        enc1 = self.enc1(x)
        maxpool1 = F.max_pool2d(enc1, (2, 2))   
        enc2 = self.enc2(maxpool1)
        maxpool2 = F.max_pool2d(enc2, (2, 2))    
        enc3 = self.enc3(maxpool2)
        maxpool3 = F.max_pool2d(enc3, (2, 2))    
        enc4 = self.enc4(maxpool3)
        maxpool4 = F.max_pool2d(enc4, (2, 2))      
        enc5=self.enc5(maxpool4)   
        
# ---- We define the up sampling phase : transposed convolution(= up pooling)/block
# ---- We define the skip connections

        up1 = torch.cat([enc4, self.up1(enc5)], 1) # skip connection : concatenate previous layer output + enc4 output
        dec1 = self.dec1(up1)   
        up2 = torch.cat([enc3, self.up2(dec1)], 1)
        dec2 = self.dec2(up2)   
        up3 = torch.cat([enc2, self.up3(dec2)], 1)
        dec3 = self.dec3(up3)  
        up4 = torch.cat([enc1, self.up4(dec3)], 1)
        dec4 = self.dec4(up4)       
        final = self.final(dec4)
            
        return final

class UDec(nn.Module):

    def __init__(self, k, in_channels=3, filters=64):

        super(UDec, self).__init__()
        
        self.enc1 = Block(k, filters, seperable=False)
        self.enc2 = Block(filters, 2*filters)
        self.enc3 = Block(2*filters, 4*filters)
        self.enc4 = Block(4*filters, 8*filters)   
        self.enc5 = Block(8*filters, 16*filters) #minimal dimension space
        

# ---- We define the transpose convolution, that does up pooling 2x2 (inverse of maxpooling)

        self.up1 = nn.ConvTranspose2d(16*filters, 8*filters, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.dec1 = Block(16*filters, 8*filters)
        self.up2 = nn.ConvTranspose2d(8*filters, 4*filters, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.dec2 = Block(8*filters, 4*filters)
        self.up3 = nn.ConvTranspose2d(4*filters, 2*filters, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.dec3 = Block(4*filters, 2*filters)
        self.up4 = nn.ConvTranspose2d(2*filters, filters, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.dec4 = Block(2*filters, filters, seperable=False)
        
        self.final = nn.Conv2d(filters, in_channels, kernel_size=(1, 1))


    def forward(self, x):

# ---- We define the down sampling phase : block/max pooling

        enc1 = self.enc1(x)
        maxpool1 = F.max_pool2d(enc1, (2, 2))  
        enc2 = self.enc2(maxpool1)
        maxpool2 = F.max_pool2d(enc2, (2, 2))  
        enc3 = self.enc3(maxpool2)
        maxpool3 = F.max_pool2d(enc3, (2, 2))
        enc4 = self.enc4(maxpool3)
        maxpool4 = F.max_pool2d(enc4, (2, 2)) 
        enc5=self.enc5(maxpool4)   
        
# ---- We define the up sampling phase : transposed convolution(= up pooling)/block
# ---- We define the skip connections

        up1 = torch.cat([enc4, self.up1(enc5)], 1) # skip connection : concatenate previous layer output + enc4 output
        dec1 = self.dec1(up1)  
        up2 = torch.cat([enc3, self.up2(dec1)], 1)
        dec2 = self.dec2(up2)  
        up3 = torch.cat([enc2, self.up3(dec2)], 1)
        dec3 = self.dec3(up3)
        up4 = torch.cat([enc1, self.up4(dec3)], 1)
        dec4 = self.dec4(up4)     
        final = self.final(dec4)
            
        return final

class WNet(nn.Module):
    
    def __init__(self, k, filters=64, in_channels=3, out_chans=3):
        
        super(WNet, self).__init__()

        self.UEnc = UEnc(k, in_channels, filters)
        self.UDec = UDec(k, in_channels, filters)
        
    def forward(self, x):
        
        enc = self.UEnc(x)      
        dec = self.UDec(F.softmax(enc, 1))

        return enc, dec


In [30]:
block1 = Block(3, 3)
enc1 = UEnc(10)
dec1 = UDec(10)
WNet1 = WNet(10)

### Now, we have to create the loss functions 

### Convert our datasets with pytorch trainloader 

In [33]:
batch_size = 64

trainloader = torch.utils.data.DataLoader(X_train, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testloader = torch.utils.data.DataLoader(X_test, batch_size=batch_size,
                                         shuffle=False, num_workers=2)


### Now, let's train the model 

In [38]:
import torch.optim as optim