<a href="https://colab.research.google.com/github/NoeGille/UNet-on-fashion-mnist/blob/main/UNet_on_fashion_MNIST_in_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Imports

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import cv2
from tqdm import tqdm

In [2]:
INPUT_SIZE = (28, 28, 1)
BATCH_SIZE = 64
LEARNING_RATE = 0.001
NUM_CLASSES = 10
EPOCHS = 1


train_dataset = datasets.FashionMNIST(root='tmp/dataset/', train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataset = datasets.FashionMNIST(root='tmp/dataset/', train=False, transform=transforms.ToTensor(), download=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=True)


Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to tmp/dataset/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:01<00:00, 15033974.83it/s]


Extracting tmp/dataset/FashionMNIST/raw/train-images-idx3-ubyte.gz to tmp/dataset/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to tmp/dataset/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 261146.95it/s]


Extracting tmp/dataset/FashionMNIST/raw/train-labels-idx1-ubyte.gz to tmp/dataset/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to tmp/dataset/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:00<00:00, 4960463.71it/s]


Extracting tmp/dataset/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to tmp/dataset/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to tmp/dataset/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 10233306.63it/s]

Extracting tmp/dataset/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to tmp/dataset/FashionMNIST/raw






In [23]:
# A faire : ajouter des batch normalization si problème de valeur extremes en sorties


class DownSampleBlock(nn.Module):
    '''Reduce the dimension of the image in input by 2'''
    def __init__(self, in_channels, out_channels):
        super(DownSampleBlock, self).__init__()
        # We keep the same dimension in input and ouput
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                               kernel_size=(3, 3), stride=(1, 1), padding=(1,1))
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels,
                               kernel_size=(3, 3), stride=(1, 1), padding=(1,1))
        self.pool = nn.MaxPool2d(kernel_size=(2, 2), stride=(2,2))

    def forward(self, x):
        x = self.pool(self.conv2(self.relu(self.conv1(x))))
        return x

class DoubleConvolution(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConvolution, self).__init__()
        # We keep the same dimension in input and ouput
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                               kernel_size=(3, 3), stride=(1, 1), padding=(1,1))
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels,
                               kernel_size=(3, 3), stride=(1, 1), padding=(1,1))

    def forward(self, x):
      x = self.conv2(self.relu(self.conv1(x)))
      return x

class UpSampleBlock(nn.Module):
    '''Increase the dimension of the input and reduce its number of channels by 2'''
    def __init__(self, in_channels, out_channels):
        super(UpSampleBlock, self).__init__()
        self.up1 = nn.ConvTranspose2d(in_channels, out_channels, 2, 2)
        self.conv1 = nn.Conv2d(in_channels=out_channels,
                                     out_channels=out_channels, kernel_size=(3, 3),
                                     stride=(1, 1), padding=(1,1))
        
    def forward(self, x):
        x = self.conv1(self.up1(x))
        return x

class ResidualConnection(nn.Module):
    '''Concatenate inputs of two blocs'''
    def __init__(self, in_channels):
        super(ResidualConnection, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels,
                               kernel_size=(3, 3), stride=(1, 1), padding=(1,1))
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels,
                               kernel_size=(3, 3), stride=(1, 1), padding=(1,1))

    def forward(self, x, y):
        x = torch.cat([x, y], -1)
        x = self.conv2(self.relu(self.conv1(x)))
        return x

class UNet(nn.Module):

    # Numbers of filters for the first layer of convolution
    # <!> The number of filters will double for each down sample blocks which can
    # lead to very high numbers of parameters very quickly <!>
    NB_OF_FILTERS = 16

    def __init__(self, input_size, num_classes:int=10, depth:int=2):
        '''### Initialize a UNet model
        input_size : dimension of input
        num_classes : specify the number of classes in ouput
        depth : the number of blocks (depth of the model)'''
        super(UNet, self).__init__()
        self.double_conv1 = DoubleConvolution(in_channels=1, out_channels = self.NB_OF_FILTERS)
        # Reduce size
        self.dblocks = [DownSampleBlock(in_channels=1, out_channels=self.NB_OF_FILTERS)]
        self.bottleneck = DoubleConvolution(in_channels=(2**(depth - 1)) * self.NB_OF_FILTERS, out_channels=(2**(depth - 1)) * self.NB_OF_FILTERS)
        # Concatenate outputs from encoder and decoder to keep tracks of objects
        # positions
        self.res_connect = []
        # Increase size
        self.ublocks = [UpSampleBlock(in_channels=self.NB_OF_FILTERS, out_channels=num_classes)]
        in_channell=[]
        out_channell=[]
        for i in range(1,depth):
            # The number of channels double each time the depth increases
            out_channels = INPUT_SIZE[-1] * 2**(i) * self.NB_OF_FILTERS
            out_channell.append(out_channels)
            in_channels = INPUT_SIZE[-1] * 2**(i - 1) * self.NB_OF_FILTERS
            in_channell.append(in_channels)
            self.dblocks.append(DownSampleBlock(in_channels=in_channels, out_channels=out_channels))
            self.res_connect.append(ResidualConnection(in_channels=out_channels))
            self.ublocks.append(UpSampleBlock(in_channels=out_channels, out_channels=in_channels))
        self.output = nn.Conv2d(in_channels=num_classes, out_channels=num_classes,
                               kernel_size=(3, 3), stride=(1, 1), padding=(1,1))

    def forward(self, x):
        depth = len(self.dblocks)
        # Encoder
        for i, down_block in enumerate(self.dblocks):
            x = down_block.forward(x)
        x = self.bottleneck.forward(x)
        # Decoder
        for i, up_block in enumerate(self.ublocks[::-1]):
            x = up_block.forward(x)
        # Flatten the output for loss computation
        x = self.output(x)
        x = x.permute(0, 2, 3, 1).contiguous().view(-1, x.size(1))
        return x


In [20]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cuda')
model = UNet(INPUT_SIZE,device=device).to(device=device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)


In [24]:

def transform_targets(X, y_label):
    y = np.array([cv2.threshold(img, 8, 255, type=cv2.THRESH_BINARY)[1] for img in X.permute(0, 3, 2, 1).numpy()])
    y = np.where(y == 0, 0, 1)
    for s in range(y.shape[0]):
      y[s] *= y_label[s].numpy() + 1
    return torch.from_numpy(np.eye(NUM_CLASSES, dtype='uint8')[y]).permute(0, 3, 2, 1)

for epoch in range(EPOCHS):
    for data,y_label in tqdm(train_loader):
        targets = transform_targets(data, y_label)
        
        data.to(device=device)
        targets = targets.to(device=device)
        # data.shape = (64, 1, 28, 28)
        # we want (64, 28, 28, 1)

        # forward 
        # prediction
        scores = model(data)
        # Flatten and convert to float targets for loss computation
        targets =  targets.permute(0, 2, 3, 1).contiguous().view(-1, targets.size(1)).float()
        # Calculate loss 
        loss = criterion(scores, targets)

        # backward
        optimizer.zero_grad()
        loss.backward()

        #gradient descent or adam step
        optimizer.step()

  0%|          | 0/938 [00:00<?, ?it/s]


RuntimeError: ignored