In [3]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms

from src.datasets.DubaiSemanticSegmentationDataset import (
    DubaiSemanticSegmentationDataset,
)

import segmentation_models_pytorch as smp

from src.datasets.utils.ResizeToDivisibleBy32 import ResizeToDivisibleBy32

In [4]:
DUBAI_DATASET_PATH = "data/DubaiSemanticSegmentationDataset"

In [8]:
example_dataset = DubaiSemanticSegmentationDataset(
    DUBAI_DATASET_PATH, transforms=[ResizeToDivisibleBy32()]
)
print(len(example_dataset))

72


In [9]:
example_loader = DataLoader(example_dataset, batch_size=1, shuffle=True)

In [10]:
for images, masks in example_loader:
    print(images.shape)
    print(masks.shape)
    break

torch.Size([1, 3, 1504, 2176])
torch.Size([1, 3, 1504, 2176])


In [11]:
to_pil_transform = transforms.ToPILImage()
img = to_pil_transform(images.squeeze())

In [12]:
# img.show()

In [13]:
msk = to_pil_transform(masks.squeeze())

In [14]:
# msk.show()

In [15]:
model = smp.Unet(
    encoder_name="resnet34",  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",  # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=6,  # model output channels (number of classes in your dataset)
)

In [17]:
model.eval()

with torch.no_grad():
    for images, masks in example_loader:

        height = images.shape[2]
        width = images.shape[3]

        if height % 32 == 0 and width % 32 == 0:
            print("Height and width are divisible by 32")
        else:
            print("Height and width are not divisible by 32")

        output = model(images)
        print('output')
        print(type(output))
        print(output.shape)
        print(output.max())
        print(output.min())
        print()
        print('masks')
        print(type(masks))
        print(masks.shape)
        print(masks.max())
        print(masks.min())

        tp, fp, fn, tn = smp.metrics.get_stats(
            output, masks, mode="multilabel", threshold=0.5
        )
        print(f"TP: {tp}")
        print(f"FP: {fp}")
        print(f"TN: {tn}")
        print(f"FN: {fn}")

        iou_score = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")
        print(f"IoU score: {iou_score}")

        break

Height and width are divisible by 32
output
<class 'torch.Tensor'>
torch.Size([1, 6, 672, 704])
tensor(1.5950)
tensor(-2.3121)

masks
<class 'torch.Tensor'>
torch.Size([1, 3, 672, 704])
tensor(0.9647)
tensor(0.0206)


ValueError: Target should be one of the integer types, got torch.float32.