In [2]:
import os
from PIL import Image
import torch
import numpy as np
from tqdm import tqdm
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation

In [9]:
# === CONFIG ===
IMAGE_DIR = "input_images"
OUTPUT_DIR = "segmented_output"
MODEL_NAME = "nvidia/segformer-b0-finetuned-cityscapes-768-768"

# Cityscapes class mapping
CITYSCAPES_ID2LABEL = {
    0: 'road', 1: 'sidewalk', 2: 'building', 3: 'wall', 4: 'fence',
    5: 'pole', 6: 'traffic_light', 7: 'traffic_sign', 8: 'vegetation', 9: 'terrain',
    10: 'sky', 11: 'person', 12: 'rider', 13: 'car', 14: 'truck',
    15: 'bus', 16: 'train', 17: 'motorcycle', 18: 'bicycle',
}

# === SETUP ===
os.makedirs(OUTPUT_DIR, exist_ok=True)
for class_name in CITYSCAPES_ID2LABEL.values():
    os.makedirs(os.path.join(OUTPUT_DIR, class_name), exist_ok=True)


In [4]:
# Load model
feature_extractor = SegformerFeatureExtractor.from_pretrained(MODEL_NAME)
model = SegformerForSemanticSegmentation.from_pretrained(MODEL_NAME).eval()

preprocessor_config.json:   0%|          | 0.00/272 [00:00<?, ?B/s]

  return func(*args, **kwargs)


config.json:   0%|          | 0.00/1.68k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/15.0M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/14.9M [00:00<?, ?B/s]

In [7]:
def segment_and_save(image_path):
    image = Image.open(image_path).convert("RGB")
    image_np = np.array(image)
    inputs = feature_extractor(images=image, return_tensors="pt")

    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits  # (1, num_classes, h/4, w/4)
        upsampled_logits = torch.nn.functional.interpolate(
            logits,
            size=image.size[::-1],  # (H, W)
            mode="bilinear",
            align_corners=False
        )
        predicted = upsampled_logits.argmax(dim=1)[0].cpu().numpy()  # (H, W)

    base_name = os.path.splitext(os.path.basename(image_path))[0]

    for class_idx, class_name in CITYSCAPES_ID2LABEL.items():
        mask = (predicted == class_idx).astype(np.uint8)

        if np.any(mask):
            # Apply mask to original image
            masked_img = image_np.copy()
            masked_img[mask == 0] = 0  # Zero out everything except target class

            masked_pil = Image.fromarray(masked_img)
            save_path = os.path.join(OUTPUT_DIR, class_name, f"{base_name}_{class_name}.png")
            masked_pil.save(save_path)


In [10]:
image_files = [f for f in os.listdir(IMAGE_DIR) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]

for filename in tqdm(image_files, desc="Segmenting images"):
    segment_and_save(os.path.join(IMAGE_DIR, filename))

print("Done. Masks saved in:", OUTPUT_DIR)


Segmenting images: 100%|██████████| 1/1 [00:01<00:00,  1.97s/it]

Done. Masks saved in: segmented_output



