In [1]:
import os
import numpy as np
from PIL import Image
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.metrics import f1_score
import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp
import torch.nn.functional as F

In [2]:
def patchify_image(img, mask, patch_size=512, overlap=128):
    patches = []
    step = patch_size - overlap
    H, W = img.shape[:2]
    
    for y in range(0, H - patch_size + 1, step):
        for x in range(0, W - patch_size + 1, step):
            img_patch = img[y:y+patch_size, x:x+patch_size]
            mask_patch = mask[y:y+patch_size, x:x+patch_size]
            patches.append((img_patch, mask_patch))
    return patches


class XBDDataset(Dataset):
    def __init__(self, root_dir, patch_size=512, overlap=128, transform=None):
        self.img_dir = os.path.join(root_dir, "images")
        self.mask_dir = os.path.join(root_dir, "targets")
        self.patch_size = patch_size
        self.overlap = overlap
        self.transform = transform
        self.samples = []

        img_files = [f for f in os.listdir(self.img_dir) if f.endswith("_post_disaster.png")]
        print(f" Creating patch dataset from: {root_dir}")

        for f in tqdm(img_files):
            img_path = os.path.join(self.img_dir, f)
            mask_path = os.path.join(self.mask_dir, f.replace(".png", "_target.png"))
            if not os.path.exists(mask_path):
                continue

            img = np.array(Image.open(img_path).convert("RGB"))
            mask = np.array(Image.open(mask_path).convert("L"))

            for img_patch, mask_patch in patchify_image(img, mask, patch_size, overlap):
                if np.any(mask_patch > 0):
                    self.samples.append((img_patch, mask_patch))

        print(f" Total usable patches: {len(self.samples)}")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img, mask = self.samples[idx]
        if self.transform:
            augmented = self.transform(image=img, mask=mask)
            img = augmented["image"]
            mask = augmented["mask"]
        return img, mask.long()

In [8]:
train_transforms = A.Compose([
    A.RandomRotate90(),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomBrightnessContrast(0.2, 0.2),
    A.GaussNoise(var_limit=(10.0, 50.0)),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])

val_transforms = A.Compose([
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])

train_root = r"C:\Users\LENOVO\OneDrive - Plaksha University\xView2_PNG\train"
val_root = r"C:\Users\LENOVO\OneDrive - Plaksha University\xView2_PNG\hold"

train_dataset = XBDDataset(train_root, patch_size=256, overlap=64, transform=train_transforms)
val_dataset   = XBDDataset(val_root,   patch_size=256, overlap=64, transform=val_transforms)

train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=0, pin_memory=True)
val_loader   = DataLoader(val_dataset,   batch_size=2, shuffle=False, num_workers=0, pin_memory=True)

  A.GaussNoise(var_limit=(10.0, 50.0)),


 Creating patch dataset from: C:\Users\LENOVO\OneDrive - Plaksha University\xView2_PNG\train


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 2799/2799 [01:43<00:00, 27.10it/s]


 Total usable patches: 31410
 Creating patch dataset from: C:\Users\LENOVO\OneDrive - Plaksha University\xView2_PNG\hold


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 933/933 [00:55<00:00, 16.73it/s]


 Total usable patches: 10473


In [9]:
NUM_CLASSES = 5

model = smp.Unet(
    encoder_name="mit_b2",          # Transformer backbone (SegFormer-B2)
    encoder_weights="imagenet",
    in_channels=3,
    classes=NUM_CLASSES,
    activation=None
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(f" Model initialized on {device}")


def dice_loss(pred, target, smooth=1e-6):
    pred = torch.softmax(pred, dim=1)
    target_onehot = F.one_hot(target, num_classes=pred.shape[1]).permute(0,3,1,2).float()
    intersection = (pred * target_onehot).sum(dim=(2,3))
    union = pred.sum(dim=(2,3)) + target_onehot.sum(dim=(2,3))
    dice = (2. * intersection + smooth) / (union + smooth)
    return 1 - dice.mean()


def weighted_ce_loss(pred, target, weights):
    return F.cross_entropy(pred, target, weight=weights.to(pred.device))


def combined_loss(pred, target, weights, alpha=1.0, beta=0.5):
    ce = weighted_ce_loss(pred, target, weights)
    dice = dice_loss(pred, target)
    return alpha * ce + beta * dice


optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)


 Model initialized on cpu


In [10]:
weights = torch.tensor([0.2, 1.0, 1.2, 1.3, 1.3])
weights = weights / weights.sum() * len(weights)
print("Class weights:", weights)

Class weights: tensor([0.2000, 1.0000, 1.2000, 1.3000, 1.3000])


In [11]:
def compute_iou(preds, targets, num_classes=5):
    preds = torch.argmax(preds, dim=1).cpu().numpy()
    targets = targets.cpu().numpy()
    ious = []
    for cls in range(num_classes):
        inter = np.logical_and(preds == cls, targets == cls).sum()
        union = np.logical_or(preds == cls, targets == cls).sum()
        ious.append(np.nan if union == 0 else inter / union)
    return ious


def compute_weighted_f1(preds, targets, num_classes=5):
    preds = torch.argmax(preds, dim=1).cpu().numpy().ravel()
    targets = targets.cpu().numpy().ravel()
    f1_per_class = f1_score(targets, preds, average=None, labels=list(range(num_classes)))
    support = np.array([(targets == i).sum() for i in range(num_classes)])
    weighted_f1 = np.sum(f1_per_class * support) / np.sum(support)
    return weighted_f1, f1_per_class


def train_one_epoch(model, dataloader, optimizer, weights):
    model.train()
    running_loss = 0.0
    for imgs, masks in tqdm(dataloader, desc="Training", leave=False):
        imgs, masks = imgs.to(device), masks.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = combined_loss(outputs, masks, weights)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    return running_loss / len(dataloader)


def validate(model, dataloader, weights):
    model.eval()
    val_loss = 0.0
    all_preds, all_targets = [], []
    with torch.no_grad():
        for imgs, masks in tqdm(dataloader, desc="Validating", leave=False):
            imgs, masks = imgs.to(device), masks.to(device)
            outputs = model(imgs)
            loss = combined_loss(outputs, masks, weights)
            val_loss += loss.item()
            all_preds.append(outputs.cpu())
            all_targets.append(masks.cpu())

    all_preds = torch.cat(all_preds)
    all_targets = torch.cat(all_targets)
    weighted_f1, f1_per_class = compute_weighted_f1(all_preds, all_targets)
    ious = compute_iou(all_preds, all_targets)

    return {
        "val_loss": val_loss / len(dataloader),
        "weighted_f1": weighted_f1,
        "f1_per_class": f1_per_class,
        "iou_per_class": ious
    }

In [13]:
NUM_EPOCHS = 10
best_f1 = 0
save_path = "best_model.pth"

for epoch in range(NUM_EPOCHS):
    print(f"\n Epoch {epoch+1}/{NUM_EPOCHS}")
    train_loss = train_one_epoch(model, train_loader, optimizer, weights)
    metrics = validate(model, val_loader, weights)
    scheduler.step()

    print(f"Train Loss: {train_loss:.4f} | Val Loss: {metrics['val_loss']:.4f}")
    print(f"Weighted F1: {metrics['weighted_f1']:.4f}")
    print(f"IoU: {np.round(metrics['iou_per_class'], 3)}")
    print(f"F1:  {np.round(metrics['f1_per_class'], 3)}")

    if metrics["weighted_f1"] > best_f1:
        best_f1 = metrics["weighted_f1"]
        torch.save(model.state_dict(), save_path)
        print(f" Saved new best model with Weighted F1 = {best_f1:.4f}")


ðŸ“˜ Epoch 1/50


                                                                  

KeyboardInterrupt: 

In [None]:
model.load_state_dict(torch.load("best_model.pth"))
model.eval()

imgs, masks = next(iter(val_loader))
imgs, masks = imgs.to(device), masks.to(device)
with torch.no_grad():
    preds = model(imgs)
pred_classes = torch.argmax(preds, dim=1)

i = 0
plt.figure(figsize=(15,5))
plt.subplot(1,3,1); plt.imshow(imgs[i].permute(1,2,0).cpu()); plt.title("Image")
plt.subplot(1,3,2); plt.imshow(masks[i].cpu()); plt.title("Ground Truth")
plt.subplot(1,3,3); plt.imshow(pred_classes[i].cpu()); plt.title("Prediction")
plt.show()