<a href="https://colab.research.google.com/github/ayachaieb/_dental_care/blob/main/dental_x_rays_caries_segmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! pip -q install segmentation_models_pytorch

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/106.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m41.0/106.7 kB[0m [31m1.0 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m106.7/106.7 kB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/58.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.8/58.8 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m29.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.8/268.8 kB[0m [31m19.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━

In [None]:
! pip install kaggle



In [None]:
! pip install kaleido

Collecting kaleido
  Downloading kaleido-0.2.1-py2.py3-none-manylinux1_x86_64.whl (79.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.9/79.9 MB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: kaleido
Successfully installed kaleido-0.2.1


In [None]:
! mkdir ~/.kaggle

In [None]:
! cp kaggle.json ~/.kaggle/

In [None]:
! cp kaggle.json ~/.kaggle/

In [None]:
! kaggle datasets download thunderpede/panoramic-dental-dataset

Downloading panoramic-dental-dataset.zip to /content
 96% 234M/244M [00:02<00:00, 78.1MB/s]
100% 244M/244M [00:04<00:00, 60.5MB/s]


In [None]:
! unzip panoramic-dental-dataset

Archive:  panoramic-dental-dataset.zip
  inflating: annotations/bboxes_caries/1008.txt  
  inflating: annotations/bboxes_caries/1009.txt  
  inflating: annotations/bboxes_caries/1016.txt  
  inflating: annotations/bboxes_caries/1018.txt  
  inflating: annotations/bboxes_caries/1026.txt  
  inflating: annotations/bboxes_caries/1033.txt  
  inflating: annotations/bboxes_caries/1042.txt  
  inflating: annotations/bboxes_caries/1050.txt  
  inflating: annotations/bboxes_caries/1058.txt  
  inflating: annotations/bboxes_caries/1067.txt  
  inflating: annotations/bboxes_caries/1074.txt  
  inflating: annotations/bboxes_caries/1080.txt  
  inflating: annotations/bboxes_caries/1083.txt  
  inflating: annotations/bboxes_caries/1088.txt  
  inflating: annotations/bboxes_caries/1090.txt  
  inflating: annotations/bboxes_caries/1091.txt  
  inflating: annotations/bboxes_caries/1092.txt  
  inflating: annotations/bboxes_caries/1096.txt  
  inflating: annotations/bboxes_caries/306.txt  
  inflating:

In [None]:
import os
import json
import warnings


import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import segmentation_models_pytorch as smp

from PIL import Image
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm


warnings.filterwarnings(action='ignore', category=UserWarning)

In [None]:
def seed_everything(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

seed_everything()

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

### Create CariesDataset

In [None]:
class CariesDataset(Dataset):
    def __init__(self, images_path_list, labels_path_list, augmentation = True, device = 'cpu', image_size = (384, 768)):
        self.images = images_path_list
        self.labels = labels_path_list
        self.augmentation = augmentation
        self.device = device

        self.transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.Grayscale(),
            transforms.ToTensor()
        ])

        if self.augmentation:
            self.same_augmentation = transforms.Compose([
                transforms.RandomRotation(degrees = 5),
                transforms.RandomHorizontalFlip(p = 0.5)
            ])

            self.different_augmentation = transforms.Compose([
                transforms.RandomAdjustSharpness(2),
                transforms.ColorJitter(brightness=0.5, contrast=0.5)
            ])


    def __getitem__(self, idx):
        image = Image.open(self.images[idx])
        label = Image.open(self.labels[idx])

        if self.augmentation:
            seed = np.random.randint(0, 10000)

            torch.random.manual_seed(seed)
            image = self.same_augmentation(image)
            image = self.different_augmentation(image)

            torch.random.manual_seed(seed)
            label = self.same_augmentation(label)

        image = self.transform(image).to(self.device)
        label = self.transform(label).to(self.device)

        label = 1. * (label != 0)

        return image, label

    def __len__(self):
        return len(self.images)

In [None]:
image_path = '/content/images_cut'
labels_path = '/content/labels_cut'

file_names = [filename for filename in os.listdir(image_path)]
train_files, val_files = train_test_split(file_names, test_size=0.2, random_state=42)


train_image_path = [image_path + file_name for file_name in train_files]
train_mask_path = [labels_path + file_name for file_name in train_files]

eval_image_path = [image_path + file_name for file_name in val_files]
eval_mask_path = [labels_path + file_name for file_name in val_files]

In [None]:
train_dataset = CariesDataset(
    images_path_list = train_image_path,
    labels_path_list = train_mask_path,
    augmentation = True,
    device = device,
)

eval_dataset = CariesDataset(
    images_path_list = eval_image_path,
    labels_path_list = eval_mask_path,
    augmentation = False,
    device = device,
)

toPIL = transforms.ToPILImage()

batch_size = 8

train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
eval_dataloader = DataLoader(eval_dataset, shuffle=False, batch_size=1)

### Create LossFunction and CalculateMetricFunction

In [None]:
class DiceLoss(nn.Module):
    def __init__(self, smooth = 1, activation = None):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
        self.activation = activation

    def forward(self, inputs, targets):
        if self.activation:
            inputs = self.activation(inputs)

        inputs = inputs.view(-1)
        targets = targets.view(-1)

        intersection = (inputs * targets).sum()
        dice = (2. * intersection + self.smooth)/(inputs.sum() + targets.sum() + self.smooth)

        return 1 - dice



def metric_calculate(prediction: np.ndarray, target: np.ndarray):

    target = np.uint8(target.flatten() > 0.5)
    prediction = np.uint8(prediction.flatten() > 0.5)
    TP = (prediction * target).sum()
    FN = ((1 - prediction) * target).sum()
    TN = ((1 - prediction) * (1 - target)).sum()
    FP = (prediction * (1 - target)).sum()

    acc = (TP + TN) / (TP + TN + FP + FN + 1e-4)
    iou = TP / (TP + FP + FN + 1e-4)
    dice = (2 * TP) / (2 * TP + FP + FN + 1e-4)
    pre = TP / (TP + FP + 1e-4)
    spe = TN / (FP + TN + 1e-4)
    sen = TP / (TP + FN + 1e-4)

    return acc, iou, dice, pre, spe, sen

### Create model for semantic segmentation from SMP library

In [None]:
model = smp.UnetPlusPlus(
    encoder_name = 'efficientnet-b0',
    encoder_weights = 'imagenet',
    in_channels = 1,
    classes = 1,
).to(device)

model_name = 'UNetEfficientnetB0'

criterion = DiceLoss(activation=F.sigmoid)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

num_epoch = 50

Downloading: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth" to /root/.cache/torch/hub/checkpoints/efficientnet-b0-355c32eb.pth
100%|██████████| 20.4M/20.4M [00:00<00:00, 119MB/s] 


### Training Process

In [None]:
print (f'Training {model_name} start.')

IoU_max = 0.
losses_train, losses_val = [], []
metrics = []

for epoch in tqdm(range(num_epoch)):
    current_train_loss, current_val_loss = 0., 0.
    current_metric = np.zeros(6)

    model.train()
    for images, labels in train_dataloader:
        optimizer.zero_grad()
        logits = model(images)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        current_train_loss += loss.item() / len(train_dataloader)

    model.eval()
    with torch.no_grad():
        for images, labels in eval_dataloader:
            logits = model(images)
            loss = criterion(logits, labels)

            current_val_loss += loss.item() / len(eval_dataloader)
            current_metric += np.array(metric_calculate(
                logits.cpu().detach().numpy(),
                labels.cpu().detach().numpy())) / len(eval_dataloader)

    losses_train.append(current_train_loss)
    losses_val.append(current_val_loss)
    metrics.append(current_metric.tolist())

    if IoU_max < metrics[-1][1]:
        torch.save(model, f'{model_name}-best.pth')
        IoU_max = metrics[-1][1]

    print (f'Epoch: {epoch + 1}, train_loss: {losses_train[-1]:.4f}, val_loss: {losses_val[-1]:.4f}, IoU: {metrics[-1][1]:.4f}')


log = {}
log['train_loss'] = losses_train
log['eval_loss'] = losses_val
log['metric'] = metrics
log['best_score'] = IoU_max

torch.save(model, f'{model_name}-last.pth')

with open(f'log.txt', 'w') as outfile:
    json.dump(log, outfile)

torch.cuda.empty_cache()

print ('- - ' * 30)
print (f'Training {model_name} done. Best IoU: {IoU_max:.4f}.')
print ('- - ' * 30)

Training UNetEfficientnetB0 start.


  0%|          | 0/50 [00:00<?, ?it/s]

FileNotFoundError: ignored

In [None]:
plt.figure(figsize = (12, 4))

plt.subplot(1, 2, 1)
plt.plot(losses_train, label = 'train_loss')
plt.plot(losses_val, label = 'val_loss')
plt.grid()
plt.legend()
plt.xlabel('Epoch')
plt.ylabel('Loss')

plt.subplot(1, 2, 2)
plt.plot([metric[1] for metric in metrics], label = 'IoU')
plt.axhline (IoU_max, linestyle = '--', color = 'red', label = f'IoU_max: {IoU_max:.4f}')
plt.grid()
plt.legend()
plt.xlabel('Epoch')
plt.ylabel('IoU')

plt.savefig('efficientnet-b0.png')
plt.show()

### Inference

In [None]:
model = torch.load('UNetEfficientnetB0-best.pth').eval()
image, label = eval_dataset[6]

In [None]:
probabilities = F.sigmoid(model(image.unsqueeze(0))).squeeze(0)

In [None]:
plt.figure(figsize = (16, 10))

plt.subplot(2, 2, 1)
plt.imshow(toPIL(image), cmap='gray')
plt.title('Input Image')

plt.subplot(2, 2, 2)
plt.imshow(toPIL(label), cmap='gray')
plt.title('GroundTruth Mask')

plt.subplot(2, 2, 3)
plt.imshow(toPIL(probabilities), cmap='gray')
plt.title('Prediction Mask')

plt.subplot(2, 2, 4)
plt.imshow(toPIL(probabilities), alpha = 0.5, cmap='gray')
plt.imshow(toPIL(label), alpha=0.5, cmap='gray')
plt.title('Prediction & GroundTruth')

plt.show()