<a href="https://colab.research.google.com/github/NoeGille/UNet-on-fashion-mnist/blob/main/UNet_on_fashion_MNIST_in_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Imports

In [231]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import cv2
from tqdm import tqdm

In [232]:
INPUT_SIZE = (28, 28, 1)
BATCH_SIZE = 64
LEARNING_RATE = 0.001
NUM_CLASSES = 10
EPOCHS = 10


train_dataset = datasets.FashionMNIST(root='tmp/dataset/', train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataset = datasets.FashionMNIST(root='tmp/dataset/', train=False, transform=transforms.ToTensor(), download=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=True)


In [233]:
# A faire : ajouter des batch normalization si problème de valeur extremes en sorties


class DownSampleBlock(nn.Module):
    '''Reduce the dimension of the image in input by 2'''
    def __init__(self, in_channels, out_channels):
        super(DownSampleBlock, self).__init__()
        # We keep the same dimension in input and ouput
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                               kernel_size=(3, 3), stride=(1, 1), padding=(1,1))
        self.norm1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels,
                               kernel_size=(3, 3), stride=(1, 1), padding=(1,1))
        self.norm2 = nn.BatchNorm2d(out_channels)
        self.pool = nn.MaxPool2d(kernel_size=(2, 2), stride=(2,2))

    def forward(self, x):
        x = self.norm2(self.conv2(self.relu(self.norm1(self.conv1(x)))))
        return self.pool(x), x

class DoubleConvolution(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConvolution, self).__init__()
        # We keep the same dimension in input and ouput
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                               kernel_size=(3, 3), stride=(1, 1), padding=(1,1))
        self.norm1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels,
                               kernel_size=(3, 3), stride=(1, 1), padding=(1,1))
        self.norm2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
      x = self.norm2(self.conv2(self.relu(self.norm1(self.conv1(x)))))
      return x

class UpSampleBlock(nn.Module):
    '''Increase the dimension of the input and reduce its number of channels by 2'''
    def __init__(self, in_channels):
        super(UpSampleBlock, self).__init__()
        self.up1 = nn.ConvTranspose2d(in_channels, in_channels, 2, 2)
        self.conv1 = nn.Conv2d(in_channels=in_channels,
                                     out_channels=in_channels, kernel_size=(3, 3),
                                     stride=(1, 1), padding=(1,1))
        
    def forward(self, x):
        x = self.conv1(self.up1(x))
        return x

class ResidualConnection(nn.Module):
    '''Concatenate inputs of two blocs'''

    def __init__(self, in_channels, out_channels):
        '''in_channels has the same dimensions as out_channels'''

        super(ResidualConnection, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channels * 2, out_channels=in_channels,
                               kernel_size=(3, 3), stride=(1, 1), padding=(1,1))
        self.norm1 = nn.BatchNorm2d(in_channels)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                               kernel_size=(3, 3), stride=(1, 1), padding=(1,1))
        self.norm2 = nn.BatchNorm2d(out_channels)

    def forward(self, x1, x2):
        x = torch.cat([x2, x1], dim=1)
        x = self.norm2(self.conv2(self.relu(self.norm1(self.conv1(x)))))
        return x

class UNet(nn.Module):

    # Numbers of filters for the first layer of convolution
    # <!> The number of filters will double for each down sample blocks which can
    # lead to very high numbers of parameters very quickly <!>
    NB_OF_FILTERS = 16

    def __init__(self, input_size, num_classes:int=10, depth:int=2):
        '''### Initialize a UNet model
        input_size : dimension of input
        num_classes : specify the number of classes in ouput
        depth : the number of blocks (depth of the model)'''
        super(UNet, self).__init__()
        channels = [input_size[-1]] + [self.NB_OF_FILTERS * (i + 1) for i in range(depth)]
        # first downsampling block
        self.dblocks = nn.ModuleList([DownSampleBlock(in_channels=channels[0], out_channels=channels[1])])
        self.bottleneck = DoubleConvolution(in_channels=channels[-1], out_channels=channels[-1])
        # Concatenate outputs from encoder and decoder to keep tracks of objects positions
        self.res_connect = nn.ModuleList([ResidualConnection(in_channels=channels[1], out_channels=num_classes)])
        # Last upsampling block
        self.ublocks = nn.ModuleList([UpSampleBlock(in_channels=channels[1])])

        for i in range(1,depth):
            # The number of channels double each time the depth increases
            self.dblocks.append(DownSampleBlock(in_channels=channels[i], out_channels=channels[i + 1]))
            self.res_connect.append(ResidualConnection(in_channels=channels[i + 1], out_channels=channels[i]))
            self.ublocks.append(UpSampleBlock(in_channels=channels[i + 1]))
        self.ublocks = self.ublocks[::-1]
        self.res_connect = self.res_connect[::-1]
        self.output = nn.Conv2d(in_channels=num_classes, out_channels=num_classes,
                               kernel_size=(3, 3), stride=(1, 1), padding=(1,1))

    def forward(self, x):
        depth = len(self.dblocks)

        # Encoder
        # Copy of output of each blocks before downsampling
        xs_down =[]
        for i, down_block in enumerate(self.dblocks):
            x, copy = down_block.forward(x)
            xs_down.append(copy)
        x = self.bottleneck.forward(x)
        xs_down = xs_down[::-1]
        # Decoder
        for i, up_block in enumerate(self.ublocks):
            x_up = up_block.forward(x)
            x = self.res_connect[i](x_up, xs_down[i])
        # Flatten the output for loss computation
        x = self.output(x)
        x = x.permute(0, 2, 3, 1).contiguous().view(-1, x.size(1))
        return x


In [234]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = UNet(INPUT_SIZE).to(device=device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)


In [235]:
def transform_targets(X, y_label):
    y = np.array([cv2.threshold(img, 8, 255, type=cv2.THRESH_BINARY)[1] for img 
                  in X.permute(0, 3, 2, 1).numpy() * 255])
    y = np.where(y == 0, 0, 1)   
    for s in range(y.shape[0]):
      y[s] *= y_label[s].numpy()
    return torch.from_numpy(np.eye(NUM_CLASSES, dtype='uint8')[y]).permute(0, 3, 2, 1)

for epoch in range(EPOCHS):
    for data,y_label in tqdm(train_loader):
        targets = transform_targets(data, y_label)
        
        data = data.to(device=device)
        targets = targets.to(device=device)
        # data.shape = (64, 1, 28, 28)
        # we want (64, 28, 28, 1)

        # forward 
        # prediction
        scores = model(data)
        # Flatten and convert to float targets for loss computation
        targets = targets.permute(0, 2, 3, 1).contiguous().view(-1, targets.size(1)).float().to(device=device)
        # Calculate loss 
        loss = criterion(scores, targets)

        # backward
        optimizer.zero_grad()
        loss.backward()

        #gradient descent or adam step
        optimizer.step()

100%|██████████| 938/938 [00:18<00:00, 51.49it/s]
100%|██████████| 938/938 [00:19<00:00, 49.35it/s]
100%|██████████| 938/938 [00:18<00:00, 52.05it/s]
100%|██████████| 938/938 [00:18<00:00, 49.46it/s]
100%|██████████| 938/938 [00:18<00:00, 51.51it/s]
100%|██████████| 938/938 [00:18<00:00, 49.64it/s]
100%|██████████| 938/938 [00:17<00:00, 52.23it/s]
100%|██████████| 938/938 [00:18<00:00, 49.98it/s]
100%|██████████| 938/938 [00:18<00:00, 51.68it/s]
100%|██████████| 938/938 [00:18<00:00, 49.61it/s]


In [236]:
from math import floor
def check_accuracy(loader, model):
    if loader.dataset.train:
        print("Checking accuracy on training data")
    else:
        print("Checking accuracy on test data")

    num_correct = 0
    num_samples = 0
    model.eval()

    with torch.no_grad():
        for x, y in tqdm(loader):
            y = transform_targets(x, y)
            x = x.to(device=device)
            y = y.to(device=device)
            
            y = y.permute(0, 2, 3, 1).contiguous().view(-1, targets.size(1)).float().to(device=device)

            # Accuracy
            scores = model(x)
            _, predictions= scores.max(1)
            num_correct += predictions.eq(y.argmax(1)).sum().item()
            num_samples += y.size(0)

            # Dice score

        acc = round(float(num_correct) / float(num_samples)*100, ndigits=2)
        print(f"Got {num_correct} / {num_samples} with accuracy {acc}")
      
    model.train()
    return acc

check_accuracy(train_loader, model)

Checking accuracy on training data


100%|██████████| 938/938 [00:14<00:00, 65.58it/s]

Got 45055987 / 47040000 with accuracy 95.78





95.78

In [237]:
check_accuracy(test_loader, model)

Checking accuracy on test data


100%|██████████| 157/157 [00:02<00:00, 66.46it/s]

Got 7411854 / 7840000 with accuracy 94.54





94.54