### Librairies import

In [1]:
import pandas as pd
import numpy as np
import os
from matplotlib import pyplot as plt

import torch
from torch import Tensor
import torch.nn as nn
import torch.optim as optim
from torchsummary import summary
import torch.nn.functional as F

### Datasets import 

In [2]:
from utils import get_data, plot_slice_seg, rand_index_dataset, prediction_to_df

In [None]:
X_train, X_test, y_train = get_data()

#### X_train collected ####
#### X_test collected ####


### U-Net

In [56]:
class Block(nn.Module):
    def __init__(self, in_filters, out_filters):
        
        super(Block, self).__init__()
            
        self.conv1 = nn.Conv2d(in_filters, out_filters, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_filters, out_filters, kernel_size=3, padding=1)
    
        self.batchnorm1 = nn.BatchNorm2d(out_filters)
        self.batchnorm2 = nn.BatchNorm2d(out_filters)
        
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        
        self.dropout1 = nn.Dropout(0.3)
        self.dropout2 = nn.Dropout(0.3)

    def forward(self, x):
        
        x = self.batchnorm1(self.conv1(x)).clamp(0)
        x = self.relu1(x)
        x = self.dropout1(x)
 
        x = self.batchnorm2(self.conv2(x)).clamp(0)
        x = self.relu2(x)
        x = self.dropout2(x)
        
        return x

In [65]:
class Unet(nn.Module):

    def __init__(self, k, in_channels, filters):

        super(Unet, self).__init__()
        
        self.enc1 = Block(in_channels, filters)
        self.enc2 = Block(filters, 2*filters)
        self.enc3 = Block(2*filters, 4*filters)
        self.enc4 = Block(4*filters, 8*filters)   
        self.enc5 = Block(8*filters, 16*filters) #minimal dimension space
        
# ---- We define the transpose convolution, that does up pooling 2x2 (inverse of maxpooling)

        self.up1 = nn.ConvTranspose2d(16*filters, 8*filters, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.dec1 = Block(16*filters, 8*filters)
        self.up2 = nn.ConvTranspose2d(8*filters, 4*filters, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.dec2 = Block(8*filters, 4*filters)
        self.up3 = nn.ConvTranspose2d(4*filters, 2*filters, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.dec3 = Block(4*filters, 2*filters)
        self.up4 = nn.ConvTranspose2d(2*filters, filters, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.dec4 = Block(2*filters, filters)
        
        self.output = nn.Conv2d(filters, k, kernel_size=(1, 1))

    def forward(self, x):

# ---- We define the down sampling phase : block/max pooling

        enc1 = self.enc1(x)
        maxpool1 = F.max_pool2d(enc1, (2, 2))   
        enc2 = self.enc2(maxpool1)
        maxpool2 = F.max_pool2d(enc2, (2, 2))    
        enc3 = self.enc3(maxpool2)
        maxpool3 = F.max_pool2d(enc3, (2, 2))    
        enc4 = self.enc4(maxpool3)
        maxpool4 = F.max_pool2d(enc4, (2, 2))      
        enc5=self.enc5(maxpool4)   
        
# ---- We define the up sampling phase : transposed convolution(= up pooling)/block
# ---- We define the skip connections

        up1 = torch.cat([enc4, self.up1(enc5)], 1) # skip connection : concatenate previous layer output + enc4 output
        dec1 = self.dec1(up1)   
        up2 = torch.cat([enc3, self.up2(dec1)], 1)
        dec2 = self.dec2(up2)   
        up3 = torch.cat([enc2, self.up3(dec2)], 1)
        dec3 = self.dec3(up3)  
        up4 = torch.cat([enc1, self.up4(dec3)], 1)
        dec4 = self.dec4(up4)       
        output = self.output(dec4)
            
        return output

### Model instanciation

In [61]:
Unet_model = Unet(k=10, in_channels=1, filters=64)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(Unet_model.parameters(), lr=1e-5, weight_decay=1e-8, foreach=True)

In [66]:
summary(Unet_model, (1, 512, 512), 64)

RuntimeError: [enforce fail at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\c10\core\impl\alloc_cpu.cpp:72] data. DefaultCPUAllocator: not enough memory: you tried to allocate 134217728 bytes.

### Convert our datasets with pytorch trainloader 

In [33]:
batch_size = 64

trainloader = torch.utils.data.DataLoader(X_train, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testloader = torch.utils.data.DataLoader(X_test, batch_size=batch_size,
                                         shuffle=False, num_workers=2)


### Now, let's train the model 

In [38]:
import torch.optim as optim