In [1]:
from google.colab import drive
import os

# Mount Drive
drive.mount('/content/drive')

# Define dataset paths
base_path = "/content/drive/MyDrive/augmented_dataset"

train_images = os.path.join(base_path, "images/train")
train_masks  = os.path.join(base_path, "masks/train")

val_images   = os.path.join(base_path, "images/valid")
val_masks    = os.path.join(base_path, "masks/valid")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
from PIL import Image
import numpy as np
import os

def get_unique_classes(mask_dir):
    unique_classes = set()
    for mask_file in os.listdir(mask_dir):
        mask_path = os.path.join(mask_dir, mask_file)
        # Check if the item is a file before opening it
        if os.path.isfile(mask_path):
            mask = np.array(Image.open(mask_path))
            unique_classes.update(np.unique(mask))
    return sorted(list(unique_classes))

# Detect classes from training masks
classes = get_unique_classes(train_masks)
print("Detected Classes:", classes)

# Optional: Remove ignore class (255)
if 255 in classes:
    classes.remove(255)

NUM_CLASSES = len(classes)
print("Total Classes:", NUM_CLASSES)


Detected Classes: [np.uint8(0), np.uint8(1), np.uint8(2), np.uint8(3), np.uint8(4), np.uint8(5), np.uint8(6), np.uint8(7), np.uint8(8), np.uint8(9), np.uint8(10), np.uint8(11), np.uint8(12), np.uint8(13), np.uint8(14), np.uint8(15), np.uint8(16), np.uint8(17), np.uint8(18), np.uint8(19), np.uint8(20), np.uint8(21), np.uint8(22), np.uint8(23), np.uint8(24), np.uint8(25), np.uint8(26), np.uint8(27), np.uint8(28), np.uint8(29), np.uint8(30), np.uint8(31), np.uint8(32), np.uint8(33), np.uint8(34), np.uint8(35), np.uint8(36), np.uint8(37), np.uint8(38), np.uint8(39), np.uint8(40), np.uint8(41), np.uint8(42), np.uint8(43), np.uint8(44), np.uint8(45), np.uint8(46), np.uint8(47), np.uint8(48), np.uint8(49), np.uint8(50), np.uint8(51), np.uint8(52), np.uint8(53), np.uint8(54), np.uint8(55), np.uint8(56), np.uint8(57), np.uint8(58), np.uint8(59), np.uint8(60), np.uint8(61), np.uint8(62), np.uint8(63), np.uint8(64), np.uint8(65), np.uint8(66), np.uint8(67), np.uint8(68), np.uint8(69), np.uint8(70

In [24]:
from torchvision import transforms
from PIL import Image
import numpy as np
import os
import torch

class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, feature_extractor, image_size=(512, 512)): # Added image_size
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.images = sorted(os.listdir(image_dir))
        # Filter out directories in the mask directory
        self.masks = sorted([f for f in os.listdir(mask_dir) if os.path.isfile(os.path.join(mask_dir, f))])
        self.feature_extractor = feature_extractor
        self.image_size = image_size # Store image_size
        # Define a transform to resize images and masks
        self.transform = transforms.Compose([
            transforms.Resize(self.image_size), # Resize to the specified size
            transforms.ToTensor(), # Convert to tensor
        ])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.images[idx])
        mask_path = os.path.join(self.mask_dir, self.masks[idx])

        # Load image and mask
        image = Image.open(image_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        # Apply transforms to image and mask
        image = self.transform(image)  # Resize and convert image to tensor
        mask = self.transform(mask)  # Resize and convert mask to tensor
        mask = torch.round(mask).long()  # Round mask to integers after scaling

        # Apply feature extractor to image
        encoded = self.feature_extractor(images=image, return_tensors="pt")
        pixel_values = encoded["pixel_values"].squeeze()

        # Convert mask to tensor [H, W]
        # No longer need this, as it was handled by the transform
        # mask_tensor = torch.tensor(np.array(mask), dtype=torch.long)

        return {
            "pixel_values": pixel_values,
            "labels": mask.squeeze() # Use the resized mask tensor
        }

In [25]:
from torch.utils.data import DataLoader

feature_extractor = SegformerFeatureExtractor(reduce_labels=False)

train_dataset = SegmentationDataset(train_images, train_masks, feature_extractor)
val_dataset = SegmentationDataset(val_images, val_masks, feature_extractor)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4)




In [26]:
from transformers import SegformerForSemanticSegmentation

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SegformerForSemanticSegmentation.from_pretrained(
    "nvidia/segformer-b0-finetuned-ade-512-512",
    num_labels=NUM_CLASSES,
    ignore_mismatched_sizes=True
).to(device)


Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/segformer-b0-finetuned-ade-512-512 and are newly initialized because the shapes did not match:
- decode_head.classifier.bias: found shape torch.Size([150]) in the checkpoint and torch.Size([255]) in the model instantiated
- decode_head.classifier.weight: found shape torch.Size([150, 256, 1, 1]) in the checkpoint and torch.Size([255, 256, 1, 1]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [27]:
from torch import nn, optim
from tqdm import tqdm

optimizer = optim.AdamW(model.parameters(), lr=5e-5)
EPOCHS = 5

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)

        outputs = model(pixel_values=pixel_values, labels=labels)
        loss = outputs.loss

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1} - Training Loss: {avg_loss:.4f}")


Epoch 1/5: 100%|██████████| 375/375 [04:00<00:00,  1.56it/s]


Epoch 1 - Training Loss: 3.4230


Epoch 2/5: 100%|██████████| 375/375 [03:47<00:00,  1.65it/s]


Epoch 2 - Training Loss: 0.9354


Epoch 3/5: 100%|██████████| 375/375 [03:45<00:00,  1.66it/s]


Epoch 3 - Training Loss: 0.4441


Epoch 4/5: 100%|██████████| 375/375 [03:46<00:00,  1.66it/s]


Epoch 4 - Training Loss: 0.3475


Epoch 5/5: 100%|██████████| 375/375 [03:47<00:00,  1.65it/s]

Epoch 5 - Training Loss: 0.3204





In [28]:
def evaluate(model, dataloader, num_classes):
    model.eval()
    total_correct = 0
    total_pixels = 0
    ious = []

    with torch.no_grad():
        for batch in dataloader:
            pixel_values = batch["pixel_values"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(pixel_values=pixel_values)
            preds = torch.argmax(outputs.logits, dim=1)

            total_correct += (preds == labels).sum().item()
            total_pixels += torch.numel(preds)

            for cls in range(num_classes):
                pred_inds = preds == cls
                label_inds = labels == cls
                intersection = (pred_inds & label_inds).sum().item()
                union = (pred_inds | label_inds).sum().item()
                if union != 0:
                    ious.append(intersection / union)

    pixel_acc = total_correct / total_pixels
    mean_iou = sum(ious) / len(ious)

    print(f"Pixel Accuracy: {pixel_acc:.4f}")
    print(f"Mean IoU: {mean_iou:.4f}")
