<a href="https://colab.research.google.com/github/NailKhelifa/FewShotsLearning/blob/main/exemple_deeplabv3_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Notation:** \
N : batch size \
C : nb classes \
H : height \
W : width \

In [41]:
%pip install opencv-python

Collecting opencv-python
  Downloading opencv_python-4.9.0.80-cp37-abi3-macosx_10_16_x86_64.whl.metadata (20 kB)
Downloading opencv_python-4.9.0.80-cp37-abi3-macosx_10_16_x86_64.whl (55.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m55.7/55.7 MB[0m [31m40.7 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: opencv-python
Successfully installed opencv-python-4.9.0.80
Note: you may need to restart the kernel to use updated packages.


In [42]:
from tqdm import tqdm
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision.models.segmentation import deeplabv3_resnet50, DeepLabV3_ResNet50_Weights
from torchvision.models.segmentation.deeplabv3 import DeepLabHead
import torchvision.transforms.v2 as transforms
import torch.nn as nn
import torch.nn.functional as F
import os
import cv2
from pathlib import Path

In [43]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device
data_dir = os.getcwd() + '/data'

### Load the data

In [44]:
def load_dataset(dataset_dir):
    dataset_list = []
    for image_file in list(sorted(Path(dataset_dir).glob("*.png"), key=lambda filename: int(filename.name.rstrip(".png")))):
        dataset_list.append(cv2.imread(str(image_file), cv2.IMREAD_GRAYSCALE))
    return np.stack(dataset_list, axis=0)

data_train = load_dataset(data_dir + "/X_train")
data_test = load_dataset(data_dir + "/X_test")

labels_train = pd.read_csv(data_dir + "/Y_train.csv", index_col=0).T

In [45]:
data_train = load_dataset(data_dir + "/X_train")
data_test = load_dataset(data_dir + "/X_test")

Regardons le nombre de classes totales pour les labels, et le nombre de classes max par image

In [46]:
tot_classes = labels_train.max().max() + 1
max_classes = np.max([len(np.unique(labels_train.iloc[k])) for k in range(len(labels_train))])
tot_classes, max_classes

(106, 31)

Bien qu'il y aie 104 classes différentes dans les labels, on remarque qu'il n'y en a au maximum que 30 par image. On peut réduire la valeur des labels pour chaque image avec le code suivant (long) :

In [51]:
def consecutive_values(row):
    """Modifie les valeurs de la ligne pour obtenir des entiers entre 0 et le nb de classes sur l'image tout en conservant les différences"""
    l, _ = pd.factorize(row, sort=True)
    return l

labels_train = labels_train.to_numpy()
labels_trainr = np.array([consecutive_values(row) for row in labels_train])
labels_trainr = pd.DataFrame(labels_trainr)

tot_classes = labels_trainr.max().max() + 1

assert tot_classes == max_classes

### Créer des datasets simples

In [52]:
labels = []
for k in range(len(labels_trainr)) :
    labels.append(torch.tensor(np.array(labels_trainr.iloc[k]).reshape(512, 512)))


y_train = torch.stack(labels[0:300])
x_train = torch.tensor(data_train[0:300]).unsqueeze(1)  # unsqueeze pour la dimension des channels de couleur (1 car greyscale)

y_valid = torch.stack(labels[300:400])
x_valid = torch.tensor(data_train[300:400]).unsqueeze(1)

x_test = torch.tensor(data_test).unsqueeze(1)

In [74]:
(labels_trainr.sum(axis = 1) == 0)[0:400].sum()

0

In [8]:
class DataSet_with_transform(Dataset):
    def __init__(self, x_dataset, y_dataset, transform=None):
        self.x = x_dataset
        self.y = y_dataset
        assert len(self.x) == len(self.y), "x and y should have same length"
        self.transform = transform

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        x, y = self.x[idx], self.y[idx]
        if self.transform:
            x = self.transform(x)
        return [x, y]

# train_dataset = DataSet_with_transform(x_train, y_train)
# valid_dataset = DataSet_with_transform(y_valid, x_valid)

# un dataloader simple pour batcher le train dataset:
# train_dataloader = DataLoader(train_dataset, batch_size = 64, shuffle = True, pin_memory = True)

### Example de finetuning avec deeplabv3

Commencons par charger le modèle et l'adapter à nos données

In [19]:
weights = DeepLabV3_ResNet50_Weights.DEFAULT # on commence par charger les weights
model = deeplabv3_resnet50(weights=weights) # instanciation du modèle
model.backbone.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) # on modifie la première couche pour prendre du greyscale
resize = 520 # le resize adapté pour le modèle

# On fait une fonction de preprocess pour resize et normalize
preprocess = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize(resize),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.225])])

del weights
# On fait une fonction de postprocess pour remettre faire correspondre l'output à la dimension des labels
def postprocess(batch):
  return F.interpolate(batch, 512, mode = 'nearest-exact')

# print(model) # (pour regarder l'architecture si besoin)

In [20]:
num_classes = 30 # le nombre de classes à détecter, 0 inclus
# On va mettre max 12 classes pour réduire le temps d'éxécution/ la RAM
in_channels = 256 # le nombre de canaux en entrée du classifier
model.classifier[4] = DeepLabHead(in_channels, num_classes) # on change le classifier pour avoir le bon nombre de classes

In [21]:
model = model.to(device)

On va créer un Dataset avec les images déja préprocess

In [20]:
train_dataset = DataSet_with_transform(x_train, y_train, transform = preprocess)
valid_dataset = DataSet_with_transform(x_valid, y_valid, transform = preprocess)
train_dataloader = DataLoader(train_dataset, batch_size = 8, shuffle = True, pin_memory = True)

#### La loss fonction qu'on va utiliser :

In [21]:
class MulticlassDiceLoss(nn.Module):
    """ Reference: https://www.kaggle.com/code/bigironsphere/loss-function-library-keras-pytorch#Dice-Loss """
    def __init__(self, num_classes):
        super().__init__()
        self.num_classes = num_classes

    def forward(self, logits, targets, smooth=1e-6):
        """Computes the dice loss for all classes and provides an overall weighted loss."""
        probabilities = logits

        targets_one_hot = torch.nn.functional.one_hot(targets.long(), num_classes=self.num_classes)
        # Convert from NHWC to NCHW
        targets_one_hot = targets_one_hot.permute(0, 3, 1, 2)

        # Multiply one-hot encoded ground truth labels with the probabilities to get the
        # prredicted probability for the actual class.
        intersection = (targets_one_hot * probabilities).sum()

        mod_a = intersection.sum()
        mod_b = targets.numel()

        dice_coefficient = 2. * intersection / (mod_a + mod_b + smooth)
        dice_loss = -dice_coefficient.log()
        return dice_loss

torch.cuda.empty_cache()

model.eval()
with torch.no_grad():
  criterion = MulticlassDiceLoss(num_classes=30)
  x, y = next(iter(train_dataloader))
  x = x.to(device)
  y = y.to(device)
  logit = postprocess(model(x)['out'].softmax(dim = 1))
print(criterion(logit, y))
del x
del y
del logit
torch.cuda.empty_cache()

tensor(2.7212, device='cuda:0')


In [None]:
import time

model.train()
num_epochs = 10
optimizer = torch.optim.AdamW(model.parameters(), lr=0.005)
loss_over_epochs = []

for epoch in range(num_epochs):
    batch_loss = []
    epoch_start_time = time.time()  # Mesure du temps de l'époque
    for x, y in train_dataloader:
        start_time = time.time()
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        outputs = postprocess(model(x)['out'].softmax(dim=1))
        loss = criterion(outputs, y)
        loss.backward()
        optimizer.step()
        batch_loss.append(loss.item())

        batch_time = time.time() - start_time
        torch.cuda.empty_cache()
    epoch_time = time.time() - epoch_start_time  # Calcul du temps écoulé pour l'époque
    print(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {np.mean(batch_loss)}, Mean batch Time: {epoch_time/100:.4f} seconds')

    loss_over_epochs.append(np.mean(batch_loss))

Epoch [1/10], Average Loss: 0.031249210285022855, Mean batch Time: 3.0026 seconds
