In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# Imports

In [None]:
import torchvision.transforms as T
from torchvision.datasets import Cityscapes
from torch.utils.data import DataLoader
from PIL import Image
import torch
import numpy as np
from torch.utils.data import Subset
import torch.nn.functional as F
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
import torchvision.models.segmentation as seg_models
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import cv2
import gc
from torchvision.transforms.functional import to_pil_image

from __future__ import print_function, absolute_import, division
from collections import namedtuple

# Constants

In [None]:

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Define the transforms
input_transform = T.Compose([
    T.Resize((512, 1024)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]),
])


Label = namedtuple( 'Label' , [

    'name'        , # The identifier of this label, e.g. 'car', 'person', ... .
                    # We use them to uniquely name a class

    'id'          , # An integer ID that is associated with this label.
                    # The IDs are used to represent the label in ground truth images
                    # An ID of -1 means that this label does not have an ID and thus
                    # is ignored when creating ground truth images (e.g. license plate).
                    # Do not modify these IDs, since exactly these IDs are expected by the
                    # evaluation server.

    'trainId'     , # Feel free to modify these IDs as suitable for your method. Then create
                    # ground truth images with train IDs, using the tools provided in the
                    # 'preparation' folder. However, make sure to validate or submit results
                    # to our evaluation server using the regular IDs above!
                    # For trainIds, multiple labels might have the same ID. Then, these labels
                    # are mapped to the same class in the ground truth images. For the inverse
                    # mapping, we use the label that is defined first in the list below.
                    # For example, mapping all void-type classes to the same ID in training,
                    # might make sense for some approaches.
                    # Max value is 255!

    'category'    , # The name of the category that this label belongs to

    'categoryId'  , # The ID of this category. Used to create ground truth images
                    # on category level.

    'hasInstances', # Whether this label distinguishes between single instances or not

    'ignoreInEval', # Whether pixels having this class as ground truth label are ignored
                    # during evaluations or not

    'color'       , # The color of this label
    ] )


labels = [
    #       name                     id    trainId   category            catId     hasInstances   ignoreInEval   color
    Label(  'unlabeled'            ,  0 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'ego vehicle'          ,  1 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'rectification border' ,  2 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'out of roi'           ,  3 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'static'               ,  4 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'dynamic'              ,  5 ,      255 , 'void'            , 0       , False        , True         , (111, 74,  0) ),
    Label(  'ground'               ,  6 ,      255 , 'void'            , 0       , False        , True         , ( 81,  0, 81) ),
    Label(  'road'                 ,  7 ,        0 , 'flat'            , 1       , False        , False        , (128, 64,128) ),
    Label(  'sidewalk'             ,  8 ,        1 , 'flat'            , 1       , False        , False        , (244, 35,232) ),
    Label(  'parking'              ,  9 ,      255 , 'flat'            , 1       , False        , True         , (250,170,160) ),
    Label(  'rail track'           , 10 ,      255 , 'flat'            , 1       , False        , True         , (230,150,140) ),
    Label(  'building'             , 11 ,        2 , 'construction'    , 2       , False        , False        , ( 70, 70, 70) ),
    Label(  'wall'                 , 12 ,        3 , 'construction'    , 2       , False        , False        , (102,102,156) ),
    Label(  'fence'                , 13 ,        4 , 'construction'    , 2       , False        , False        , (190,153,153) ),
    Label(  'guard rail'           , 14 ,      255 , 'construction'    , 2       , False        , True         , (180,165,180) ),
    Label(  'bridge'               , 15 ,      255 , 'construction'    , 2       , False        , True         , (150,100,100) ),
    Label(  'tunnel'               , 16 ,      255 , 'construction'    , 2       , False        , True         , (150,120, 90) ),
    Label(  'pole'                 , 17 ,        5 , 'object'          , 3       , False        , False        , (153,153,153) ),
    Label(  'polegroup'            , 18 ,      255 , 'object'          , 3       , False        , True         , (153,153,153) ),
    Label(  'traffic light'        , 19 ,        6 , 'object'          , 3       , False        , False        , (250,170, 30) ),
    Label(  'traffic sign'         , 20 ,        7 , 'object'          , 3       , False        , False        , (220,220,  0) ),
    Label(  'vegetation'           , 21 ,        8 , 'nature'          , 4       , False        , False        , (107,142, 35) ),
    Label(  'terrain'              , 22 ,        9 , 'nature'          , 4       , False        , False        , (152,251,152) ),
    Label(  'sky'                  , 23 ,       10 , 'sky'             , 5       , False        , False        , ( 70,130,180) ),
    Label(  'person'               , 24 ,       11 , 'human'           , 6       , True         , False        , (220, 20, 60) ),
    Label(  'rider'                , 25 ,       12 , 'human'           , 6       , True         , False        , (255,  0,  0) ),
    Label(  'car'                  , 26 ,       13 , 'vehicle'         , 7       , True         , False        , (  0,  0,142) ),
    Label(  'truck'                , 27 ,       14 , 'vehicle'         , 7       , True         , False        , (  0,  0, 70) ),
    Label(  'bus'                  , 28 ,       15 , 'vehicle'         , 7       , True         , False        , (  0, 60,100) ),
    Label(  'caravan'              , 29 ,      255 , 'vehicle'         , 7       , True         , True         , (  0,  0, 90) ),
    Label(  'trailer'              , 30 ,      255 , 'vehicle'         , 7       , True         , True         , (  0,  0,110) ),
    Label(  'train'                , 31 ,       16 , 'vehicle'         , 7       , True         , False        , (  0, 80,100) ),
    Label(  'motorcycle'           , 32 ,       17 , 'vehicle'         , 7       , True         , False        , (  0,  0,230) ),
    Label(  'bicycle'              , 33 ,       18 , 'vehicle'         , 7       , True         , False        , (119, 11, 32) ),
    Label(  'license plate'        , -1 ,       -1 , 'vehicle'         , 7       , False        , True         , (  0,  0,142) ),
]


name2label      = { label.name    : label for label in labels           }
# id to label object
id2label        = { label.id      : label for label in labels           }
# trainId to label object
trainId2label   = { label.trainId : label for label in reversed(labels) }

# Utils

In [None]:

def target_transform(label):
    # Convert label to tensor with trainId mapping
    label = torch.from_numpy(np.array(label)).long()
    return label

In [None]:

def prepare_targets(label, sorted_keys, sorted_train_ids, output_size=(64, 128)):
    label = vectorized_map_labels(label, sorted_keys, sorted_train_ids)

    label[label == 255] = 19

    one_hot = F.one_hot(label, num_classes=21)  # [B, H, W, C]
    one_hot = one_hot.permute(0, 3, 1, 2).float().to(device)    # [B, C, H, W]

    mask = ((label > 10) & (label < 19)).unsqueeze(1)   # [B, 1, H, W]

    one_hot[:, 20:21] = mask.float()
    one_hot = F.interpolate(one_hot.float(), size=output_size, mode='nearest').long().squeeze(1)

    return one_hot


def vectorized_map_labels(label_tensor, sorted_keys, sorted_train_ids):
    keys_for_label = torch.where(label_tensor < 1000, label_tensor, label_tensor // 1000)

    pos = torch.searchsorted(sorted_keys, keys_for_label.view(-1))

    mapped_flat = torch.full_like(keys_for_label.view(-1), -1)

    valid_mask = (pos < len(sorted_keys)) & (sorted_keys[pos] == keys_for_label.view(-1))
    mapped_flat[valid_mask.nonzero(as_tuple=True)[0]] = sorted_train_ids[pos[valid_mask]]

    return mapped_flat.view(label_tensor.shape)

def prepare_label_mapping(id2label, device='cpu'):
    keys = []
    train_ids = []
    for k, v in id2label.items():
        keys.append(k)
        train_ids.append(v.trainId)
    keys = torch.tensor(keys, device=device)
    train_ids = torch.tensor(train_ids, device=device)

    sorted_keys, indices = torch.sort(keys)
    sorted_train_ids = train_ids[indices]

    return sorted_keys, sorted_train_ids

sorted_keys, sorted_train_ids = prepare_label_mapping(id2label, device='cuda' if torch.cuda.is_available() else 'cpu')




# Data

In [None]:
!mkdir -p /content/cityscapes

!unzip /content/drive/MyDrive/cityscapes/leftImg8bit_trainvaltest.zip -d /content/cityscapes

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
 extracting: /content/cityscapes/leftImg8bit/train/jena/jena_000074_000019_leftImg8bit.png  
 extracting: /content/cityscapes/leftImg8bit/train/jena/jena_000040_000019_leftImg8bit.png  
 extracting: /content/cityscapes/leftImg8bit/train/jena/jena_000020_000019_leftImg8bit.png  
 extracting: /content/cityscapes/leftImg8bit/train/jena/jena_000030_000019_leftImg8bit.png  
 extracting: /content/cityscapes/leftImg8bit/train/jena/jena_000005_000019_leftImg8bit.png  
 extracting: /content/cityscapes/leftImg8bit/train/jena/jena_000059_000019_leftImg8bit.png  
 extracting: /content/cityscapes/leftImg8bit/train/jena/jena_000100_000019_leftImg8bit.png  
 extracting: /content/cityscapes/leftImg8bit/train/jena/jena_000034_000019_leftImg8bit.png  
 extracting: /content/cityscapes/leftImg8bit/train/jena/jena_000089_000019_leftImg8bit.png  
 extracting: /content/cityscapes/leftImg8bit/train/jena/jena_000104_000019_leftImg8bit.png  
 extr

In [None]:
!unzip /content/drive/MyDrive/cityscapes/gtFine_trainvaltest.zip -d /content/cityscapes

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: /content/cityscapes/gtFine/test/berlin/berlin_000117_000019_gtFine_color.png  
  inflating: /content/cityscapes/gtFine/test/berlin/berlin_000114_000019_gtFine_color.png  
  inflating: /content/cityscapes/gtFine/test/berlin/berlin_000434_000019_gtFine_labelIds.png  
  inflating: /content/cityscapes/gtFine/test/berlin/berlin_000420_000019_gtFine_color.png  
  inflating: /content/cityscapes/gtFine/test/berlin/berlin_000483_000019_gtFine_instanceIds.png  
  inflating: /content/cityscapes/gtFine/test/berlin/berlin_000420_000019_gtFine_instanceIds.png  
  inflating: /content/cityscapes/gtFine/test/berlin/berlin_000254_000019_gtFine_color.png  
  inflating: /content/cityscapes/gtFine/test/berlin/berlin_000490_000019_gtFine_color.png  
  inflating: /content/cityscapes/gtFine/test/berlin/berlin_000448_000019_gtFine_polygons.json  
  inflating: /content/cityscapes/gtFine/test/berlin/berlin_000099_000019_gtFine_labelIds

In [None]:
cityscapes_dataset = Cityscapes(
    root='/content/cityscapes', # Use the path where you downloaded the dataset
    split='train',
    mode='fine',
    target_type='instance',
    transform=input_transform,
    target_transform=target_transform,
)
train_loader = DataLoader(cityscapes_dataset, batch_size=4, shuffle=False)

In [None]:
cityscapes_datase_val = Cityscapes(
    root='/content/cityscapes', # Use the path where you downloaded the dataset
    split='val',
    mode='fine',
    target_type='instance',
    transform=input_transform,
    target_transform=target_transform,
)
val_loader = DataLoader(cityscapes_datase_val, batch_size=4, shuffle=True)

# Model

In [None]:

class DualSegmentationModel(nn.Module):
    def __init__(self, n_classes=21):
        super().__init__()
        base_model = seg_models.deeplabv3_resnet50(pretrained=True)
        self.encoder = base_model.backbone
        self.decoder = nn.Conv2d(2048,256,kernel_size=1)
        self.head_semantic = nn.Conv2d(256, n_classes, kernel_size=1)

    def forward(self, x):
        features = self.encoder(x)['out']
        x = self.decoder(features)
        semantic_logits = self.head_semantic(x)
        return semantic_logits


In [None]:
def boundary_aware_bce_loss(logits, targets, boundary_mask, device, lambda_weight = 5,num_classes = 21):
    bce = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')
    bce_standard = bce.mean()

    # Boundary-aware loss
    if boundary_mask.shape[1] == 1:
        boundary_mask = boundary_mask.expand_as(bce)
    boundary_loss = (bce * boundary_mask).sum() / (boundary_mask.sum() + 1e-6)

    loss = bce_standard + lambda_weight * boundary_loss
    return loss

def compute_boundary_mask_batch(semantic_labels, kernel_size=3):
    boundary_masks = []
    kernel = np.ones((kernel_size, kernel_size), np.uint8)

    for label in semantic_labels.cpu().numpy():
        dilated = cv2.dilate(label.astype(np.uint8), kernel, iterations=1)
        eroded = cv2.erode(label.astype(np.uint8), kernel, iterations=1)
        boundary = (dilated != eroded).astype(np.float32)
        boundary_masks.append(torch.tensor(boundary))

    return torch.stack(boundary_masks).unsqueeze(1).to(semantic_labels.device)



In [None]:

def compute_loss(sem_logits, sem_target, device ,lambda_weight = 5.0):
  sem_target_no_special = sem_target[:, :20, :, :]  # Keep first 20 classes only

  semantic_labels = sem_target_no_special.argmax(dim=1)
  boundary_mask = compute_boundary_mask_batch(semantic_labels)
  sem_loss = boundary_aware_bce_loss(sem_logits, sem_target, boundary_mask, device ,lambda_weight)
  return sem_loss


In [None]:

state_dict = torch.load("/content/drive/MyDrive/cityscapes/best_model_ap2.pth", map_location=torch.device(device))
model = DualSegmentationModel()
model.to(device)
model.load_state_dict(state_dict)


<All keys matched successfully>

In [None]:


def validate(model, val_loader, device):
    model.eval()
    running_loss = 0.0

    with torch.no_grad():
        for images, targets in val_loader:
            images = images.to(device)
            sem_target  = prepare_targets(targets.to(device),sorted_keys, sorted_train_ids)

            out_sem = model(images)

            probs = torch.sigmoid(out_sem[:, 20:21, :, :])

            loss = compute_loss(out_sem,sem_target.float(),device, lambda_weight)
            running_loss += loss.item()
    avg_loss = running_loss / len(train_loader)

    print(f"Validation Pixel loss: {avg_loss:.2f}%")
    return  avg_loss

# Train

In [None]:
num_epochs = 1 # trained the model with 15 epochs

In [None]:
best_val_loss = float('inf')
patience = 5
lambda_weight = 3
early_stop_counter = 0

save_path_loss = '/content/drive/MyDrive/cityscapes/best_model_loss.pth'
current_path = '/content/drive/MyDrive/cityscapes/current_epoch_model.pth'

# Optimizer & scheduler
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0001)
total_iters = num_epochs * len(train_loader)
scheduler = optim.lr_scheduler.PolynomialLR(optimizer, total_iters=total_iters, power=0.9)

scaler = torch.amp.GradScaler(device)

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
    for images, targets in loop:
        images = images.to(device)

        sem_target  = prepare_targets(targets.to(device),sorted_keys, sorted_train_ids)

        sem_target = sem_target.to(device)

        optimizer.zero_grad()

        with torch.amp.autocast(device):
            out_sem = model(images)
            loss = compute_loss(out_sem,sem_target.float(),device, lambda_weight)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item()
        loop.set_postfix(loss=loss.item())

        scheduler.step()
    avg_loss = running_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}] Loss: {avg_loss:.4f}")
    torch.save(model.state_dict(), current_path)

    print(f"Model saved (Epoch [{epoch+1}/{num_epochs}] Loss: {avg_loss:.4f})")


    if epoch == 0 or (epoch + 1) % 3 == 0:
        val_loss = validate(model, val_loader, device)

        if val_loss < best_val_loss:
            best_val_acc = val_loss
            torch.save(model.state_dict(), save_path_loss)
            print(f"Model saved (val_loss={val_loss:.4f})")
            early_stop_counter = 0
        else:
            early_stop_counter += 1
            print(f"No improvement. Early stop counter: {early_stop_counter}/{patience}")

        if early_stop_counter >= patience:
            print("Early stopping triggered.")
            break

print("Training complete!")


Epoch 1/1: 100%|██████████| 744/744 [13:03<00:00,  1.05s/it, loss=0.217]


Epoch [1/1] Loss: 0.2276
Model saved (Epoch [1/1] Loss: 0.2276)
Validation Pixel loss: 0.04%
No improvement. Early stop counter: 1/5
Training complete!


# Test

In [None]:
cityscapes_dataset_test = Cityscapes(
    root='/content/cityscapes', # Use the path where you downloaded the dataset
    split='test',
    mode='fine',
    target_type='semantic',
    transform=input_transform,
    target_transform=target_transform,
)

test_loader = DataLoader(cityscapes_dataset_test, batch_size=4, shuffle=False)


# Fishscapes Benchmark

We need to downgrade the libraries and install additional library bdlb for this benchmark (it may ask to restart the session after the library installation)

In [None]:
!pip3 install --quiet tensorflow-gpu
!pip3 install --quiet --upgrade git+https://github.com/hermannsblum/bdl-benchmark.git

  [1;31merror[0m: [1msubprocess-exited-with-error[0m
  
  [31m×[0m [32mpython setup.py egg_info[0m did not run successfully.
  [31m│[0m exit code: [1;36m1[0m
  [31m╰─>[0m See above for output.
  
  [1;35mnote[0m: This error originates from a subprocess, and is likely not a problem with pip.
  Preparing metadata (setup.py) ... [?25l[?25herror
[1;31merror[0m: [1mmetadata-generation-failed[0m

[31m×[0m Encountered error while generating package metadata.
[31m╰─>[0m See above for output.

[1;35mnote[0m: This is an issue with the package mentioned above, not pip.
[1;36mhint[0m: See above for details.
  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for bdlb (setup.py) ... [?25l[?25hdone


In [None]:
!pip install tensorflow-datasets==3.1.0


Collecting tensorflow-datasets==3.1.0
  Downloading tensorflow_datasets-3.1.0-py3-none-any.whl.metadata (4.6 kB)
Downloading tensorflow_datasets-3.1.0-py3-none-any.whl (3.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m39.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tensorflow-datasets
  Attempting uninstall: tensorflow-datasets
    Found existing installation: tensorflow-datasets 4.9.9
    Uninstalling tensorflow-datasets-4.9.9:
      Successfully uninstalled tensorflow-datasets-4.9.9
Successfully installed tensorflow-datasets-3.1.0


In [None]:
!pip install --force-reinstall -v protobuf==3.20.*

Using pip 24.1.2 from /usr/local/lib/python3.11/dist-packages/pip (python 3.11)
Collecting protobuf==3.20.*
  Obtaining dependency information for protobuf==3.20.* from https://files.pythonhosted.org/packages/8d/14/619e24a4c70df2901e1f4dbc50a6291eb63a759172558df326347dce1f0d/protobuf-3.20.3-py2.py3-none-any.whl.metadata
  Downloading protobuf-3.20.3-py2.py3-none-any.whl.metadata (720 bytes)
Downloading protobuf-3.20.3-py2.py3-none-any.whl (162 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m162.1/162.1 kB[0m [31m4.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: protobuf
  Attempting uninstall: protobuf
    Found existing installation: protobuf 5.29.5
    Uninstalling protobuf-5.29.5:
      Removing file or directory /usr/local/lib/python3.11/dist-packages/google/_upb/
      Removing file or directory /usr/local/lib/python3.11/dist-packages/google/protobuf/
      Removing file or directory /usr/local/lib/python3.11/dist-packages/protobuf-5.29.5.d

In [None]:
import bdlb
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

%matplotlib inline



In [None]:
fs = bdlb.load(benchmark="fishyscapes", download_and_prepare=False)
fs.download_and_prepare('LostAndFound')
# load the dataset, there is only a validation dataset
ds = tfds.load('fishyscapes/LostAndFound', split='validation')


Downloading and preparing dataset fishyscapes/LostAndFound/1.0.0 (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/fishyscapes/LostAndFound/1.0.0...


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]





Extraction completed...: 0 file [00:00, ? file/s]






Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]







Extraction completed...: 0 file [00:00, ? file/s]




0 examples [00:00, ? examples/s]

Shuffling and writing examples to /root/tensorflow_datasets/fishyscapes/LostAndFound/1.0.0.incompleteOJ9E2N/fishyscapes-validation.tfrecord


  0%|          | 0/100 [00:00<?, ? examples/s]

Computing statistics...:   0%|          | 0/1 [00:00<?, ? split/s]

0 examples [00:00, ? examples/s]

Dataset fishyscapes downloaded and prepared to /root/tensorflow_datasets/fishyscapes/LostAndFound/1.0.0. Subsequent calls will reuse this data.


In [None]:
def estimator_unknown_objectness_scores(image):

    test = image
    if isinstance(image, tf.Tensor):
        img_np = image.numpy()
        # Bit depth normalization if needed
        img_np = (img_np * 255).astype(np.uint8)
        image = Image.fromarray(img_np)
    transform = T.Compose([
    T.Resize((512, 1024)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]),
    ])

    img_tensor = transform(image).unsqueeze(0).to(device)

    with torch.amp.autocast(device):
      sem_logits = model(img_tensor)

    out_obj = torch.sigmoid(sem_logits[:, 20:21, :, :])
    sem_probs = torch.sigmoid(sem_logits[:, :20, :, :])

    sem_probs_inv = 1 - sem_probs
    sem_logits_inv_prod = sem_probs_inv.prod(dim=1, keepdim=True)
    unknown_objectness_score = sem_logits_inv_prod.squeeze()  * out_obj.squeeze()

    unknown_objectness_score = F.interpolate(
        unknown_objectness_score.unsqueeze(0).unsqueeze(0),
        size=(1024, 2048),
        mode='bilinear',
        align_corners=False
    ).squeeze()

    unknown_objectness_score = unknown_objectness_score.detach().cpu().numpy()
    unknown_objectness_score = tf.convert_to_tensor(unknown_objectness_score, dtype=tf.dtypes.float32)
    return unknown_objectness_score

In [None]:
def estimator_product_scores(image):
    """Assigns a random uncertainty per pixel."""
        # Convert tf.Tensor to PIL.Image
    test = image
    if isinstance(image, tf.Tensor):
        img_np = image.numpy()
        # Bit depth normalization if needed
        img_np = (img_np * 255).astype(np.uint8)
        image = Image.fromarray(img_np)
    transform = T.Compose([
    T.Resize((512, 1024)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]),
    ])
    img_tensor = transform(image).unsqueeze(0).to(device)
    uncertainty = tf.random.uniform(test.shape[:-1])
    with torch.amp.autocast(device):
      sem_logits = model(img_tensor)

    out_obj = torch.sigmoid(sem_logits[:, 20:21, :, :])
    sem_probs = torch.sigmoid(sem_logits[:, :20, :, :])

    sem_probs_inv = 1 - sem_probs
    sem_logits_inv_prod = sem_probs_inv.prod(dim=1, keepdim=True)

    sem_logits_inv_prod = F.interpolate(
        sem_logits_inv_prod,
        size=(1024, 2048),
        mode='bilinear',
        align_corners=False
    ).squeeze()

    sem_logits_inv_prod = sem_logits_inv_prod.detach().cpu().numpy()
    sem_logits_inv_prod = tf.convert_to_tensor(sem_logits_inv_prod, dtype=uncertainty.dtype)
    return sem_logits_inv_prod



In [None]:
def estimator(image):
    """Assigns a random uncertainty per pixel."""
        # Convert tf.Tensor to PIL.Image
    test = image
    if isinstance(image, tf.Tensor):
        img_np = image.numpy()
        # Bit depth normalization if needed
        img_np = (img_np * 255).astype(np.uint8)
        image = Image.fromarray(img_np)
    transform = T.Compose([
    T.Resize((512, 1024)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]),
    ])
    img_tensor = transform(image).unsqueeze(0).to(device)
    uncertainty = tf.random.uniform(test.shape[:-1])
    with torch.amp.autocast(device):
      sem_logits = model(img_tensor)

    out_obj = torch.sigmoid(sem_logits[:, 20:21, :, :])
    sem_probs = torch.sigmoid(sem_logits[:, :20, :, :])

    sem_probs_max, _ = sem_probs.max(dim=1)
    sem_probs_inv = 1 - sem_probs_max
    sem_probs_inv = F.interpolate(
        sem_probs_inv.unsqueeze(0).unsqueeze(0),
        size=(1024, 2048),
        mode='bilinear',
        align_corners=False
    ).squeeze()

    sem_probs_inv = sem_probs_inv.detach().cpu().numpy()
    sem_probs_inv = tf.convert_to_tensor(sem_probs_inv, dtype=uncertainty.dtype)
    return sem_probs_inv



In [None]:
metrics = fs.evaluate(estimator_unknown_objectness_scores, ds)


100%|██████████| 100/100 [00:12<00:00,  7.75it/s]


In [None]:
print('My method achieved {:.2f}% AP {:.2f}% AUC'.format(100 * metrics['AP'],100 * metrics['auroc']))


My method achieved 0.69% AP 78.17% AUC


In [None]:
metrics = fs.evaluate(estimator, ds)


100%|██████████| 100/100 [00:12<00:00,  8.10it/s]


In [None]:
print('My method achieved {:.2f}% AP {:.2f}% AUC'.format(100 * metrics['AP'],100 * metrics['auroc']))


My method achieved 0.61% AP 72.94% AUC


In [None]:
metrics = fs.evaluate(estimator_product_scores, ds)


100%|██████████| 100/100 [00:12<00:00,  7.93it/s]


In [None]:
print('My method achieved {:.2f}% AP {:.2f}% AUC'.format(100 * metrics['AP'],100 * metrics['auroc']))


My method achieved 0.86% AP 74.56% AUC
