In [None]:
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import segmentation_models_pytorch as smp


from src.datasets.DubaiSemanticSegmentationDataset import (
    DubaiSemanticSegmentationDataset,
)

In [None]:
model = smp.Unet(
    encoder_name="resnet34",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=3,                      # model output channels (number of classes in your dataset)
)

In [None]:
DUBAI_DATASET_PATH = "data/DubaiSemanticSegmentationDataset"


In [None]:
example_dataset = DubaiSemanticSegmentationDataset(DUBAI_DATASET_PATH)
print(len(example_dataset))

In [None]:
example_loader = DataLoader(example_dataset, batch_size=1, shuffle=True)

In [None]:
batch = next(iter(example_loader))
input = batch[0]
print(input.shape)

In [None]:
padded_input = F.pad(input, (0, 32 - (input.shape[3] % 32), 0, 32 - (input.shape[2] % 32)))
print(padded_input.shape)

In [None]:
with torch.no_grad():
    output = model(padded_input)

In [None]:
print(output.shape)

In [None]:
target = batch[1]

In [None]:
# Plot the output image
plt.subplot(1, 2, 1)
plt.imshow(output[0].permute(1, 2, 0))
plt.title('Output')
plt.axis('off')

# Plot the batch[1] image
plt.subplot(1, 2, 2)
plt.imshow(target[0].permute(1, 2, 0))
plt.title('Target')
plt.axis('off')

# Display the plot
plt.show()


In [None]:
# FIXME: this is broken - need to do some softmax to extract highest probability class

# first compute statistics for true positives, false positives, false negative and
# true negative "pixels"
tp, fp, fn, tn = smp.metrics.get_stats(output, target, mode='multilabel', threshold=0.5)

# then compute metrics with required reduction (see metric docs)
iou_score = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")
f1_score = smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro")
f2_score = smp.metrics.fbeta_score(tp, fp, fn, tn, beta=2, reduction="micro")
accuracy = smp.metrics.accuracy(tp, fp, fn, tn, reduction="macro")
recall = smp.metrics.recall(tp, fp, fn, tn, reduction="micro-imagewise")