In [1]:
import os

try:
    import subprocess
    import sys

    from google.colab import drive

    subprocess.run(["pip", "install", "torchmetrics", "optuna"])
    base_dir = "/content/drive/MyDrive/Colab_Notebooks/Crack_Detection"
    drive.mount("/content/drive")
    sys.path.append(os.path.join(base_dir, "semantic-segmentation"))
    LOCAL = False
except ImportError:
    base_dir = "."
    LOCAL = True

import matplotlib.pyplot as plt
import numpy as np
import optuna
import torch
from scipy.ndimage import distance_transform_edt
from skimage.segmentation import find_boundaries
from torchmetrics.classification import BinaryAccuracy, BinaryPrecision, BinaryRecall
from torchmetrics.clustering import RandScore

from train import (
    DataLoader,
    ImageMaskTransform,
    SegmentationDataset,
    UNet,
    cross_entropy_weighted,
    init_data_loaders,
    init_datasets,
    tune_hyperparams,
    train_fixed_hyperparams,
)


%load_ext tensorboard

In [None]:
%tensorboard --logdir {os.path.join(base_dir, "runs")}

In [None]:
tune_hyperparams(base_dir, LOCAL)

In [2]:
train_fixed_hyperparams(base_dir, LOCAL)

KeyError: 'min_save_epoch'

In [7]:
import pickle
import optuna.visualization as vis

# study_name = "study-0506-190624"
study_name = "study-0502-175711"
# study_name = "study-0428"
storage = f"sqlite:///Data/{study_name}/Data/seg-study.db"
studies = optuna.study.get_all_study_summaries(storage=storage)
print("Num of studies:", len(studies))
study = studies[-1]
print(study.study_name)

Num of studies: 14
study-0502-175709


In [8]:
try:
    with open(f"Data/{study_name}/Data/studies/{study.study_name}.pkl", "rb") as f:
        study_dict = pickle.load(f)

    study = optuna.load_study(
        study_name=study.study_name,
        storage=storage,
        sampler=study_dict["sampler"],
        pruner=study_dict["pruner"],
    )
except FileNotFoundError:
    study = optuna.load_study(study_name=study.study_name, storage=storage)

In [9]:
study.best_value, len(study.trials), study.pruner, study.best_params

(0.13715654611587524,
 47,
 <optuna.pruners._median.MedianPruner at 0x7437597b1870>,
 {'flip_prob': 0.042747388239577966,
  'rotate_prob': 0.4719902190284881,
  'elastic_prob': 0.01567998717280196,
  'translate_prob': 0.15377448395168106,
  'brightness_prob': 0.138486416623626,
  'batch_size': 8,
  'dropout_p': 0.5328890961652364,
  'vanilla_loss': True,
  'use_adam': True,
  'lr': 0.00036296166101104416,
  'use_cosine_scheduler': False,
  'min_lr': 1.0017766083008169e-08,
  'lr_patience': 18,
  'lr_cooldown': 1,
  'lr_factor': 0.06925485872989952})

In [29]:
study.best_value, len(study.trials), study.pruner, study.best_params

(0.14338421821594238,
 100,
 <optuna.pruners._hyperband.HyperbandPruner at 0x7c74eecb5390>,
 {'flip_prob': 0.04459255293028834,
  'rotate_prob': 0.40379419104333913,
  'elastic_prob': 0.014062767704262479,
  'translate_prob': 0.33608821897374763,
  'brightness_prob': 0.23828063471331679,
  'batch_size': 9,
  'dropout_p': 0.2887745942859943,
  'vanilla_loss': True,
  'use_adam': True,
  'lr': 0.0003909349124991186,
  'use_cosine_scheduler': False,
  'min_lr': 6.945255506527192e-08,
  'lr_patience': 24,
  'lr_cooldown': 1,
  'lr_factor': 0.2878405900553068})

In [10]:
for trial in study.trials:
    params = trial.params
    if not params["vanilla_loss"]:
        v = trial.value
        if v and v < 0.25:
            print(f"{v:0.3f}", trial.number, trial.params, trial.state)

In [14]:
optuna.importance.get_param_importances(study)

{'use_adam': np.float64(0.3977645095745252),
 'vanilla_loss': np.float64(0.1406793931554693),
 'elastic_prob': np.float64(0.08746186350862907),
 'batch_size': np.float64(0.08671901945241697),
 'use_cosine_scheduler': np.float64(0.08402969690245228),
 'brightness_prob': np.float64(0.06316926781526236),
 'translate_prob': np.float64(0.043377337361138406),
 'rotate_prob': np.float64(0.031254978146683376),
 'flip_prob': np.float64(0.025227132797285463),
 'min_lr': np.float64(0.021256423150298954),
 'dropout_p': np.float64(0.0190603781358387)}

In [None]:
vis.plot_param_importances(study)

In [None]:
(train_image_dir, train_images, train_mask_dir, val_images, val_percent) = (
    init_datasets(base_dir)
)

flip_prob = 0.1
rotate_prob = 0.1
elastic_prob = 0.11
translate_prob = 0.1
brightness_prob = 0.1
batch_size = 1 if LOCAL else 9
train_dataloader, val_dataloader = init_data_loaders(
    batch_size,
    brightness_prob,
    elastic_prob,
    flip_prob,
    rotate_prob,
    train_image_dir,
    train_images,
    train_mask_dir,
    translate_prob,
    val_images,
)

In [None]:
for images, masks in train_dataloader:
    image = images[0].permute(1, 2, 0).cpu().numpy()
    mask = masks[0].cpu().numpy().squeeze()
    fig, axs = plt.subplots(1, 2, figsize=(12, 6))
    axs[0].imshow(image, cmap="gray")
    axs[1].imshow(mask, cmap="gray")
    assert image.shape[:2] == (572, 572)
    assert mask.shape == (388, 388)
    plt.show()

In [None]:
loss_w0, loss_sigma = 5, 5
loss_w1 = 1.0
dropout_p = 0.2
vanilla_loss = True

test_image_dir = os.path.join(base_dir, "isbi_2012_challenge/test/imgs")
test_mask_dir = os.path.join(base_dir, "isbi_2012_challenge/test/labels")

test_images = os.listdir(test_image_dir)
test_transforms = ImageMaskTransform(train=False)
test_dataset = SegmentationDataset(
    test_image_dir, test_mask_dir, test_images, transform=test_transforms
)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size)

pretrained_weights_path = "checkpoints/C1.pth"
model = UNet(dropout_p=dropout_p)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.load_state_dict(torch.load(pretrained_weights_path, map_location=device))
model.eval()

accuracy_metric_test = BinaryAccuracy().to(device)
precision_metric_test = BinaryPrecision().to(device)
recall_metric_test = BinaryRecall().to(device)
rand_score_metric_test = RandScore().to(device)
test_loss = 0.0
with torch.no_grad():
    for images, labels in test_dataloader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        labels = labels.squeeze(1)
        loss = cross_entropy_weighted(
            outputs,
            labels,
            device,
            loss_w0,
            loss_sigma,
            loss_w1,
            vanilla=vanilla_loss,
        )
        test_loss += loss.item()
        preds = torch.argmax(outputs, dim=1)
        accuracy_metric_test.update(preds, labels)
        recall_metric_test.update(preds, labels)
        precision_metric_test.update(preds, labels)
        rand_score_metric_test.update(preds.view(-1), labels.view(-1))

print(
    f"(Test) Loss: {test_loss / len(test_dataloader):.4f}, "
    f"Rand error: {1 - rand_score_metric_test.compute():.4f} "
    f"Pixel Error: {1 - accuracy_metric_test.compute():.4f} "
    f"Recall: {recall_metric_test.compute():.4f} "
    f"Precision: {precision_metric_test.compute():.4f}"
)

In [None]:
with torch.no_grad():
    for images, labels in test_dataloader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)

        image = images[0].permute(1, 2, 0).cpu().numpy()
        mask = labels[0].cpu().numpy().squeeze()

        predicted_mask = torch.argmax(outputs[0], dim=0).cpu().numpy()

        fig, axs = plt.subplots(1, 3, figsize=(18, 6))
        axs[0].imshow(image, cmap="gray")
        axs[0].set_title("Image")
        axs[0].axis("off")

        axs[1].imshow(mask, cmap="gray")
        axs[1].set_title("Ground Truth Mask")
        axs[1].axis("off")

        axs[2].imshow(predicted_mask, cmap="gray")
        axs[2].set_title("Predicted Mask")
        axs[2].axis("off")

plt.show()

In [None]:
W0, SIGMA = 5, 5
for images, labels in train_dataloader:
    borders = find_boundaries(labels)
    dist = distance_transform_edt(~borders)
    w = W0 * np.exp(-2 * dist**2 / SIGMA**2)

    labels_bincount = torch.bincount(labels.flatten())
    w_class = labels_bincount.sum() / labels_bincount
    print(w_class, labels_bincount.sum())

    class_map = w_class[labels]
    print(class_map)

    w_final = class_map.numpy() + w

    fig, axs = plt.subplots(1, 4, figsize=(19, 6))
    axs[0].imshow(labels.squeeze(), cmap="gray")
    axs[1].imshow(~borders.squeeze(), cmap="gray")
    axs[2].imshow(w.squeeze(), cmap="coolwarm")
    axs[3].imshow(w_final.squeeze(), cmap="coolwarm")
    plt.show()
    plt.hist(w.flatten(), bins=10)
    plt.show()
    plt.hist(w_final.flatten(), bins=10)
    break