In [None]:
pip install transformers datasets torch torchvision segmentation-models-pytorch albumentations

In [None]:
import os
import torch
import torchvision.transforms as T
import albumentations as A
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
import torch.nn as nn
import torch.optim as optim

# Set dataset directory
DATASET_DIR = "/path/to/your/dataset"  # UPDATE THIS

# Define paths
train_img_dir = os.path.join(DATASET_DIR, "train")
train_mask_dir = os.path.join(DATASET_DIR, "train_label")
val_img_dir = os.path.join(DATASET_DIR, "validation")
val_mask_dir = os.path.join(DATASET_DIR, "validation_label")

# SegFormer processor
processor = SegformerImageProcessor(do_resize=True, size=512, do_normalize=True)

# Augmentations
train_transform = A.Compose([
    A.Resize(512, 512),
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.3),
])

val_transform = A.Compose([
    A.Resize(512, 512)
])

# Custom dataset
class SegmentationDataset(Dataset):
    def __init__(self, img_dir, mask_dir, transform=None):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = sorted(os.listdir(img_dir))
        self.masks = sorted(os.listdir(mask_dir))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.images[idx])
        mask_path = os.path.join(self.mask_dir, self.masks[idx])

        # Load image and mask
        image = read_image(img_path).float() / 255.0  # Normalize
        mask = read_image(mask_path)[0]  # Load as grayscale

        # Convert to numpy
        image_np = image.permute(1, 2, 0).numpy()
        mask_np = mask.numpy()

        # Apply augmentations
        if self.transform:
            augmented = self.transform(image=image_np, mask=mask_np)
            image_np, mask_np = augmented["image"], augmented["mask"]

        # Convert back to tensor
        image = torch.tensor(image_np).permute(2, 0, 1)
        mask = torch.tensor(mask_np, dtype=torch.long)

        return image, mask

# Load datasets
train_dataset = SegmentationDataset(train_img_dir, train_mask_dir, transform=train_transform)
val_dataset = SegmentationDataset(val_img_dir, val_mask_dir, transform=val_transform)

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=4)

# Load SegFormer model
model = SegformerForSemanticSegmentation.from_pretrained(
    "nvidia/segformer-b0-finetuned-ade-512-512",
    num_labels=5,  # 5 classes (0,1,2,3,4)
    ignore_mismatched_sizes=True
)

# Move to GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=5e-5)

# Training loop
EPOCHS = 20
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0

    for images, masks in train_loader:
        images, masks = images.to(device), masks.to(device)

        optimizer.zero_grad()
        outputs = model(pixel_values=images).logits  # (B, num_classes, H, W)
        loss = criterion(outputs, masks)  # CE Loss

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{EPOCHS} - Loss: {avg_loss:.4f}")

# Save model
torch.save(model.state_dict(), "segformer_model.pth")
print("Model saved successfully!")
