In [1]:
import os

import cv2
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import albumentations as album
from PIL import Image

import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR, CosineAnnealingWarmRestarts, SequentialLR

import segmentation_models_pytorch as smp
from segmentation_models_pytorch.utils import base, metrics, train
from segmentation_models_pytorch.losses import DiceLoss, SoftBCEWithLogitsLoss

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
alldata_df = pd.read_csv("./data.csv")

train_df, test_df = train_test_split(alldata_df, test_size=0.1)

In [3]:
training_transform = album.Compose([
    album.HorizontalFlip(),
    album.VerticalFlip(),
    album.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=15, p=0.9, border_mode=cv2.BORDER_REFLECT),
    
    album.OneOf([
        album.ElasticTransform(p=0.3),
        album.MedianBlur(p=0.3),
        album.MotionBlur(p=0.3),
        album.GaussianBlur(p=0.3),
        album.GaussNoise(p=0.3),
        album.OpticalDistortion(p=0.3),
        album.GridDistortion(p=0.1),
    ], p=0.3),

    album.OneOf([
        album.ColorJitter(p=0.5),
        album.HueSaturationValue(15, 25, 0),
        album.CLAHE(clip_limit=2),
        album.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.75),
    ], p=0.3),
], p=0.9)

In [4]:
class_name = [
    "tooth",
    "background",
]

class_rgb_values = [
    [255, 255, 255],
    [0, 0, 0]
]

def to_onehot(mask, mask_values):
    semantic_map = []
    
    for colour in mask_values:
        equality = np.equal(mask, colour)
        class_map = np.all(equality, axis=-1)
        semantic_map.append(class_map)

    semantic_map = np.stack(semantic_map, axis=-1)
    return semantic_map

def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype("float32")

def get_preprocessing(fn=None):
    transform = []
    if fn:
        transform.append(album.Lambda(image=fn))
    transform.append(album.Lambda(image=to_tensor, mask=to_tensor))
    return album.Compose(transform)

In [5]:
class ToothDataset(Dataset):
    def __init__(self, df, transform=None, preprocessing=None, train=True):
        self.image_paths = df["image"].tolist()
        self.mask_paths = df["mask"].tolist()

        self.transform = transform
        self.preprocessing = preprocessing
        
        self.train = train

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        if self.train:
            image = cv2.cvtColor(cv2.imread(self.image_paths[idx]), cv2.COLOR_BGR2RGB)
            mask = cv2.cvtColor(cv2.imread(self.mask_paths[idx]), cv2.COLOR_BGR2RGB)

            mask = to_onehot(mask, class_rgb_values).astype("float")

            if self.transform:
                sample = self.transform(image=image, mask=mask)
                image, mask = sample["image"], sample["mask"]

            if self.preprocessing:
                sample = self.preprocessing(image=image, mask=mask)
                image, mask = sample["image"], sample["mask"]

            return image, mask

        else:
            image = cv2.cvtColor(cv2.imread(self.image_paths[idx]), cv2.COLOR_BGR2RGB)

            if self.transform:
                sample = self.transform(image=image)
                image = sample["image"]

            if self.preprocessing:
                sample = self.preprocessing(image=image)
                image = sample["image"]

            return image, self.mask_paths[idx]  

In [6]:
class MyLoss(base.Loss):
    def __init__(self):
        super(MyLoss, self).__init__()

        self.diceloss = DiceLoss(mode="binary")
        self.binloss = SoftBCEWithLogitsLoss(reduction="mean", smooth_factor=0.1)

    def forward(self, output, mask):
        dice_loss = self.diceloss(output, mask)
        bin_loss = self.binloss(output, mask)

        return dice_loss * 0.7 + bin_loss * 0.3

In [7]:
ENCODER = "mit_b2"
ENCODER_WEIGHTS = "imagenet"
CLASSES = class_name
ACTIVATION = "sigmoid"

model = smp.MAnet(
    encoder_name=ENCODER,
    encoder_weights=ENCODER_WEIGHTS,
    encoder_depth=4,
    decoder_channels=(512, 256, 128, 64),
    classes=len(CLASSES),
    activation=ACTIVATION,
)

In [8]:
preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

train_ds = ToothDataset(
    train_df, 
    transform=training_transform, 
    preprocessing=get_preprocessing(preprocessing_fn)
)

test_ds = ToothDataset(
    test_df, 
    transform=None, 
    preprocessing=get_preprocessing(preprocessing_fn)
)

train_dl = DataLoader(train_ds, batch_size=8, shuffle=True, num_workers=8)
test_dl = DataLoader(test_ds, batch_size=8, shuffle=False, num_workers=8)

In [9]:
EPOCHS = 100
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TRAIN = False
INFER = True

In [10]:
model_path = "./model/best_iou_0.9546_at_10.pth"
model = torch.load(model_path, map_location=DEVICE)

loss = MyLoss()
metrics = [metrics.IoU(threshold=0.5)]
optimizer = torch.optim.AdamW([dict(params=model.parameters(), lr=1e-4, weight_decay=0.01)])
schedulers = [
    CosineAnnealingWarmRestarts(optimizer, T_0=1, T_mult=2, eta_min=1e-6),
    CosineAnnealingLR(optimizer=optimizer, T_max=10, eta_min=1e-6)
]
lr_scheduler = SequentialLR(optimizer, schedulers, milestones=[60])

In [11]:
train_epoch = train.TrainEpoch(
    model,
    loss=loss,
    metrics=metrics,
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)

valid_epoch = train.ValidEpoch(
    model,
    loss=loss,
    metrics=metrics,
    device=DEVICE,
    verbose=True,
)

# Train

In [12]:
if TRAIN:
    best_ephoch = 0
    best_iou_score = 0.0
    train_logs_list, test_logs_list = [], []

    for i in range(0, EPOCHS):
        print(f"Epoch: {i}, Learning Rate: {optimizer.param_groups[0]['lr']}")

        train_logs = train_epoch.run(train_dl)
        test_logs = valid_epoch.run(test_dl)

        train_logs_list.append(train_logs)
        test_logs_list.append(test_logs)

        lr_scheduler.step()
        
        if best_iou_score < test_logs["iou_score"]:
            # if best_iou_score != 0:
            #     os.remove(f"model/best_iou_{best_iou_score:.4f}_at_{best_ephoch}.pth")
            #     print(f"history score {best_iou_score:.4f} model removed!")

            best_ephoch = i
            best_iou_score = test_logs["iou_score"]

            torch.save(model, f"model/best_iou_{best_iou_score:.4f}_at_{i}.pth")

            print(f"best score {best_iou_score:.4f} model saved!")

# Infer

In [13]:
if INFER:
    infer_df = pd.read_csv("./unlabelled.csv")

    infer_ds = ToothDataset(
        infer_df, 
        transform=None, 
        preprocessing=get_preprocessing(preprocessing_fn),
        train=False
    )

    uncertainty = []

    for i in tqdm(range(len(infer_ds))):
        image, result_path = infer_ds[i]

        model.eval()

        with torch.no_grad():
            x_tensor1 = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
            pred_mask1 = model(x_tensor1)
            pred_mask1 = pred_mask1[:, 0, :]

            x_tensor2 = torch.flip(x_tensor1, [2])
            pred_mask2 = model(x_tensor2)
            pred_mask2 = torch.flip(pred_mask2, [2])[:, 0, :]

            x_tensor3 = torch.flip(x_tensor1, [3])
            pred_mask3 = model(x_tensor3)
            pred_mask3 = torch.flip(pred_mask3, [3])[:, 0, :]
        
        pred_mask = (pred_mask1 + pred_mask2 + pred_mask3) / 3.0
        prob = pred_mask.cpu().numpy().reshape(320, 640)

        uncert = -np.sum(prob * np.log(prob)) / np.sum(prob)
        uncertainty.append({
            "path": result_path,
            "uncertainty": uncert
        })

        threshold = 0.5
        pred_mask = torch.where(pred_mask >= threshold, torch.tensor(255, dtype=torch.float).to(DEVICE), pred_mask)
        pred_mask = torch.where(pred_mask < threshold, torch.tensor(0, dtype=torch.float).to(DEVICE), pred_mask)

        out = pred_mask.detach().cpu().numpy().reshape(1, 320, 640)

        image = Image.fromarray(out[0].astype(np.uint8))
        image.convert("1")

        image.save(result_path)

    uncertainty = sorted(uncertainty, key=lambda x: x["uncertainty"], reverse=False)
    uncertainty_df = pd.DataFrame(uncertainty)
    uncertainty_df.to_csv("./uncertainty.csv", index=False)

100%|██████████| 2100/2100 [02:23<00:00, 14.61it/s]


# Post Processing

In [14]:
def filter_small_connected_components(binary_image, threshold):
    nb_components, output, stats, _ = cv2.connectedComponentsWithStats(binary_image, connectivity=8)
    sizes = stats[:, -1]
    filtered_image = np.zeros(output.shape)
    for i in range(1, nb_components):
        if sizes[i] >= threshold:
            filtered_image[output == i] = 255
    return filtered_image


def process_images_in_directory(input_dir, output_dir, threshold):
    os.makedirs(output_dir, exist_ok=True)
    for filename in tqdm(os.listdir(input_dir)):
        if filename.endswith(".png"):
            image_path = os.path.join(input_dir, filename)
            binary_image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
            filtered_image = filter_small_connected_components(binary_image, threshold).astype(np.bool_)
            pil_image = Image.fromarray(filtered_image)
            image.convert("1")
            output_path = os.path.join(output_dir, filename)
            pil_image.save(output_path)

In [15]:
# process_images_in_directory(f"./infer/label/", f"./infer/label/", 128)
process_images_in_directory(f"./data/unlabelled/fakelabel/", f"./data/unlabelled/fakelabel/", 128)