revise path for your dataset

In [1]:
data_root = "/data/DATA/ADE20K"
data_path = "/data/DATA/ADE20K/ADEChallengeData2016"

load related packages and visualize catalogue of dataset

In [None]:
import os

annotations_path = os.path.join(data_path, "annotations")
annotations_instance_path = os.path.join(annotations_path, "annotations_instance")
sceneCategories_path = os.path.join(annotations_path, "sceneCategories.txt")
images_path = os.path.join(data_path, "images")
objectInfo150_path = os.path.join(data_path, "objectInfo150.txt")
annotations_detectron2_path = os.path.join(annotations_path, "annotations_detectron2")

os.listdir(data_path)

initilize config

In [3]:
class Config:
    
    class DATASET:
        root_dataset= data_root
        list_val= "Data/validation.odgt"
        num_class= 150
        imgSizes= (300, 375, 450, 525, 600)
        imgMaxSize= 1000
        padding_constant= 32
        segm_downsampling_rate= 4
        random_flip: True


load dataset and initialize dataloader

In [None]:
import torch
from dataset import ValDataset
from PIL import Image

cfg = Config()
# Dataset and Loader
dataset_val = ValDataset(
    cfg.DATASET.root_dataset,
    cfg.DATASET.list_val,
    cfg.DATASET)
loader_val = torch.utils.data.DataLoader(
    dataset_val,
    batch_size=1,
    shuffle=False,
    num_workers=1,
    drop_last=True)


define appropriate metrics for image segmentation

In [5]:
def accuracy(preds, label):
    valid = (label >= 0)
    acc_sum = (valid * (preds == label)).sum()
    valid_sum = valid.sum()
    acc = float(acc_sum) / (valid_sum + 1e-10)
    return acc, valid_sum

def intersectionAndUnion(imPred, imLab, numClass):
    imPred = np.asarray(imPred).copy()
    imLab = np.asarray(imLab).copy()

    imPred += 1
    imLab += 1
    # Remove classes from unlabeled pixels in gt image.
    # We should not penalize detections in unlabeled portions of the image.
    imPred = imPred * (imLab > 0)

    # Compute area intersection:
    intersection = imPred * (imPred == imLab)
    (area_intersection, _) = np.histogram(
        intersection, bins=numClass, range=(1, numClass))

    # Compute area union:
    (area_pred, _) = np.histogram(imPred, bins=numClass, range=(1, numClass))
    (area_lab, _) = np.histogram(imLab, bins=numClass, range=(1, numClass))
    area_union = area_pred + area_lab - area_intersection

    return (area_intersection, area_union)

Testing with different traditional segmentation methods and visualization selected samples.

In [None]:
import cv2
import collections
from sklearn.cluster import KMeans
from torch.autograd import Variable
from sklearn.cluster import MiniBatchKMeans

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.initialized = False
        self.val = None
        self.avg = None
        self.sum = None
        self.count = None

    def initialize(self, val, weight):
        self.val = val
        self.avg = val
        self.sum = val * weight
        self.count = weight
        self.initialized = True

    def update(self, val, weight=1):
        if not self.initialized:
            self.initialize(val, weight)
        else:
            self.add(val, weight)

    def add(self, val, weight):
        self.val = val
        self.sum += val * weight
        self.count += weight
        self.avg = self.sum / self.count

    def value(self):
        return self.val

    def average(self):
        return self.avg

def as_numpy(obj):
    if isinstance(obj, collections.Sequence):
        return [as_numpy(v) for v in obj]
    elif isinstance(obj, collections.Mapping):
        return {k: as_numpy(v) for k, v in obj.items()}
    elif isinstance(obj, Variable):
        return obj.data.cpu().numpy()
    elif torch.is_tensor(obj):
        return obj.cpu().numpy()
    else:
        
        return np.array(obj)

def traditional_watershed_segmentation(image, num_classes):

    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    blur = cv2.GaussianBlur(gray, (5,5), 0)

    _, thresh = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)

    kernel = np.ones((3,3), np.uint8)
    opening = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel, iterations=2)
    
    sure_bg = cv2.dilate(opening, kernel, iterations=3)
  
    dist_transform = cv2.distanceTransform(opening, cv2.DIST_L2, 5)
    _, sure_fg = cv2.threshold(dist_transform, 0.5*dist_transform.max(), 255, 0)
    sure_fg = np.uint8(sure_fg)

    unknown = cv2.subtract(sure_bg, sure_fg)

    _, markers = cv2.connectedComponents(sure_fg)
    markers = markers + 1
    markers[unknown == 255] = 0
    
    markers = cv2.watershed(image, markers)
    
    unique_markers = np.unique(markers)
    if len(unique_markers) > num_classes:

        lab_image = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
        region_features = []
        for marker in unique_markers:
            if marker == -1:  
                continue
            mask = (markers == marker)
            region_mean = np.mean(lab_image[mask], axis=0)
            region_features.append(region_mean)
            
        kmeans = MiniBatchKMeans(n_clusters=num_classes, random_state=42)
        region_labels = kmeans.fit_predict(region_features)
        

        new_markers = np.zeros_like(markers)
        for i, marker in enumerate(unique_markers):
            if marker == -1:
                continue
            new_markers[markers == marker] = region_labels[i-1] 
        markers = new_markers
    
    return markers.astype(np.int32)

def traditional_canny_edge_segmentation(image, num_classes):

    scale_factor = 0.5
    h, w = image.shape[:2]
    small_h, small_w = int(h * scale_factor), int(w * scale_factor)
    small_image = cv2.resize(image, (small_w, small_h))

    lab_image = cv2.cvtColor(small_image, cv2.COLOR_RGB2LAB)
    

    y, x = np.mgrid[0:small_h, 0:small_w]
    x = x / small_w * 255
    y = y / small_h * 255

    features = np.concatenate([
        lab_image.reshape(-1, 3),
        x.reshape(-1, 1) * 0.2,
        y.reshape(-1, 1) * 0.2
    ], axis=1)
    
    kmeans = MiniBatchKMeans(
        n_clusters=num_classes,
        batch_size=1000,  
        random_state=42,
        n_init='auto'
    )
    labels = kmeans.fit_predict(features)

    segmentation = labels.reshape(small_h, small_w)

    segmentation = cv2.resize(
        segmentation.astype(float), 
        (w, h), 
        interpolation=cv2.INTER_NEAREST
    )
    
    return segmentation.astype(np.int32)



import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
def visualize_image(img_tensor):
    if torch.is_tensor(img_tensor):
        img_array = img_tensor.cpu().numpy()
    else:
        img_array = img_tensor  
        
    if img_array.ndim == 3 and img_array.shape[0] in [1, 3]:
        img_array = np.transpose(img_array, (1, 2, 0))
    
    if img_array.min() < 0 or img_array.max() > 1:
        img_array = (img_array - img_array.min()) / (img_array.max() - img_array.min())

    plt.figure(figsize=(5, 4))
    plt.imshow(img_array)
    plt.axis('off')
    plt.show()


def visualize_segmentation(segmentation_map):
    """
    可视化分割结果
    Args:
        segmentation_map: shape应该是(H, W)的2D数组
    """
    # 确保输入是2D数组
    if segmentation_map.ndim > 2:
        segmentation_map = segmentation_map.squeeze()
    
    # 生成随机颜色映射
    np.random.seed(42)  # 保持颜色一致性
    colors = np.random.randint(0, 255, (150, 3), dtype=np.uint8)
    colors[0] = [0, 0, 0]  # 背景设为黑色
    
    # 创建彩色可视化图
    visualization = colors[segmentation_map.astype(np.int32)]
    
    plt.figure(figsize=(5, 4))
    plt.imshow(visualization.astype(np.uint8))
    plt.axis('off')
    plt.show()

acc_meter = AverageMeter()
intersection_meter = AverageMeter()
union_meter = AverageMeter()
time_meter = AverageMeter()

for i, one_batch in enumerate(loader_val):
    original_image = one_batch['img_ori'][0]
    resized_images = one_batch['img_data']
    seg_label = one_batch['seg_label'][0]
    
    if i == 0:
        visualize_image(seg_label)
        visualize_image(original_image)

        
    pred = traditional_canny_edge_segmentation(as_numpy(original_image), cfg.DATASET.num_class)
    visualize_image(as_numpy(pred)[None, ...])
        
    acc, pix = accuracy(as_numpy(pred), as_numpy(seg_label))

    intersection, union = intersectionAndUnion(as_numpy(pred), as_numpy(seg_label), cfg.DATASET.num_class)
    acc_meter.update(acc, pix)
    intersection_meter.update(intersection)
    union_meter.update(union)

iou = intersection_meter.sum / (union_meter.sum + 1e-10)
for i, _iou in enumerate(iou):
    print('class [{}], IoU: {:.4f}'.format(i, _iou))

print('[Eval Summary]:')
print('Mean IoU: {:.4f}, Accuracy: {:.2f}%'.format(iou.mean(), acc_meter.average()*100))

