In [1]:
import os
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image
from tqdm import tqdm
import torch.nn.functional as F

In [2]:
NUM_CLASSES = 2
BATCH_SIZE = 16
EPOCHS = 20
LR = 5e-5
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(42)

<torch._C.Generator at 0x7f9e74d80750>

In [3]:
class SegDataset(Dataset):
    def __init__(self, img_dir, mask_dir):
        self.imgs = sorted(os.listdir(img_dir))
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.transform = A.Compose([
            A.Resize(512, 512),
            A.HorizontalFlip(p=0.5),
            A.Normalize(),
            ToTensorV2()
        ])

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        img = Image.open(os.path.join(self.img_dir, self.imgs[idx])).convert("RGB")
        mask_pil = Image.open(os.path.join(self.mask_dir, self.imgs[idx]))
        mask_np = np.array(mask_pil)

        # If mask is RGB, convert to single-channel labels
        if mask_np.ndim == 3 and mask_np.shape[2] == 3:
            # If channels identical (e.g., saved gray as RGB), take first channel
            if np.all(mask_np[:, :, 0] == mask_np[:, :, 1]) and np.all(mask_np[:, :, 1] == mask_np[:, :, 2]):
                mask_np = mask_np[:, :, 0]
            else:
                # Colored segmentation map: map each unique RGB color to a class index
                h, w, _ = mask_np.shape
                flat = mask_np.reshape(-1, 3)
                colors, inverse = np.unique(flat, axis=0, return_inverse=True)
                label_mask = inverse.reshape(h, w).astype(np.uint8)
                mask_np = label_mask

        # Now mask_np should be 2D (H, W). Handle common encodings like 0/255 for binary masks
        if mask_np.ndim == 2:
            if mask_np.max() > (NUM_CLASSES - 1):
                if NUM_CLASSES == 2:
                    # Map any non-zero value (e.g., 255) to 1
                    mask_np = (mask_np > 0).astype(np.uint8)
                else:
                    # For multi-class masks with unexpected label values, remap uniques to 0..K-1
                    uniques = np.unique(mask_np)
                    mapping = {int(v): i for i, v in enumerate(uniques)}
                    vec_map = np.vectorize(lambda x: mapping[int(x)])
                    mask_np = vec_map(mask_np).astype(np.uint8)

        augmented = self.transform(image=np.array(img), mask=mask_np)
        return augmented["image"], augmented["mask"].long()


In [4]:
dataset = SegDataset("dataset/images", "dataset/masks")
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_ds, val_ds = torch.utils.data.random_split(dataset, [train_size, val_size])

In [5]:
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=1)

In [6]:
model = SegformerForSemanticSegmentation.from_pretrained(
    "nvidia/segformer-b0-finetuned-ade-512-512",
    num_labels=NUM_CLASSES,
    ignore_mismatched_sizes=True
).to(DEVICE)

Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/segformer-b0-finetuned-ade-512-512 and are newly initialized because the shapes did not match:
- decode_head.classifier.bias: found shape torch.Size([150]) in the checkpoint and torch.Size([2]) in the model instantiated
- decode_head.classifier.weight: found shape torch.Size([150, 256, 1, 1]) in the checkpoint and torch.Size([2, 256, 1, 1]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
model

SegformerForSemanticSegmentation(
  (segformer): SegformerModel(
    (encoder): SegformerEncoder(
      (patch_embeddings): ModuleList(
        (0): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(3, 32, kernel_size=(7, 7), stride=(4, 4), padding=(3, 3))
          (layer_norm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
        )
        (1): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        )
        (2): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(64, 160, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((160,), eps=1e-05, elementwise_affine=True)
        )
        (3): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(160, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  

In [8]:
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)

In [9]:
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    for x, y in tqdm(train_loader):
        x, y = x.to(DEVICE), y.to(DEVICE)
        out = model(pixel_values=x, labels=y)
        loss = out.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        total_loss += loss.item()

    print(f"Epoch {epoch}: Train Loss = {total_loss/len(train_loader):.4f}")

  0%|          | 0/30 [00:00<?, ?it/s]

100%|██████████| 30/30 [00:09<00:00,  3.05it/s]
100%|██████████| 30/30 [00:09<00:00,  3.05it/s]


Epoch 0: Train Loss = 0.6438


100%|██████████| 30/30 [00:07<00:00,  4.28it/s]
100%|██████████| 30/30 [00:07<00:00,  4.28it/s]


Epoch 1: Train Loss = 0.5697


100%|██████████| 30/30 [00:07<00:00,  4.27it/s]
100%|██████████| 30/30 [00:07<00:00,  4.27it/s]


Epoch 2: Train Loss = 0.5205


100%|██████████| 30/30 [00:07<00:00,  4.22it/s]
100%|██████████| 30/30 [00:07<00:00,  4.22it/s]


Epoch 3: Train Loss = 0.4796


100%|██████████| 30/30 [00:07<00:00,  4.20it/s]
100%|██████████| 30/30 [00:07<00:00,  4.20it/s]


Epoch 4: Train Loss = 0.4473


100%|██████████| 30/30 [00:07<00:00,  4.25it/s]



Epoch 5: Train Loss = 0.4309


100%|██████████| 30/30 [00:07<00:00,  4.24it/s]
100%|██████████| 30/30 [00:07<00:00,  4.24it/s]


Epoch 6: Train Loss = 0.4041


100%|██████████| 30/30 [00:06<00:00,  4.29it/s]
100%|██████████| 30/30 [00:06<00:00,  4.29it/s]


Epoch 7: Train Loss = 0.3918


100%|██████████| 30/30 [00:06<00:00,  4.34it/s]
100%|██████████| 30/30 [00:06<00:00,  4.34it/s]


Epoch 8: Train Loss = 0.3830


100%|██████████| 30/30 [00:06<00:00,  4.37it/s]
100%|██████████| 30/30 [00:06<00:00,  4.37it/s]


Epoch 9: Train Loss = 0.3709


100%|██████████| 30/30 [00:06<00:00,  4.29it/s]
100%|██████████| 30/30 [00:06<00:00,  4.29it/s]


Epoch 10: Train Loss = 0.3609


100%|██████████| 30/30 [00:06<00:00,  4.32it/s]
100%|██████████| 30/30 [00:06<00:00,  4.32it/s]


Epoch 11: Train Loss = 0.3499


100%|██████████| 30/30 [00:06<00:00,  4.30it/s]
100%|██████████| 30/30 [00:06<00:00,  4.30it/s]


Epoch 12: Train Loss = 0.3464


100%|██████████| 30/30 [00:07<00:00,  4.25it/s]
100%|██████████| 30/30 [00:07<00:00,  4.25it/s]


Epoch 13: Train Loss = 0.3363


100%|██████████| 30/30 [00:07<00:00,  3.87it/s]
100%|██████████| 30/30 [00:07<00:00,  3.87it/s]


Epoch 14: Train Loss = 0.3285


100%|██████████| 30/30 [00:07<00:00,  3.89it/s]
100%|██████████| 30/30 [00:07<00:00,  3.89it/s]


Epoch 15: Train Loss = 0.3262


100%|██████████| 30/30 [00:06<00:00,  4.29it/s]
100%|██████████| 30/30 [00:06<00:00,  4.29it/s]


Epoch 16: Train Loss = 0.3183


100%|██████████| 30/30 [00:06<00:00,  4.34it/s]
100%|██████████| 30/30 [00:06<00:00,  4.34it/s]


Epoch 17: Train Loss = 0.3081


100%|██████████| 30/30 [00:06<00:00,  4.30it/s]
100%|██████████| 30/30 [00:06<00:00,  4.30it/s]


Epoch 18: Train Loss = 0.3072


100%|██████████| 30/30 [00:06<00:00,  4.34it/s]

Epoch 19: Train Loss = 0.3008





In [10]:
model.eval()
with torch.no_grad():
    total_correct = 0
    total_pixels = 0
    for x, y in tqdm(val_loader):
        x, y = x.to(DEVICE), y.to(DEVICE)
        out = model(pixel_values=x)
        upsampled_logits = F.interpolate(
            out.logits,
            size=y.shape[-2:],  # (H, W)
            mode="bilinear",
            align_corners=False
        )
        preds = torch.argmax(upsampled_logits, dim=1)
        
        total_correct += (preds == y).sum().item()
        total_pixels += torch.numel(y)

    print(f"Validation Accuracy = {total_correct/total_pixels:.4f}")

100%|██████████| 120/120 [00:02<00:00, 53.21it/s]

Validation Accuracy = 0.8442





In [11]:
from torchmetrics import JaccardIndex

model.eval()
iou_metric = JaccardIndex(
    task="binary",
).to(DEVICE)

with torch.no_grad():
    for x, mask in val_loader:
        x = x.to(DEVICE)
        mask = mask.to(DEVICE)

        logits = model(pixel_values=x).logits
        logits = F.interpolate(
            logits,
            size=mask.shape[-2:],  # (H, W)
            mode="bilinear",
            align_corners=False
        )

        preds = logits.argmax(dim=1)
        iou_metric.update(preds, mask)

mean_iou = iou_metric.compute()
print(f"Mean IoU: {mean_iou:.4f}")

Mean IoU: 0.7245
