## BE UNet 
> Victor Ludvig, February 2024 <br>
> Prof. L. Chen

In [None]:
import torch 
import torch.nn as nn
from torch.nn import BCELoss
from torch.optim import Adam
import torchvision
from torch.utils.data import Dataset, DataLoader
import os
from PIL import Image
from torchvision.transforms import ToTensor, Resize, Compose
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import copy
import matplotlib.pyplot as plt 
from torch.optim.lr_scheduler import MultiplicativeLR
PATH_ROOT = './competition_data/competition_data/train'
PATH_IMAGES = os.path.join(PATH_ROOT, 'images')
PATH_MASKS = os.path.join(PATH_ROOT, 'masks')

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

%matplotlib widget

### Implementatin of UNet

#### Computation of the transpose convolution parameters
> The transpose convolution layers upsample the image by doubling its width and height. <br>
> The downsampling is achieved using a MaxPool2D layer in the encoder with kernel k=2, stride s=2 and padding p=0. <br>
> To upsample with PyTorch, we can simply use the same parameters for the TransposeConvolution: k=2, s=2, p=0. <br>
> The stride s=2 adds one 0 between each pixel, so the convolution with a kernel of size 2 will effectively double the image width and height. <br>

> This result can be checked using the formula from the [Pytorch website](https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html):
$\begin{equation} H_{out}​=(H_{in}​−1)×stride−2×padding+dilation×(kernel\_size−1)+output\_padding+1 \end{equation}$

> With : 
>> $H_{out}​=2H_{in}$ <br>
>> $stride=2$ <br>
>> $dilation = 1$ <br>
>> $kernel\_size = 2$ <br>
>> $output\_padding = 0$

In [None]:
class UNet(nn.Module):
    def __init__(self):
        # Super constructor
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding='same'),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding='same'), 
            nn.BatchNorm2d(16),
            nn.ReLU()
        )
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding='same'),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding='same'),
            nn.BatchNorm2d(32),
            nn.ReLU()
        )
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv3 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding='same'),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding='same'),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv4 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding='same'),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding='same'),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv5 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding='same'),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding='same'),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )

        # Decoder 
        
        self.transpose_conv1 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride=2)

        self.conv6 = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, padding='same'),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding='same'),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        self.transpose_conv2 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2)

        self.conv7 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding='same'),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding='same'),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.transpose_conv3 = nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=2, stride=2)

        self.conv8 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, padding='same'),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding='same'),
            nn.BatchNorm2d(32),
            nn.ReLU()
        )
        self.transpose_conv4 = nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=2, stride=2)

        self.conv9 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=16, kernel_size=3, padding='same'),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding='same'),
            nn.BatchNorm2d(16),
            nn.ReLU()
        )

        self.conv10 = nn.Conv2d(in_channels=16, out_channels=1, kernel_size=1)
        

    def forward(self, images):
        # Encoder

        c1 = self.conv1(images)
        p1 = self.pool1(c1)

        c2 = self.conv2(p1)
        p2 = self.pool2(c2)

        c3 = self.conv3(p2)
        p3 = self.pool3(c3)

        c4 = self.conv4(p3)
        p4 = self.pool4(c4)

        c5 = self.conv5(p4)

        # Decoder 

        u6 = self.transpose_conv1(c5)
        u6 = torch.cat((u6, c4), dim=1) # first dim is batch dimension, second (1) is the channel

        c6 = self.conv6(u6)
        u7 = self.transpose_conv2(c6)
        u7 = torch.cat((u7, c3), dim=1)

        c7 = self.conv7(u7)
        u8 = self.transpose_conv3(c7)
        u8 = torch.cat((u8, c2), dim=1)

        c8 = self.conv8(u8)
        u9 = self.transpose_conv4(c8)
        u9 = torch.cat((u9, c1), dim=1)

        c9 = self.conv9(u9)
        
        return torch.sigmoid(self.conv10(c9))

In [None]:
UNet = UNet()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
UNet.to(device)

#### Training

> The learning rate will be divided by 10 when 5 consecutive epochs don't improve the validation loss <br>
> It works very well, and the threshold could be increased to 15 epochs without improvement.

> Creation of custom Dataset class to load data <br>
> Some images are loaded by Pillow with an alpha channel with all values to 255. <br>
The alpha channel is discarded using im = im.convert("RGB")<br>

In [None]:
class ImageDataset(Dataset):

    def __init__(self, list_images, list_masks, transform=None):
        self.list_images = list_images
        self.list_masks = list_masks
        
        self.transform = transform

    def __len__(self):
        return len(self.list_images)

    def __getitem__(self, idx):
        img_path = self.list_images[idx]
        with Image.open(img_path) as im:
            im = im.convert("RGB")
            if self.transform:
                im = self.transform(im)

        mask_path = self.list_masks[idx]
        with Image.open(mask_path) as im_mask:
            im_mask = im_mask.convert('L')
            if self.transform:
                im_mask = self.transform(im_mask)

        return im, im_mask

#### Divide train/val

> The train images are divided into train/val images <br>
> 20% of the images are used as validation images

In [None]:
list_images = [os.path.join(PATH_IMAGES,x) for x in os.listdir(PATH_IMAGES)]
list_masks = [os.path.join(PATH_MASKS,x) for x in os.listdir(PATH_MASKS)]
path_train_images, path_val_images, path_train_masks, path_val_masks = train_test_split(list_images, list_masks, test_size=0.2, random_state=0)

assert len(path_train_images) == len(path_train_masks)
print(f'Number of training images: {len(path_train_images)}')
print(f'Number of validation images: {len(path_val_images)}')

In [None]:
def get_hyperparameters(model):
    """"Get hyperparameters"""
    N_EPOCHS = 200
    LR = 0.0001
    BATCH_SIZE = 32
    optimizer = Adam(model.parameters(), lr=LR)
    scheduler = MultiplicativeLR(optimizer, lr_lambda=lambda epoch: .1)
    criterion = BCELoss()
    return N_EPOCHS, LR, BATCH_SIZE, optimizer, scheduler, criterion

def get_data_loader(path_train_images, path_train_masks, path_val_images, path_val_masks, BATCH_SIZE):
    """ Get data loader. """
    train_dataset = ImageDataset(path_train_images, path_train_masks, Compose([ToTensor(), Resize((128,128))]))
    val_dataset = ImageDataset(path_val_images, path_val_masks, Compose([ToTensor(), Resize((128,128))]))

    train_data_loader = torch.utils.data.DataLoader(train_dataset,
                                            batch_size=BATCH_SIZE,
                                            shuffle=True,
                                            num_workers=0)

    val_data_loader = torch.utils.data.DataLoader(val_dataset,
                                            batch_size=BATCH_SIZE,
                                            shuffle=True,
                                            num_workers=0)

    dataloaders = {'train': train_data_loader,
                'val': val_data_loader}
    
    dataset_sizes = {'train': len(path_train_images),
                     'val': len(path_val_images)}
    
    return dataloaders, dataset_sizes

In [None]:
N_EPOCHS, LR, BATCH_SIZE, optimizer, scheduler, criterion = get_hyperparameters(UNet)
dataloaders, dataset_sizes = get_data_loader(path_train_images, path_train_masks, path_val_images, path_val_masks, BATCH_SIZE)

In [None]:
def train(model, optimizer, criterion, dataloaders, N_EPOCHS, scheduler):
    train_losses = []
    val_losses = []
    best_val_loss = 1e10
    best_model_weights = copy.deepcopy(model.state_dict())
    # threshold of consecutive epochs without improvement to know when to update the learning rate
    threshold_lr = 0
    threshold_stop = 0

    for epoch in tqdm(range(N_EPOCHS)):
        print(f'\n\nEpoch {epoch}/{N_EPOCHS}')
        print('-'*10)

        for phase in ['train', 'val']:
            print(f'Phase: {phase}')
            if phase == 'train':
                model.train()
            else: 
                model.eval()

            running_loss = 0.0

            for inputs, masks in dataloaders[phase]:
                inputs = inputs.to(device)
                masks = masks.to(device)


                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = criterion(outputs, masks)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step() 

                    running_loss += loss.item() * inputs.size(0)

            average_epoch_loss = running_loss / dataset_sizes[phase]

            print(f'Average {phase} loss: {average_epoch_loss:.4f}')

            if phase == 'val':
                val_losses.append(average_epoch_loss)
                if average_epoch_loss < best_val_loss:
                    best_val_loss = average_epoch_loss
                    best_model_weights = copy.deepcopy(model.state_dict())
                    threshold_lr = 0
                    threshold_stop = 0
                else:
                    threshold_lr += 1
                    threshold_stop += 1 
                if threshold_lr >= 10:
                    scheduler.step()
                    threshold_lr = 0
                    threshold_stop += 1  
                print(f'threshold_lr: {threshold_lr}')
            else:
                train_losses.append(average_epoch_loss)

            # early stopping in 20 epochs without improvement
            if threshold_stop >= 20:
                return model, train_losses, val_losses  

            
    model.load_state_dict(best_model_weights)
    return model, train_losses, val_losses  

In [None]:
trained_unet, train_losses, val_losses = train(UNet, optimizer, criterion, dataloaders, N_EPOCHS, scheduler)

In [None]:
if not os.path.isdir('./model'):
    os.mkdir('model')
torch.save(trained_unet.state_dict(), './model/trained_unet_1.pt')

if not os.path.isdir('./figures'):
    os.mkdir('figures')

> The UNet model trains well over 200 epochs, without overfitting. <br>
> The training could have gone further. <br>
> However we see that at the end the incremental improvements are minimal. <br>

In [None]:
def plot_training(N_epochs, train_losses, val_losses, title='test.png'):
    epochs = range(1, N_epochs)  # Assuming 200 epochs

    # Plotting
    plt.figure(figsize=(10, 6))  # Adjust figure size as needed
    plt.plot(epochs, val_losses, label='Validation Loss')
    plt.plot(epochs, train_losses, label='Training Loss')
    plt.title('Training and Validation Loss over Epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()  # Adjust layout
    plt.savefig(f'./figures/{title}')
    plt.show()

In [None]:
plot_training(N_EPOCHS+1, train_losses, val_losses, 'training_1.png')

![figures/unet_training.png](figures/unet_training.png)

#### 4) Inference
> On training set first

In [None]:
trained_unet = UNet()
trained_unet.load_state_dict(state_dict=torch.load('./model/trained_unet_1.pt'))
trained_unet.to(device)

In [None]:
def plot_predictions(n, inputs, masks_rgb, outputs, img_name=''):
    """ Plot ground-truth, salt, predictions, binary predictions side by side, like in the pdf. 
    n: number of ground-truth images to plot
    """
    outputs_binary = torch.where(outputs > .5, 1., 0.).repeat(1, 3, 1, 1)
    outputs = outputs.repeat(1, 3, 1, 1)

    # Concatenate the images along the width dimension to create a single tensor
    combined_images = torch.cat([inputs[0:n,:,:,:], masks_rgb[0:n,:,:,:], outputs[0:n,:,:,:], outputs_binary[0:n,:,:,:]], dim=3)

    # Make a grid with the combined images
    grid_image = torchvision.utils.make_grid(combined_images, nrow=1, padding=10, normalize=True)

    # Convert the grid tensor to a numpy array for visualization
    grid_image_np = grid_image.permute(1, 2, 0).cpu().numpy()

    # Display the grid of images
    plt.figure(figsize=(10, 15))
    plt.imshow(grid_image_np)
    
    plt.axis('off')
    plt.show()
    plt.savefig(f'figures/{img_name}')

In [None]:
def predictions(model, dataloader, name_img):
    inputs, masks = next(iter(dataloader))
    masks = masks.to(device)
    masks_rgb = masks.repeat(1, 3, 1, 1)
    inputs = inputs.to(device)
    outputs = model(inputs)
    plot_predictions(5, inputs, masks_rgb, outputs, name_img)

> Results are very good on these training samples

In [None]:
predictions(trained_unet, dataloaders['val'], 'inference_validation.png')

![figures/inference_validation.png](figures/inference_validation.png)

> Results are also pretty good on the validation samples, although a bit less accuracte on the last image.

#### 5- Questions

> 1) The 2x2 MaxPool kernels can be replaced with 2x2 Conv2D kernel with padding 1 and stride 2. The resulting size will still be divided by two since the MaxPool kernel has the same effect as the Conv2D kernel on the size. <br>
It can be checked with the formula from [Pytorch website](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html): <br>
$\begin{equation} H_{out}​=\frac{H_{in}+2×padding+dilation×(kernel\_size−1)-1}{stride}+1\end{equation}$

In [None]:
class Conv2(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding='same'),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding='same'),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

    def forward(self, x):
        # print(f'Conv2 forward, x.shape = {x.shape}')
        return self.conv(x)
    
class ConvUnet(nn.Module):
    def __init__(self):
        # Super constructor
        super().__init__()
        self.conv1 = Conv2(in_channels=3, out_channels=16)
        self.pool_conv1 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=2, stride=2)

        self.conv2 = Conv2(in_channels=16, out_channels=32)
        self.pool_conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=2, stride=2)

        self.conv3 = Conv2(in_channels=32, out_channels=64)
        self.pool_conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=2, stride=2)

        self.conv4 = Conv2(in_channels=64, out_channels=128)
        self.pool_conv4 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=2, stride=2)

        self.conv5 = Conv2(in_channels=128, out_channels=256)

        # Decoder 
        
        self.transpose_conv1 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride=2)

        self.conv6 = Conv2(in_channels=256, out_channels=128)
        self.transpose_conv2 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2)

        self.conv7 = Conv2(in_channels=128, out_channels=64)
        self.transpose_conv3 = nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=2, stride=2)

        self.conv8 = Conv2(in_channels=64, out_channels=32)
        self.transpose_conv4 = nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=2, stride=2)

        self.conv9 = Conv2(in_channels=32, out_channels=16)

        self.conv10 = nn.Conv2d(in_channels=16, out_channels=1, kernel_size=1)
        

    def forward(self, images):
        # Encoder

        # print(f'images.shape = {images.shape}')
        c1 = self.conv1(images)
        # print(f'c1.shape = {c1.shape}')
        p1 = self.pool_conv1(c1)

        c2 = self.conv2(p1)
        # print(f'c2.shape = {c1.shape}')
        p2 = self.pool_conv2(c2)

        c3 = self.conv3(p2)
        p3 = self.pool_conv3(c3)

        c4 = self.conv4(p3)
        p4 = self.pool_conv4(c4)

        c5 = self.conv5(p4)

        # Decoder 

        u6 = self.transpose_conv1(c5)
        u6 = torch.cat((u6, c4), dim=1) # first dim is batch dimension, second (1) is the channel

        c6 = self.conv6(u6)
        u7 = self.transpose_conv2(c6)
        u7 = torch.cat((u7, c3), dim=1)

        c7 = self.conv7(u7)
        u8 = self.transpose_conv3(c7)
        u8 = torch.cat((u8, c2), dim=1)

        c8 = self.conv8(u8)
        u9 = self.transpose_conv4(c8)
        u9 = torch.cat((u9, c1), dim=1)

        c9 = self.conv9(u9)
        
        return torch.sigmoid(self.conv10(c9))
    
UNet_conv = ConvUnet()
UNet_conv.to(device)

In [None]:
N_EPOCHS, LR, BATCH_SIZE, optimizer, scheduler, criterion = get_hyperparameters(UNet_conv)
dataloaders, dataset_sizes = get_data_loader(path_train_images, path_train_masks, path_val_images, path_val_masks, BATCH_SIZE)

trained_unet_conv, train_losses, val_losses = train(UNet_conv, optimizer, criterion, dataloaders, N_EPOCHS, scheduler)

In [None]:
if not os.path.isdir('./model'):
    os.mkdir('model')
torch.save(trained_unet_conv.state_dict(), './model/trained_unetconv.pt')

if not os.path.isdir('./figures'):
    os.mkdir('figures')

In [None]:
plot_training(N_EPOCHS+1, train_losses, val_losses, title='conv_unet_training.png')

![figures/conv_unet_training.png](figures/conv_unet_training.png)

In [None]:
predictions(trained_unet_conv, dataloaders['val'], 'validation_predictions_unet_conv.png')

![figures/validation_predictions_unet_conv.png](figures/validation_predictions_unet_conv.png)

> Replacing the MaxPool layers with convolutional layers makes the model a bit **slower to learn**, however it is able to learn. <br>
> It was quite predictable because MaxPool and convolution both select the same samples of the features maps (stride 2). <br>
> The MaxPool operation is not learnt, hence the model doesn't have to optimize any weight which is probably the reason why training is slower. <br>
> The MaxPool operation is non linear, therefore a convolutional layer (which applies a linear weighted sum of inputs) cannot learn to do it. The optimal learnt convolutional kernel is probably an average of the weights, which mimic the max operation. <br>

> 2. As explained in the notes, the encoder loses spatial information. <br>
The latent space contains 8x8 features maps, wherehas the initial image size is 128x128. <br>
The 4 skip connections gradually add some spatial information to the decoded latent space. <br>
If we remove the skip connections, the decoder won't have any spatial information to recontruct the white zones. <br>
It will probably be harder for the model to learn. <br>

> The following section removes skip connections. The size of the feature maps are not increased by the skip connection. <br>
> Hence, the convolutional layers of the features map don't decrease the size of the feature maps (in_channels = out_channels) <br>
> The no skip UNet is implemented in  [UNet_no_skip.py](UNet_no_skip.py).

In [None]:
from UNet_no_skip import UNet_no_skip

UNet_no_skip = UNet_no_skip()
UNet_no_skip.to(device)

N_EPOCHS, LR, BATCH_SIZE, optimizer, scheduler, criterion = get_hyperparameters(UNet_no_skip)
dataloaders, dataset_sizes = get_data_loader(path_train_images, path_train_masks, path_val_images, path_val_masks, BATCH_SIZE)

trained_unet_no_skip, train_losses, val_losses = train(UNet_no_skip, optimizer, criterion, dataloaders, N_EPOCHS, scheduler)

In [None]:
torch.save(trained_unet_no_skip.state_dict(), './model/trained_unetnoskip.pt')

In [None]:
plot_training(N_EPOCHS+1, train_losses, val_losses, title='no_skip_unet_training.png')

![figures/no_skip_unet_training.png](figures/no_skip_unet_training.png)

In [None]:
predictions(trained_unet_no_skip, dataloaders['val'], 'validation_predictions_unet_no_skip.png')

![figures/validation_predictions_unet_no_skip.png](figures/validation_predictions_unet_no_skip.png)

> Without skip connections the model is able to learn, however the final validation loss is around 5% higher than the one with the skip connections. <br>
> Training is faster without skip connections: it takes around 75 epochs, compared to around 150 for the orginal model. <br>
> This is because without concatenations, the decoder features maps have less channels (2x less), hence the decoder convolutional layers have less parameters to optimize.

> Alternative skip connection <br>
>> Addition <br>
>> The AddUnet is implemented in [UNet_add.py](UNet_add.py). <br>
>> The decoder convolutional layers keep the number channels, and the transpose convolutional layers divide the number of channels by two.

In [None]:
from UNet_add import UNet_add
    
UNet_add = UNet_add()
UNet_add.to(device)

N_EPOCHS, LR, BATCH_SIZE, optimizer, scheduler, criterion = get_hyperparameters(UNet_add)
dataloaders, dataset_sizes = get_data_loader(path_train_images, path_train_masks, path_val_images, path_val_masks, BATCH_SIZE)

trained_unet_add, train_losses, val_losses = train(UNet_add, optimizer, criterion, dataloaders, N_EPOCHS, scheduler)

In [None]:
torch.save(trained_unet_add.state_dict(), './model/trained_unetadd.pt')

In [None]:
plot_training(N_EPOCHS+1, train_losses, val_losses, title='add_unet_training.png')

![figures/add_unet_training.png](figures/add_unet_training.png)

In [None]:
predictions(trained_unet_add, dataloaders['val'], 'validation_predictions_unet_add.png')

![figures/validation_predictions_unet_add.png](figures/validation_predictions_unet_add.png)

>> Max <br>
The MaxUnet is implemented in [MaxUnet.py](MaxUnet.py)

In [None]:
from MaxUnet import MaxUnet
    
MaxUnet = MaxUnet()
MaxUnet.to(device)

N_EPOCHS, LR, BATCH_SIZE, optimizer, scheduler, criterion = get_hyperparameters(MaxUnet)
dataloaders, dataset_sizes = get_data_loader(path_train_images, path_train_masks, path_val_images, path_val_masks, BATCH_SIZE)

trained_MaxUnet, train_losses, val_losses = train(MaxUnet, optimizer, criterion, dataloaders, N_EPOCHS, scheduler)

In [None]:
torch.save(trained_MaxUnet.state_dict(), './model/trained_unetmax.pt')

In [None]:
plot_training(146, train_losses, val_losses, title='max_unet_training.png')

![figures/max_unet_training.png](figures/max_unet_training.png)

In [None]:
predictions(trained_MaxUnet, dataloaders['val'], 'validation_predictions_unet_max.png')

![figures/validation_predictions_unet_max.png](figures/validation_predictions_unet_max.png)

> The resuts of MaxUnet are very close to the original UNet. <br>
> The validation loss crosses the training loss around epoch 130, like in the original UNet. <br>
> It seems that taking the max has the same effect as adding additional feature maps. <br>
> It can be linked to the MaxPool operations being replaced by Convolutions in question 2: here we use max of feature maps instead of convolutions on combined feature maps. Like in question 2, results (regarding the loss) are very similar, because convolutions and max can yield similar values, especially since the inputs are normalized.

#### 5.3 FCN
> The Fully Convolutional neural Network is implemented in [FCN.py](FCN.py). <br>
> All double convolution layers were kept. <br>
> Remaining shrinking or upsampling layers (MaxPool, TransposeConv) were dismissed. <br>

In [None]:
from FCN import FCN
    
FCN = FCN()
FCN.to(device)

N_EPOCHS, LR, BATCH_SIZE, optimizer, scheduler, criterion = get_hyperparameters(FCN)
dataloaders, dataset_sizes = get_data_loader(path_train_images, path_train_masks, path_val_images, path_val_masks, BATCH_SIZE)

trained_fcn, train_losses, val_losses = train(FCN, optimizer, criterion, dataloaders, N_EPOCHS, scheduler)

> First of all, a FCN which keeps the image size constant takes much more disk space than the auto-encoder architecture during training: up from 2.5GB to 12GB of GPU RAM. <br>
> As the feature map dimensions remain constant (intial image size) and the channels are the same as with the auto-encoder, there are many more computations being carried out by the network, and many more gradients computations to save during training. <br>
> The benefit of this method is the **explainability**: we can check feature maps at each stage of the network, contrary to the auto-encoder architecture which shrinks the feature maps. <br>

In [None]:
torch.save(trained_fcn.state_dict(), './model/trained_fcn.pt')

In [None]:
plot_training(N_EPOCHS+1, train_losses, val_losses, title='fcn_training.png')

![figures/fcn_training.png](figures/fcn_training.png)

In [None]:
predictions(trained_fcn, dataloaders['val'], 'validation_predictions_fcn.png')

![figures/validation_predictions_fcn.png](figures/validation_predictions_fcn.png)

#### 5.4 Threshold for inference

> A classical metric that optimizes both precision and recall in our case is the Intersection over Union (IoU).   
> To determine the optimal threshold, we can simply make a grid search over the thresholds, and choose the one that gets the best IoU on the validation set. <br>

In [None]:
UNet = UNet()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
UNet.to(device)

In [None]:
UNet.load_state_dict(torch.load('./model/trained_unet_1.pt'))

In [None]:
BATCH_SIZE = 64
dataloaders, dataset_sizes = get_data_loader(path_train_images, path_train_masks, path_val_images, path_val_masks, BATCH_SIZE)

In [None]:
def compute_IoU(masks, outputs):
    """ Return IoU of masks and outputs.
    
    When union is 0, it returns 1.
    Since intersection C union, I = 0 if union = 0. 
    In that case the output is correct that's why 1 is returned.
    """
    intersection = torch.sum(masks & outputs, dim=(1,2,3))
    union = torch.sum(masks | outputs, dim=(1,2,3))
    IoU = torch.where(union == 0, torch.tensor(1), intersection / union)
    return IoU.sum()

In [None]:
def get_precision_recall(masks, outputs):
    TP = torch.sum(torch.where((masks == 1) & (outputs == 1), torch.tensor(1), torch.tensor(0)), dim=(1,2,3))
    FP = torch.sum(torch.where((masks == 0) & (outputs == 1), torch.tensor(1), torch.tensor(0)), dim=(1,2,3))
    FN = torch.sum(torch.where((masks == 1) & (outputs == 0), torch.tensor(1), torch.tensor(0)), dim=(1,2,3))
    precision = torch.where(TP+FP == 0, torch.tensor(1), TP/(TP+FP)).sum()
    recall = torch.where(TP+FN == 0, torch.tensor(1), TP/(TP+FN)).sum()
    return precision, recall

In [None]:
import numpy as np 
list_IoU = []
list_precision = []
list_recall = []
thresholds = [i/100 for i in range(30, 91)]
with torch.no_grad():    
    for threshold in thresholds:
        IoU = 0
        precision = 0 
        recall = 0 
        for inputs, masks in dataloaders['val']:
            masks = masks.to(device).bool()
            inputs = inputs.to(device)
            outputs = UNet(inputs) 
            outputs = torch.where(outputs > threshold, 1., 0.).bool()
            IoU += compute_IoU(masks, outputs)
            p, r = get_precision_recall(masks, outputs)
            precision += p 
            recall += r 
        list_IoU.append(IoU / dataset_sizes['val'])
        list_precision.append(precision / dataset_sizes['val'] )
        list_recall.append(recall / dataset_sizes['val'] )
        print(f't = {threshold}, IoU = {list_IoU[-1]:.3f}, precision = {list_precision[-1]:.3f}, recall = {list_recall[-1]:.3f}')

In [None]:
list_precision = [x.item() for x in list_precision]
list_recall = [x.item() for x in list_recall]
list_IoU = [x.item() for x in list_IoU]

In [None]:
plt.figure()
plt.plot(thresholds, list_precision, label='Precision')
plt.plot(thresholds, list_recall, label='Recall')
plt.plot(thresholds, list_IoU, label='IoU')

plt.xlabel('Threshold')
plt.ylabel('Value')
plt.title('Precision, Recall, and IoU vs. Threshold')
plt.legend()

plt.grid(True)
plt.savefig('IoU_plot.png')
plt.show()

![figures/IoU_plot.png](figures/IoU_plot.png)

In [None]:
best_index = max((x,i) for i,x in enumerate(list_IoU))[1]
best_IoU, best_precision, best_recall = list_IoU[best_index], list_precision[best_index], list_recall[best_index]

In [None]:
print(f'Best IoU: {best_IoU:.3f}\nPrecision for best IoU: {best_precision:.3f}\nRecall for best IoU: {best_recall:.3f}\nThreshold: {thresholds[best_index]}')

> With this method, the best threshold is 0.69. <br>
> A few predictions with this threshold can be made.

In [None]:
inputs, masks = next(iter(dataloaders['val']))
masks = masks.to(device)
masks_rgb = masks.repeat(1, 3, 1, 1)
inputs = inputs.to(device)
outputs = UNet(inputs)
n = 5

outputs_binary = torch.where(outputs > thresholds[best_index], 1., 0.).repeat(1, 3, 1, 1)
outputs = outputs.repeat(1, 3, 1, 1)
# Concatenate the images along the width dimension to create a single tensor
combined_images = torch.cat([inputs[0:n,:,:,:], masks_rgb[0:n,:,:,:], outputs[0:n,:,:,:], outputs_binary[0:n,:,:,:]], dim=3)
# Make a grid with the combined images
grid_image = torchvision.utils.make_grid(combined_images, nrow=1, padding=10, normalize=True)
# Convert the grid tensor to a numpy array for visualization
grid_image_np = grid_image.permute(1, 2, 0).cpu().numpy()
# Display the grid of images
plt.figure(figsize=(10, 15))
plt.imshow(grid_image_np)

plt.axis('off')
plt.show()
plt.savefig(f'figures/best_threshold_outputs.png')

![figures/best_threshold_outputs.png](figures/best_threshold_outputs.png)