<a href="https://colab.research.google.com/github/NoeGille/UNet-on-fashion-mnist/blob/main/UNet_on_fashion_MNIST_in_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Imports

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import cv2
from tqdm import tqdm

In [None]:
INPUT_SIZE = (28, 28, 1)
BATCH_SIZE = 64
LEARNING_RATE = 0.001
NUM_CLASSES = 11
EPOCHS = 10


train_dataset = datasets.FashionMNIST(root='tmp/dataset/', train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataset = datasets.FashionMNIST(root='tmp/dataset/', train=False, transform=transforms.ToTensor(), download=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=True)


In [None]:
# A faire : ajouter des batch normalization si problème de valeur extremes en sorties


class DownSampleBlock(nn.Module):
    '''Reduce the dimension of the image in input by 2'''
    def __init__(self, in_channels, out_channels):
        super(DownSampleBlock, self).__init__()
        # We keep the same dimension in input and ouput
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                               kernel_size=(3, 3), stride=(1, 1), padding=(1,1))
        self.norm1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels,
                               kernel_size=(3, 3), stride=(1, 1), padding=(1,1))
        self.norm2 = nn.BatchNorm2d(out_channels)
        self.pool = nn.MaxPool2d(kernel_size=(2, 2), stride=(2,2))

    def forward(self, x):
        x = self.norm2(self.conv2(self.relu(self.norm1(self.conv1(x)))))
        return self.pool(x), x

class DoubleConvolution(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConvolution, self).__init__()
        # We keep the same dimension in input and ouput
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                               kernel_size=(3, 3), stride=(1, 1), padding=(1,1))
        self.norm1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels,
                               kernel_size=(3, 3), stride=(1, 1), padding=(1,1))
        self.norm2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
      x = self.norm2(self.conv2(self.relu(self.norm1(self.conv1(x)))))
      return x

class UpSampleBlock(nn.Module):
    '''Increase the dimension of the input and reduce its number of channels by 2'''
    def __init__(self, in_channels):
        super(UpSampleBlock, self).__init__()
        self.up1 = nn.ConvTranspose2d(in_channels, in_channels, 2, 2)
        self.conv1 = nn.Conv2d(in_channels=in_channels,
                                     out_channels=in_channels, kernel_size=(3, 3),
                                     stride=(1, 1), padding=(1,1))
        
    def forward(self, x):
        x = self.conv1(self.up1(x))
        return x

class ResidualConnection(nn.Module):
    '''Concatenate inputs of two blocs'''

    def __init__(self, in_channels, out_channels):
        '''in_channels has the same dimensions as out_channels'''

        super(ResidualConnection, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channels * 2, out_channels=in_channels,
                               kernel_size=(3, 3), stride=(1, 1), padding=(1,1))
        self.norm1 = nn.BatchNorm2d(in_channels)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                               kernel_size=(3, 3), stride=(1, 1), padding=(1,1))
        self.norm2 = nn.BatchNorm2d(out_channels)

    def forward(self, x1, x2):
        x = torch.cat([x2, x1], dim=1)
        x = self.norm2(self.conv2(self.relu(self.norm1(self.conv1(x)))))
        return x

class UNet(nn.Module):

    # Numbers of filters for the first layer of convolution
    # <!> The number of filters will double for each down sample blocks which can
    # lead to very high numbers of parameters very quickly <!>
    NB_OF_FILTERS = 16

    def __init__(self, input_size, num_classes:int=10, depth:int=2):
        '''### Initialize a UNet model
        input_size : dimension of input
        num_classes : specify the number of classes in ouput
        depth : the number of blocks (depth of the model)'''
        super(UNet, self).__init__()
        channels = [input_size[-1]] + [self.NB_OF_FILTERS * (i + 1) for i in range(depth)]
        # first downsampling block
        self.dblocks = nn.ModuleList([DownSampleBlock(in_channels=channels[0], out_channels=channels[1])])
        self.bottleneck = DoubleConvolution(in_channels=channels[-1], out_channels=channels[-1])
        # Concatenate outputs from encoder and decoder to keep tracks of objects positions
        self.res_connect = nn.ModuleList([ResidualConnection(in_channels=channels[1], out_channels=num_classes)])
        # Last upsampling block
        self.ublocks = nn.ModuleList([UpSampleBlock(in_channels=channels[1])])

        for i in range(1,depth):
            # The number of channels double each time the depth increases
            self.dblocks.append(DownSampleBlock(in_channels=channels[i], out_channels=channels[i + 1]))
            self.res_connect.append(ResidualConnection(in_channels=channels[i + 1], out_channels=channels[i]))
            self.ublocks.append(UpSampleBlock(in_channels=channels[i + 1]))
        self.ublocks = self.ublocks[::-1]
        self.res_connect = self.res_connect[::-1]
        self.output = nn.Conv2d(in_channels=num_classes, out_channels=num_classes,
                               kernel_size=(3, 3), stride=(1, 1), padding=(1,1))

    def forward(self, x):
        depth = len(self.dblocks)

        # Encoder
        # Copy of output of each blocks before downsampling
        xs_down =[]
        for i, down_block in enumerate(self.dblocks):
            x, copy = down_block.forward(x)
            xs_down.append(copy)
        x = self.bottleneck.forward(x)
        xs_down = xs_down[::-1]
        # Decoder
        for i, up_block in enumerate(self.ublocks):
            x_up = up_block.forward(x)
            x = self.res_connect[i](x_up, xs_down[i])
        
        x = F.softmax(self.output(x))
        
        return x


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = UNet(INPUT_SIZE, num_classes=NUM_CLASSES).to(device=device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)


In [None]:
def transform_targets(X, y_label):
    y = np.array([cv2.threshold(img, 8, 255, type=cv2.THRESH_BINARY)[1] for img 
                  in X.permute(0, 3, 2, 1).numpy() * 255])
    y = np.where(y == 0, 0, 1)   
    for s in range(y.shape[0]):
      y[s] *= y_label[s].numpy() + 1
    return torch.from_numpy(np.eye(NUM_CLASSES, dtype='uint8')[y]).permute(0, 3, 2, 1)

def metrics_values(y_valid, y_pred):
    '''return values useful for computation of accuracy and dice score'''
    y_valid_idx = y_valid.argmax(1)
    y_pred_idx = y_pred.argmax(1)
    num_correct = y_pred_idx.eq(y_valid_idx).sum().item()
    num_samples = y_valid.size(0)

    # Dice score 
    valid_mask = torch.where(y_valid_idx == 0, 0, 1).sum()
    pred_mask = torch.where(y_pred_idx == 0, 0, 1).sum()
    intersection = torch.where((y_valid_idx == y_pred_idx) & (y_pred_idx != 0), 1, 0).sum()
    return (num_correct, num_samples, valid_mask, pred_mask, intersection)


accuracies = []
dice_scores = []

for epoch in range(EPOCHS):
    num_correct = 0
    num_samples = 0
    valid_mask = 0
    pred_mask = 0
    intersection = 0
    for data,y_label in tqdm(train_loader):
        targets = transform_targets(data, y_label)
        
        data = data.to(device=device)
        targets = targets.to(device=device)

        # prediction
        scores = model(data)
        # Flatten the output for loss computation
        scores = scores.permute(0, 2, 3, 1).contiguous().view(-1, scores.size(1))
        # Flatten and convert to float targets for loss computation
        targets = targets.permute(0, 2, 3, 1).contiguous().view(-1, targets.size(1)).float().to(device=device)
        # Calculate loss 
        loss = criterion(scores, targets)
        # Save data for plotting
        values = metrics_values(targets, scores)
        num_correct += values[0]
        num_samples += values[1]
        valid_mask += values[2]
        pred_mask += values[3]
        intersection += values[4]
        # backward
        optimizer.zero_grad()
        loss.backward()

        #gradient descent or adam step
        optimizer.step()
    acc = round(float(num_correct) / float(num_samples)*100, ndigits=2)
    dice_score = round(2 * float(intersection) / (float(valid_mask) + float(pred_mask)), ndigits=2)
    print(f"accuracy {acc} and dice score {dice_score}")
    accuracies.append(acc)
    dice_scores.append(dice_score)

In [None]:
# Plot evolution of accuracy and dice score
plt.plot(np.array(accuracies) / 100, label='accuracy')
plt.plot(dice_scores, label='dice score')
plt.legend()
plt.show()


In [None]:
from math import floor

def check_accuracy(loader, model):
    if loader.dataset.train:
        print("Checking accuracy on training data")
    else:
        print("Checking accuracy on test data")

    num_correct = 0
    num_samples = 0
    intersection = 0
    valid_mask = 0
    pred_mask = 0
    model.eval()
    img = None
    with torch.no_grad():
        
        for x, y in tqdm(loader):
            y = transform_targets(x, y)
            x = x.to(device=device)
            y = y.to(device=device)
            scores = model(x)

            # Keep the first image of the batch to visualize the prediction
            img = scores[0]

            # Accuracy
            y_valid = y.permute(0, 2, 3, 1).contiguous().view(-1, targets.size(1)).float().to(device=device)
            y_pred = scores.permute(0, 2, 3, 1).contiguous().view(-1, scores.size(1))
            values = metrics_values(y_valid, y_pred)
            num_correct += values[0]
            num_samples += values[1]
            valid_mask += values[2]
            pred_mask += values[3]
            intersection += values[4]
            
        img = img.argmax(0).cpu().detach().numpy()
        acc = round(float(num_correct) / float(num_samples)*100, ndigits=2)
        dice_score = round(2 * float(intersection) / (float(valid_mask) + float(pred_mask)), ndigits=2)
        print(f"Got {num_correct} / {num_samples} with accuracy {acc} and dice score {dice_score}")
        
    model.train()
    return acc, img

_, img = check_accuracy(train_loader, model)

showing_result = False
if showing_result:
    # Showing a result (to be improved)
    img = np.uint8(img) * 24
    cv2.namedWindow("img", cv2.WINDOW_NORMAL)
    cv2.resizeWindow('img', 600, 600)
    cv2.imshow('img', img)


    cv2.waitKey(0)
    cv2.destroyAllWindows()

In [None]:
check_accuracy(test_loader, model)