In [6]:
DATA_DIR = "./data"

In [2]:

!mkdir -p $DATA_DIR
!test ! -d $DATA_DIR/train \
    && wget -O $DATA_DIR/train.tar https://people.eecs.berkeley.edu/~hendrycks/streethazards_train.tar \
    && tar -xf $DATA_DIR/train.tar -C $DATA_DIR \
    && rm -r $DATA_DIR/train.tar \
    && mv $DATA_DIR/train $DATA_DIR/streethazards_train
!test ! -d $DATA_DIR/test \
    && wget -O $DATA_DIR/test.tar https://people.eecs.berkeley.edu/~hendrycks/streethazards_test.tar \
    && tar -xf $DATA_DIR/test.tar -C $DATA_DIR \
    && rm -r $DATA_DIR/test.tar\
    && mv $DATA_DIR/test $DATA_DIR/streethazards_test

--2025-10-05 20:55:45--  https://people.eecs.berkeley.edu/~hendrycks/streethazards_train.tar
Resolving people.eecs.berkeley.edu (people.eecs.berkeley.edu)... 128.32.244.190
Connecting to people.eecs.berkeley.edu (people.eecs.berkeley.edu)|128.32.244.190|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 9386226176 (8.7G) [application/x-tar]
Saving to: ‘./data/train.tar’


2025-10-05 20:59:01 (45.7 MB/s) - ‘./data/train.tar’ saved [9386226176/9386226176]

--2025-10-05 20:59:25--  https://people.eecs.berkeley.edu/~hendrycks/streethazards_test.tar
Resolving people.eecs.berkeley.edu (people.eecs.berkeley.edu)... 128.32.244.190
Connecting to people.eecs.berkeley.edu (people.eecs.berkeley.edu)|128.32.244.190|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2150484992 (2.0G) [application/x-tar]
Saving to: ‘./data/test.tar’


2025-10-05 21:00:21 (36.9 MB/s) - ‘./data/test.tar’ saved [2150484992/2150484992]



In [4]:
!pip install -U segmentation-models-pytorch

Collecting segmentation-models-pytorch
  Downloading segmentation_models_pytorch-0.5.0-py3-none-any.whl.metadata (17 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.8->segmentation-models-pytorch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.8->segmentation-models-pytorch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.8->segmentation-models-pytorch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.8->segmentation-models-pytorch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.8->segmentation-models-pytorch)
  Downloading nvidia_cublas_cu12-12.4.5.8-

In [16]:
import numpy as np
import os
from enum import IntEnum
import torch
from torch import Tensor
import torch.nn as nn
import segmentation_models_pytorch as smp
from typing import Optional, Callable, Union, Tuple, Dict, List
import json
from pathlib import Path
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torchvision import transforms
from torchvision.transforms import v2
from tqdm import tqdm
from PIL import Image

In [7]:
"""
Source: https://github.com/hendrycks/anomaly-seg/issues/15#issuecomment-890300278
"""
COLORS = np.array([
    [ 70,  70,  70],  # building     =   0,
    [190, 153, 153],  # fence        =   1, 
    [250, 170, 160],  # other        =   2,
    [220,  20,  60],  # pedestrian   =   3, 
    [153, 153, 153],  # pole         =   4,
    [157, 234,  50],  # road line    =   5, 
    [128,  64, 128],  # road         =   6,
    [244,  35, 232],  # sidewalk     =   7,
    [107, 142,  35],  # vegetation   =   8, 
    [  0,   0, 142],  # car          =   9,
    [102, 102, 156],  # wall         =  10, 
    [220, 220,   0],  # traffic sign =  11,
    [ 60, 250, 240],  # anomaly      =  12,
]) 

class StreetHazardsClasses(IntEnum):
    BUILDING        = 0
    FENCE           = 1
    OTHER           = 2
    PEDESTRIAN      = 3
    POLE            = 4
    ROAD_LINE       = 5
    ROAD            = 6
    SIDEWALK        = 7
    VEGETATION      = 8
    CAR             = 9
    WALL            = 10
    TRAFFIC_SIGN    = 11
    ANOMALY         = 12
    
#path to streethazards dataset
train_odgt_file = f"{DATA_DIR}/streethazards_train/train.odgt"
val_odgt_file = f"{DATA_DIR}/streethazards_train/validation.odgt"
test_odgt_file = f"{DATA_DIR}/streethazards_test/test.odgt"

COMPUTE_MEAN_STD = False

In [37]:
class StreetHazardsDataset(Dataset):
    """
    A custom PyTorch Dataset for the StreetHazards inliner dataset.

    This dataset reads image and segmentation label paths from a `.odgt` file,
    applies optional resizing and spatial transformations, and returns
    dictionary-style samples with normalized image tensors and label tensors.

    Args:
        odgt_file (str): Path to the `.odgt` file containing image and label metadata.
        image_resize (Tuple[int, int], optional): Target size to resize images and labels. 
        spatial_transforms (Callable, optional): Optional transformation function applied to both images and labels.
        mean_std (Tuple[List[float], List[float]], optional): Mean and standard deviation for image normalization.
        
    """
    def __init__(
        self,
        odgt_file: str,
        image_resize: Tuple[int, int] = (512, 896),
        spatial_transforms: Optional[Callable] = None,
        mean_std: Tuple[List[float], List[float]] = None
    ):

        self.spatial_transforms = spatial_transforms
        self.mean_std = mean_std
        self.image_resize = image_resize

        with open(odgt_file, "r") as f:
            odgt_data = json.load(f)
        

        self.paths = [
            {
                "image": os.path.join(Path(odgt_file).parent, data["fpath_img"]),
                "labels": os.path.join(Path(odgt_file).parent, data["fpath_segm"]),
            }
            for data in odgt_data 
        ]
    
    def __len__(self) -> int:
        return len(self.paths)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:

        image = Image.open(self.paths[idx]["image"]).convert("RGB")
        labels = Image.open(self.paths[idx]["labels"])

        if self.image_resize:
            image = transforms.Resize(self.image_resize, transforms.InterpolationMode.BILINEAR)(image)
            labels = transforms.Resize(self.image_resize, transforms.InterpolationMode.NEAREST)(labels)
            
        if self.spatial_transforms:
            image, labels  = self.spatial_transforms(image, labels)         

        #to_tensor
        image = transforms.ToTensor()(image)
        labels = torch.as_tensor(transforms.functional.pil_to_tensor(labels), dtype=torch.int64) - 1
        
        labels = labels.squeeze(0)
        
        if self.mean_std:
            image = transforms.Normalize(mean = self.mean_std[0], std = self.mean_std[1])(image)

        return {'image' : image, 'labels' : labels}

In [38]:
def create_one_hot_prototypes_torch(num_known_classes: int, t_value: float = 3.0, device: str = 'cpu') -> torch.Tensor:
    """
    Generates one-hot prototypes as a PyTorch tensor for a given number of known classes.
    Each prototype is a vector where only the element corresponding
    to its class index has the 't_value', and all other elements are 0.

    Args:
        num_known_classes (int): The total number of known (in-distribution) classes.
                                 This also determines the dimensionality of each prototype vector.
        t_value (float): The non-zero value at the class's specific index in the prototype.
                         As specified in the paper, this is often 3.0.
        device (str): The device on which to create the tensor ('cpu' or 'cuda').

    Returns:
        torch.Tensor: A 2D PyTorch tensor where each row is a prototype vector.
                      The shape will be (num_known_classes, num_known_classes).
    """
    if not isinstance(num_known_classes, int) or num_known_classes <= 0:
        raise ValueError("num_known_classes must be a positive integer.")
    if not isinstance(t_value, (int, float)):
        raise ValueError("t_value must be a numeric type.")
    if device not in ['cpu', 'cuda']:
        raise ValueError("device must be 'cpu' or 'cuda'.")

    # Create a tensor of zeros
    prototypes = torch.zeros((num_known_classes, num_known_classes), dtype=torch.float32, device=device)

    # Fill the diagonal with t_value to create one-hot prototypes
    for i in range(num_known_classes):
        prototypes[i, i] = t_value
        
    # An even more concise way using torch.eye (Identity matrix)
    # prototypes = torch.eye(num_known_classes, dtype=torch.float32, device=device) * t_value

    return prototypes

Prototype  = create_one_hot_prototypes_torch(12)

In [66]:

class DMLNetFeatureExtractor(torch.nn.Module):
    def __init__(self, encoder_name, encoder_weights, num_feature_channels, activation):
        super().__init__()
        

        self.model = smp.DeepLabV3Plus(
            encoder_name=encoder_name,
            encoder_weights=encoder_weights,
            classes=num_feature_channels, # This sets the output channels of the segmentation_head if kept
            activation=activation # Usually 'None' for the main head, but for features it might not matter directly
        )
        

        # Option 2 is safer and more robust.
        # First, we disable the original segmentation head as you did.
        self.original_segmentation_head = self.model.segmentation_head # Store it if needed
        self.model.segmentation_head = torch.nn.Identity() # Remove the final head

        # --- CORRECTION START ---
        # To get the decoder's actual output channels, we need a dummy forward pass
        # through just the encoder and decoder.
        
        # Temporarily detach the module to make a dummy pass if needed,
        # but in __init__, we can usually just do a conceptual forward.
        # However, to be absolutely safe and get the runtime channel count:
        
        # Create a dummy input to trace the decoder output channels
        # Assuming typical RGB input (3 channels) and arbitrary spatial dimensions
        dummy_input = torch.randn(2, 3, 256, 256) 
        
        # Pass through encoder
        encoder_features_dummy = self.model.encoder(dummy_input)
        
        # Pass through decoder to get its output channels
        decoder_output_dummy = self.model.decoder(encoder_features_dummy)
        
        # Extract the channel dimension from the dummy output
        decoder_actual_out_channels = decoder_output_dummy.shape[1]
        # --- CORRECTION END ---

        # Add a 1x1 convolution to project the decoder's output to the desired num_feature_channels.
        self.feature_projection = torch.nn.Conv2d(
            in_channels=decoder_actual_out_channels,
            out_channels=num_feature_channels,
            kernel_size=1,
            stride=1,
            padding=0
        )
        
    def forward(self, x):
        # The encoder outputs a list of feature maps at different resolutions
        encoder_features = self.model.encoder(x)
        
        # The decoder takes these features and produces a high-resolution feature map.
        # This output will typically have the same spatial dimensions as the input 'x'
        # (due to DeepLabV3+ decoder's upsampling) but with its default channel count.
        decoder_output = self.model.decoder(encoder_features)
        
        # Project the decoder's output to the desired number of feature channels
        final_features = self.feature_projection(decoder_output)
        
        # These `final_features` are your f(X; θf)i,j with num_feature_channels.
        return final_features

In [67]:
class DiscriminativeCrossEntropyLoss(nn.Module):
    def __init__(self, prototypes: torch.Tensor, reduction: str = 'mean'):
        super().__init__()
       
        self.prototypes = prototypes
        self.reduction = reduction
        if prototypes.dim() != 2:
            raise ValueError("Prototypes must be a 2D tensor (num_classes, feature_dim)")

    def forward(self, pixel_features: torch.Tensor, target_labels: torch.Tensor):
        
        if pixel_features.dim() != 2:
            raise ValueError("pixel_features must be a 2D tensor (N_pixels, feature_dim)")
        if target_labels.dim() != 1:
            raise ValueError("target_labels must be a 1D tensor (N_pixels,)")
            
        num_pixels, feature_dim = pixel_features.shape
        num_classes, proto_feature_dim = self.prototypes.shape

        if feature_dim != proto_feature_dim:
            raise ValueError(f"Feature dimension mismatch: pixel_features ({feature_dim}) "
                             f"vs prototypes ({proto_feature_dim})")
        
        # 1. Calculate squared Euclidean distances from each pixel feature to ALL prototypes
        #   (N_pixels, feature_dim) - (num_classes, feature_dim) -> broadcasting
        #   Resulting distances_sq: (N_pixels, num_classes)
        
        # A bit more efficient way to compute all-pairs squared Euclidean distances:
        # ||a - b||^2 = ||a||^2 - 2<a,b> + ||b||^2
        
        # Calculate ||a||^2 for pixel_features
        pixel_features_sq_norm = torch.sum(pixel_features**2, dim=1, keepdim=True) # (N_pixels, 1)
        # Calculate ||b||^2 for prototypes
        prototypes_sq_norm = torch.sum(self.prototypes**2, dim=1, keepdim=True).T # (1, num_classes)
        
        # Calculate 2<a,b>
        dot_product = torch.matmul(pixel_features, self.prototypes.T) * 2 # (N_pixels, num_classes)
        
        # Combine to get squared distances
        # (N_pixels, 1) + (1, num_classes) - (N_pixels, num_classes)
        distances_sq = pixel_features_sq_norm + prototypes_sq_norm - dot_product
        
        # Ensure distances are non-negative due to potential floating point inaccuracies
        distances_sq = torch.clamp(distances_sq, min=0.0)

        # 2. Transform squared distances into "logits" (or similarity scores)
        # The paper uses exp(-distance^2)
        # logits_from_distances will be (N_pixels, num_classes)
        logits_from_distances = -distances_sq 
        # Note: applying exp() *after* this would be like a custom softmax.
        # However, nn.CrossEntropyLoss expects raw logits, so we keep them as -distance_sq
        # If the formula in the paper is a custom "softmax", then using F.log_softmax
        # on -distances_sq directly is the closest match for the structure of cross_entropy.

        # Let's verify the paper's formula with F.log_softmax/NLLLoss:
        # log( exp(A) / sum(exp(B)) ) = log_softmax(A)
        # So, the inner part of log is exactly a softmax on -distances_sq
        
        # F.log_softmax on -distances_sq
        log_probabilities = F.log_softmax(logits_from_distances, dim=1) # (N_pixels, num_classes)

        # NLLLoss expects log-probabilities
        # F.nll_loss directly calculates -log_probabilities[target_labels]
        # target_labels should be long type and contain class indices (0 to num_classes-1)
        loss = F.nll_loss(log_probabilities, target_labels, reduction=self.reduction)
        
        return loss

In [68]:
shape_resize = (512, 896)

if COMPUTE_MEAN_STD:
    mean_streethazards, std_streethazards = compute_mean_std_channels(StreetHazardsDataset(odgt_file= train_odgt_file,
                                                                                           image_resize = shape_resize,
                                                                                           spatial_transforms=None,
                                                                                           mean_std=None))
else:
    mean_streethazards, std_streethazards = [0.3302, 0.3459, 0.373], [0.1595, 0.1577, 0.1712]

spatial_transforms = transforms.v2.Compose([
    transforms.v2.RandomHorizontalFlip(),
])

train_dataset = StreetHazardsDataset(
    odgt_file= train_odgt_file,
    image_resize = shape_resize,
    spatial_transforms=spatial_transforms,
    mean_std=(mean_streethazards, std_streethazards)
)

val_dataset = StreetHazardsDataset(
    odgt_file= val_odgt_file,
    image_resize = shape_resize,
    spatial_transforms=None,
    mean_std=(mean_streethazards, std_streethazards)
)

test_dataset = StreetHazardsDataset(
    odgt_file= test_odgt_file,
    image_resize = shape_resize,
    spatial_transforms=None,
    mean_std=(mean_streethazards, std_streethazards)
)

train_dl = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=2)
val_dl = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=2)
test_dl = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=2)

In [69]:
torch.cuda.empty_cache()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Esempio di utilizzo:
encoder_name = "resnet34"
encoder_weights = "imagenet"
num_known_classes = 12 # Numero di classi per cui i prototipi sono one-hot
t_value = 3.0
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Inizializza il feature extractor
feature_extractor = DMLNetFeatureExtractor(
    encoder_name=encoder_name,
    encoder_weights=encoder_weights,
    num_feature_channels=num_known_classes,
    activation=None
).to(device)
model_optimizer = torch.optim.Adam(feature_extractor.parameters(), lr=0.001)

def train(num_epochs,model,train_loader,verbose= False) -> None:
        
        for epoch in tqdm(range(num_epochs), desc="Epoch"):
            
            model

            losses = []

            for batch in train_loader: 
                    
                imgs = batch['image'].to(device)
                labels = batch['labels'].to(device)
                
                logits = model(imgs)
                print(logits.shape)
                
                if type(logits) == tuple:
                    
                    vanilla_logits, logits = logits
                    loss_res = self.loss1(logits=logits, targets=labels.clone())
                else:
    
                    if not self.loss2:
                        loss_res = self.loss1(logits, labels)
    
                    else:

                        loss1_res = self.loss1(logits, labels)
                        loss2_res = self.loss2(logits, labels)
                        loss_res = self.loss_scheduler(loss1= loss1_res, loss2= loss2_res, epoch= epoch)
                        
                    del imgs, labels
                            
                losses.append(loss_res.item())
                
                self.optimizer.zero_grad()
                loss_res.backward()
                self.optimizer.step()
                self.scheduler.step()
            
                del loss_res
                

            l = sum(losses) / len(losses)

            print(f"Epoch {epoch + 1}", end = ' ')
            self.eval("train", epoch)
            self.eval("val", epoch)

            if self.patience and self.patience < self.step:
                if self.wandb_login:
                    wandb.finish()
                break

        if self.wandb_login:
            wandb.finish()

train(10,feature_extractor,train_dl)

OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 MiB. GPU 0 has a total capacity of 15.89 GiB of which 3.12 MiB is free. Process 48725 has 15.88 GiB memory in use. Of the allocated memory 15.55 GiB is allocated by PyTorch, and 47.82 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)