In [None]:
try:
    import subprocess

    from google.colab import drive

    subprocess.run(["pip", "install", "torchmetrics"])
    base_dir = "/content/drive/MyDrive/Colab_Notebooks/Crack_Detection"
    drive.mount("/content/drive")
except:
    base_dir = "."

In [None]:
import os
import random
from collections import OrderedDict

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms.functional as F
from PIL import Image
from scipy.ndimage import distance_transform_edt
from skimage.segmentation import find_boundaries
from torch import optim
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
from torchmetrics.classification import BinaryAccuracy, BinaryPrecision, BinaryRecall
from torchmetrics.clustering import RandScore

In [None]:
class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=2, init_features=64, dropout_p=0.5):
        # TODO adjust dropout probability
        super().__init__()
        features = init_features
        kernel_size = 3
        self.encoder1 = block(in_channels, features, kernel_size, "enc1")
        self.pool = nn.MaxPool2d(stride=2, kernel_size=2)
        self.encoder2 = block(features, features * 2, kernel_size, "enc2")
        self.encoder3 = block(features * 2, features * 4, kernel_size, "enc3")
        self.encoder4 = block(features * 4, features * 8, kernel_size, "enc4")
        self.dropout = nn.Dropout2d(p=dropout_p)
        self.bottleneck_encoder = block(
            features * 8, features * 16, kernel_size, "bottleneck_enc"
        )
        self.upconv1 = nn.ConvTranspose2d(
            features * 16, features * 8, kernel_size=2, stride=2
        )
        self.decoder1 = block(features * 16, features * 8, kernel_size, "dec1")
        self.upconv2 = nn.ConvTranspose2d(
            features * 8, features * 4, kernel_size=2, stride=2
        )
        self.decoder2 = block(features * 8, features * 4, kernel_size, "dec2")
        self.upconv3 = nn.ConvTranspose2d(
            features * 4, features * 2, kernel_size=2, stride=2
        )
        self.decoder3 = block(features * 4, features * 2, kernel_size, "dec3")
        self.upconv4 = nn.ConvTranspose2d(
            features * 2, features, kernel_size=2, stride=2
        )
        self.decoder4 = block(features * 2, features, kernel_size, "dec4")
        self.out_conv = nn.Conv2d(features, out_channels, kernel_size=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool(enc1))
        enc3 = self.encoder3(self.pool(enc2))
        enc4 = self.encoder4(self.pool(enc3))
        drop = self.dropout(enc4)
        bottleneck = self.bottleneck_encoder(self.pool(drop))

        upconv1 = self.upconv1(bottleneck)
        crop_bot, crop_top = crop_size(enc4, upconv1)
        dec1 = self.decoder1(
            torch.cat(
                (enc4[:, :, crop_bot:crop_top, crop_bot:crop_top], upconv1), dim=1
            )
        )

        upconv2 = self.upconv2(dec1)
        crop_bot, crop_top = crop_size(enc3, upconv2)
        dec2 = self.decoder2(
            torch.cat(
                (enc3[:, :, crop_bot:crop_top, crop_bot:crop_top], upconv2), dim=1
            )
        )

        upconv3 = self.upconv3(dec2)
        crop_bot, crop_top = crop_size(enc2, upconv3)
        dec3 = self.decoder3(
            torch.cat(
                (enc2[:, :, crop_bot:crop_top, crop_bot:crop_top], upconv3), dim=1
            )
        )

        upconv4 = self.upconv4(dec3)
        crop_bot, crop_top = crop_size(enc1, upconv4)
        dec4 = self.decoder4(
            torch.cat(
                (enc1[:, :, crop_bot:crop_top, crop_bot:crop_top], upconv4), dim=1
            )
        )

        output = self.out_conv(dec4)
        return output


def crop_size(encoder, upconv) -> tuple:
    "Return crop size of encoder's feature maps so it fits upconv's shape"
    x = encoder.shape[2]
    y = upconv.shape[2]
    return (x - y) // 2, (x + y) // 2


def block(
    in_channels: int, features: int, kernel_size: int, name: str
) -> nn.Sequential:
    return nn.Sequential(
        OrderedDict(
            [
                (f"{name}_conv1", nn.Conv2d(in_channels, features, kernel_size)),
                (f"{name}_relu1", nn.ReLU(inplace=True)),
                (f"{name}_conv2", nn.Conv2d(features, features, kernel_size)),
                (f"{name}_relu2", nn.ReLU(inplace=True)),
            ]
        )
    )

In [None]:
class SegmentationDataset(Dataset):
    def __init__(
        self, image_dir: str, mask_dir: str, image_names: list, transform=None
    ):
        self.transform = transform
        self.image_data = []
        self.mask_data = []
        for image_name in image_names:
            image_path = os.path.join(image_dir, image_name)
            mask_path = os.path.join(mask_dir, image_name)
            image = Image.open(image_path).convert("L")
            mask = Image.open(mask_path).convert("L")
            self.image_data.append(image)
            self.mask_data.append(mask)

    def __len__(self) -> int:
        return len(self.image_data)

    def __getitem__(self, idx: int) -> tuple:
        image = self.image_data[idx]
        mask = self.mask_data[idx]
        if self.transform:
            image, mask = self.transform(image, mask)
        return image, mask


class ImageMaskTransform:
    def __init__(
        self,
        flip_prob=0.5,
        rotate_prob=0.5,
        # TODO adjust rotation degree, flip and rotation probabilities
        rotation_degree=30,
        train=True,
        image_size=512,
        input_size=572,
        mask_size=388,
        # TODO automatic rotate pad calc
        rotate_pad=134,
    ):
        self.flip_prob = flip_prob
        self.rotate_prob = rotate_prob
        self.rotation_degree = rotation_degree
        self.train = train
        self.image_size = image_size
        self.input_size = input_size
        self.mask_size = mask_size
        self.default_pad = (input_size - image_size) // 2
        self.rotate_pad = rotate_pad

    def __call__(self, image, mask):
        image = F.to_tensor(image)
        mask = F.to_tensor(mask).long()
        if self.train:
            if random.random() < self.rotate_prob:
                image = F.pad(image, padding=self.rotate_pad, padding_mode="reflect")
                angle = random.uniform(-self.rotation_degree, self.rotation_degree)
                image = F.rotate(image, angle)
                mask = F.rotate(mask, angle)
                image = F.center_crop(image, self.input_size)
                mask = F.center_crop(mask, self.mask_size)
            elif random.random() < self.flip_prob:
                image = F.hflip(image)
                mask = F.hflip(mask)
                image = F.pad(image, padding=self.default_pad, padding_mode="reflect")
                mask = F.center_crop(mask, self.mask_size)
            elif random.random() < self.flip_prob:
                image = F.vflip(image)
                mask = F.vflip(mask)
                image = F.pad(image, padding=self.default_pad, padding_mode="reflect")
                mask = F.center_crop(mask, self.mask_size)
            else:
                image = F.pad(image, padding=self.default_pad, padding_mode="reflect")
                mask = F.center_crop(mask, self.mask_size)
        else:
            image = F.pad(image, padding=self.default_pad, padding_mode="reflect")
            mask = F.center_crop(mask, self.mask_size)

        return image, mask

In [None]:
def iou_loss(predictions, targets, eps=1e-6):
    intersection = torch.sum(predictions * targets)
    union = torch.sum(predictions + targets) - intersection
    iou = (intersection + eps) / (union + eps)
    return iou


W0, SIGMA = 5, 5


def cross_enthropy_weighted(outputs, targets, device, scale_factor=10**5):
    weight_class = scale_factor / torch.bincount(targets.flatten())

    borders = find_boundaries(targets.cpu().numpy())
    dist = distance_transform_edt(~borders)
    dist = torch.tensor(dist).to(device)
    weight_borders = W0 * torch.exp(-(dist**2) / (2 * SIGMA**2))

    class_map = weight_class[targets]
    weight = class_map + weight_borders
    loss_map = nn.functional.cross_entropy(outputs, targets, reduction="none")
    return torch.mean(loss_map * weight)

In [None]:
train_image_dir = os.path.join(base_dir, "isbi_2012_challenge/train/imgs")
train_mask_dir = os.path.join(base_dir, "isbi_2012_challenge/train/labels")
batch_size = 1
val_percent = 0.2

all_images = os.listdir(train_image_dir)
val_size = int(val_percent * len(all_images))
train_size = len(all_images) - val_size
random.shuffle(all_images)
val_images = all_images[:val_size]
train_images = all_images[val_size:]

train_transform = ImageMaskTransform(flip_prob=0.85, rotate_prob=0.85)
train_dataset = SegmentationDataset(
    train_image_dir, train_mask_dir, train_images, transform=train_transform
)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

val_transform = ImageMaskTransform(train=False)
val_dataset = SegmentationDataset(
    train_image_dir, train_mask_dir, val_images, transform=val_transform
)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size)

In [None]:
lr = 1e-2
num_epochs = 500
model = UNet()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer, T_0=1, T_mult=2, eta_min=1e-6
)
log_dir = os.path.join(base_dir, "runs")
writer = SummaryWriter(log_dir=log_dir)
# %tensorboard --logdir {log_dir}

accuracy_metric_val = BinaryAccuracy().to(device)
precision_metric_val = BinaryPrecision().to(device)
recall_metric_val = BinaryRecall().to(device)
rand_score_metric_val = RandScore().to(device)

accuracy_metric_train = BinaryAccuracy().to(device)
precision_metric_train = BinaryPrecision().to(device)
recall_metric_train = BinaryRecall().to(device)
rand_score_metric_train = RandScore().to(device)


for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    for images, labels in train_dataloader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = cross_enthropy_weighted(outputs, labels.squeeze(1), device)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    scheduler.step()

    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for images, labels in val_dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            labels = labels.squeeze(1)
            loss = cross_enthropy_weighted(outputs, labels, device)
            val_loss += loss.item()
            preds = torch.argmax(outputs, dim=1)
            accuracy_metric_val.update(preds, labels)
            recall_metric_val.update(preds, labels)
            precision_metric_val.update(preds, labels)
            rand_score_metric_val.update(preds.view(-1), labels.view(-1))

        for images, labels in train_dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            labels = labels.squeeze(1)
            preds = torch.argmax(outputs, dim=1)
            accuracy_metric_train.update(preds, labels)
            recall_metric_train.update(preds, labels)
            precision_metric_train.update(preds, labels)
            rand_score_metric_train.update(preds.view(-1), labels.view(-1))

    print(
        f"Epoch {epoch+1}/{num_epochs}, "
        f"Train loss: {train_loss / len(train_dataloader):.4f}, Val loss: {val_loss / len(val_dataloader):.4f}"
    )
    writer.add_scalar("Loss/train", train_loss / len(train_dataloader), epoch)
    writer.add_scalar("RandError/train", 1 - rand_score_metric_train.compute(), epoch)
    writer.add_scalar("PixelError/train", 1 - accuracy_metric_train.compute(), epoch)
    writer.add_scalar("Recall/train", recall_metric_train.compute(), epoch)
    writer.add_scalar("Precision/train", precision_metric_train.compute(), epoch)

    writer.add_scalar("Loss/val", val_loss / len(val_dataloader), epoch)
    writer.add_scalar("RandError/val", 1 - rand_score_metric_val.compute(), epoch)
    writer.add_scalar("PixelError/val", 1 - accuracy_metric_val.compute(), epoch)
    writer.add_scalar("Recall/val", recall_metric_val.compute(), epoch)
    writer.add_scalar("Precision/val", precision_metric_val.compute(), epoch)

    accuracy_metric_val.reset()
    recall_metric_val.reset()
    precision_metric_val.reset()
    rand_score_metric_val.reset()

    accuracy_metric_train.reset()
    recall_metric_train.reset()
    precision_metric_train.reset()
    rand_score_metric_train.reset()

torch.save(model.state_dict(), os.path.join(base_dir, "checkpoint.pth"))
writer.close()

In [None]:
test_image_dir = os.path.join(base_dir, "isbi_2012_challenge/test/imgs")
test_mask_dir = os.path.join(base_dir, "isbi_2012_challenge/test/labels")

test_images = os.listdir(test_image_dir)
test_transforms = ImageMaskTransform(train=False)
test_dataset = SegmentationDataset(
    test_image_dir, test_mask_dir, test_images, transform=test_transforms
)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size)

checkpoint_path = os.path.join(base_dir, "checkpoint.pth")
model = UNet()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model.eval()

accuracy_metric_test = BinaryAccuracy().to(device)
precision_metric_test = BinaryPrecision().to(device)
recall_metric_test = BinaryRecall().to(device)
rand_score_metric_test = RandScore().to(device)
test_loss = 0.0
with torch.no_grad():
    for images, labels in test_dataloader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        labels = labels.squeeze(1)
        loss = cross_enthropy_weighted(outputs, labels, device)
        test_loss += loss.item()
        preds = torch.argmax(outputs, dim=1)
        accuracy_metric_test.update(preds, labels)
        recall_metric_test.update(preds, labels)
        precision_metric_test.update(preds, labels)
        rand_score_metric_test.update(preds.view(-1), labels.view(-1))

print(
    f"(Test) Loss: {test_loss / len(test_dataloader):.4f}, "
    f"Rand error: {1 - rand_score_metric_test.compute():.4f} "
    f"Pixel Error: {1 - accuracy_metric_test.compute():.4f} "
    f"Recall: {recall_metric_test.compute():.4f} "
    f"Precision: {precision_metric_test.compute():.4f}"
)

In [None]:
with torch.no_grad():
    for images, labels in test_dataloader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)

        image = images[0].permute(1, 2, 0).cpu().numpy()
        mask = labels[0].cpu().numpy().squeeze()

        predicted_mask = torch.argmax(outputs[0], dim=0).cpu().numpy()

        fig, axs = plt.subplots(1, 3, figsize=(18, 6))
        axs[0].imshow(image, cmap="gray")
        axs[0].set_title("Image")
        axs[0].axis("off")

        axs[1].imshow(mask, cmap="gray")
        axs[1].set_title("Ground Truth Mask")
        axs[1].axis("off")

        axs[2].imshow(predicted_mask, cmap="gray")
        axs[2].set_title("Predicted Mask")
        axs[2].axis("off")

plt.show()

In [None]:
for images, masks in train_dataloader:
    image = images[0].permute(1, 2, 0).cpu().numpy()
    mask = masks[0].cpu().numpy().squeeze()
    fig, axs = plt.subplots(1, 2, figsize=(12, 6))
    axs[0].imshow(image, cmap="gray")
    axs[1].imshow(mask, cmap="gray")
    plt.show()
    print(image.shape[:2], mask.shape)
    print(mask)

In [None]:
for images, labels in train_dataloader:
    borders = find_boundaries(labels)
    dist = distance_transform_edt(~borders)
    w = W0 * np.exp(-(dist**2) / (2 * SIGMA**2))

    scale_factor = 10**5
    w_class = scale_factor / torch.bincount(labels.flatten())
    print(w_class)

    class_map = w_class[labels]
    print(class_map)

    w_final = class_map.numpy() + w

    fig, axs = plt.subplots(1, 4, figsize=(19, 6))
    axs[0].imshow(labels.squeeze(), cmap="gray")
    axs[1].imshow(~borders.squeeze(), cmap="gray")
    # axs[2].imshow(dist.squeeze(), cmap="coolwarm")
    axs[2].imshow(w_final.squeeze(), cmap="coolwarm")
    axs[3].imshow(w.squeeze(), cmap="coolwarm")
    plt.show()
    plt.hist(w.flatten(), bins=10)
    plt.show()
    plt.hist(w_final.flatten(), bins=10)
    break