<div class="alert alert-block alert-info">

----------
---------
# <b> 1. Imports</b> 

--------------
----------------
</div>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import os, csv
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
import matplotlib.pyplot as plt


  from .autonotebook import tqdm as notebook_tqdm


<div class="alert alert-block alert-info">

----------
---------
# <b> 2. Configuration & Paths</b> 

--------------
----------------
</div>

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)


# Paths 

ROOT = "../data"

# The following absolute path was used on the HPC server and is kept here
# only for reference. It is NOT required to run this notebook locally.
# ROOT = "/home/javid/segmentation_resnet/data"

TRAIN_LIST = os.path.join(ROOT, "train.txt")
VAL_LIST   = os.path.join(ROOT, "val.txt")
TEST_LIST  = os.path.join(ROOT, "test.txt")

IMG_SIZE = 512
BATCH_SIZE = 4
EPOCHS = 35   
NUM_WORKERS = 4
LR = 5e-5


Using device: cuda


<div class="alert alert-block alert-info">

----------
---------
# <b> 3. Model and Image Processor</b> 

--------------
----------------
</div>

In [None]:
processor = SegformerImageProcessor(
    do_resize=True,
    size={"height": 512, "width": 512},
    resample=Image.BILINEAR,
    do_normalize=True
)

model = SegformerForSemanticSegmentation.from_pretrained(
    "nvidia/segformer-b2-finetuned-ade-512-512",
    num_labels=1,
    ignore_mismatched_sizes=True
).to(DEVICE)

print("SegFormer-B2 loaded successfully.")


# The warning about newly initialized weights is expected.
# The pretrained checkpoint uses 150 classes (ADE20K),
# while this project performs binary segmentation (1 class).

Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/segformer-b2-finetuned-ade-512-512 and are newly initialized because the shapes did not match:
- decode_head.classifier.weight: found shape torch.Size([150, 768, 1, 1]) in the checkpoint and torch.Size([1, 768, 1, 1]) in the model instantiated
- decode_head.classifier.bias: found shape torch.Size([150]) in the checkpoint and torch.Size([1]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


SegFormer-B2 loaded successfully.


<div class="alert alert-block alert-info">

----------
---------
# <b> 4. Dataset Definition</b> 

--------------
----------------
</div>

In [4]:
class SegFormerDataset(Dataset):
    def __init__(self, list_path):
        with open(list_path, "r") as f:
            self.items = [line.strip().split() for line in f.readlines()]

        self.mask_resize = transforms.Resize((512, 512), interpolation=Image.NEAREST)

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        img_path, mask_path = self.items[idx]

        image = Image.open(img_path).convert("RGB")
        mask  = Image.open(mask_path).convert("L")

        mask = self.mask_resize(mask)
        mask = (np.array(mask) > 127).astype(np.float32)
        mask = torch.tensor(mask).unsqueeze(0)

        encoded = processor(images=image, return_tensors="pt")
        img_tensor = encoded["pixel_values"].squeeze(0)

        return img_tensor, mask


<div class="alert alert-block alert-info">

----------
---------
# <b> 5. DataLoader Setup</b> 

--------------
----------------
</div>

In [5]:
train_dataset = SegFormerDataset(TRAIN_LIST)
val_dataset   = SegFormerDataset(VAL_LIST)
test_dataset  = SegFormerDataset(TEST_LIST)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
val_loader   = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
test_loader  = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=NUM_WORKERS)

print("Train:", len(train_dataset))
print("Val:", len(val_dataset))
print("Test:", len(test_dataset))


Train: 571
Val: 122
Test: 123


<div class="alert alert-block alert-info">

----------
---------
# <b> 6. Loss Functions</b> 

--------------
----------------
</div>

In [6]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, logits, targets):
        probas = torch.sigmoid(logits)
        pt = torch.where(targets == 1, probas, 1 - probas)
        focal_weight = self.alpha * (1 - pt) ** self.gamma
        loss = -focal_weight * torch.log(pt + 1e-8)
        return loss.mean()

bce_loss = FocalLoss(alpha=0.25, gamma=2)


In [7]:
def dice_loss(pred, target, eps=1e-6):
    pred = torch.sigmoid(pred)
    pred = pred.view(-1)
    target = target.view(-1)
    inter = (pred * target).sum()
    return 1 - (2 * inter + eps) / (pred.sum() + target.sum() + eps)


<div class="alert alert-block alert-info">

----------
---------
# <b>7. Evaluation Metrics</b> 

--------------
----------------
</div>

In [8]:
def iou_score(pred, true):
    pred = pred.bool()
    true = true.bool()
    inter = (pred & true).sum().item()
    union = (pred | true).sum().item()
    return inter / union if union > 0 else 1.0

def dice_score(pred, true):
    pred = pred.bool()
    true = true.bool()
    inter = (pred & true).sum().item()
    return (2 * inter) / (pred.sum().item() + true.sum().item() + 1e-6)


<div class="alert alert-block alert-info">

----------
---------
# <b> 8. Optimizer and Learning Rate Scheduler</b> 

--------------
----------------
</div>

In [9]:
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=EPOCHS,
    eta_min=1e-6
)


<div class="alert alert-block alert-info">

----------
---------
# <b> 9. Output Directories</b> 

--------------
----------------
</div>

In [10]:
save_dir = "../outputs/checkpoints/segformer_b2_finetuned"
os.makedirs(save_dir, exist_ok=True)

csv_path = os.path.join(save_dir, "training_history.csv")

with open(csv_path, "w", newline="") as f:
    writer = csv.writer(f)
    writer.writerow(["epoch", "train_loss", "val_loss", "val_iou", "val_dice"])

print("Saving to:", save_dir)


Saving to: ../outputs/checkpoints/segformer_b2_finetuned


<div class="alert alert-block alert-info">

----------
---------
# <b> 10. Training and Validation Loop</b> 

--------------
----------------
</div>

In [11]:
best_val_loss = float("inf")
no_improve = 0
early_stop_patience = 7

for epoch in range(EPOCHS):
    print(f"\n===== Epoch {epoch+1}/{EPOCHS} =====")

    # Adjust LR
    if epoch == 10:
        for g in optimizer.param_groups:
            g["lr"] = 3e-5
        print("LR ↓ to 3e-5")

    if epoch == 20:
        for g in optimizer.param_groups:
            g["lr"] = 1e-5
        print("LR ↓ to 1e-5")

    model.train()
    train_losses = []

    for imgs, masks in train_loader:
        imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)

        optimizer.zero_grad()
        logits = model(pixel_values=imgs).logits

        logits = F.interpolate(logits, size=(512,512), mode="bilinear", align_corners=False)

        fl = bce_loss(logits, masks)
        dl = dice_loss(logits, masks)
        loss = fl + dl

        loss.backward()
        optimizer.step()
        train_losses.append(loss.item())

    avg_train_loss = np.mean(train_losses)

    # VALIDATION
    model.eval()
    val_losses, val_ious, val_dices = [], [], []

    with torch.no_grad():
        for imgs, masks in val_loader:
            imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)

            logits = model(pixel_values=imgs).logits
            logits = F.interpolate(logits, size=(512,512), mode="bilinear", align_corners=False)

            fl = bce_loss(logits, masks)
            dl = dice_loss(logits, masks)
            vloss = fl + dl
            val_losses.append(vloss.item())

            preds = (torch.sigmoid(logits) > 0.5).cpu()
            masks_cpu = masks.cpu()

            for p, t in zip(preds, masks_cpu):
                val_ious.append(iou_score(p, t))
                val_dices.append(dice_score(p, t))

    avg_val_loss = np.mean(val_losses)
    avg_val_iou  = np.mean(val_ious)
    avg_val_dice = np.mean(val_dices)

    print(f"Train Loss: {avg_train_loss:.4f}")
    print(f"Val Loss:   {avg_val_loss:.4f}")
    print(f"Val IoU:    {avg_val_iou:.4f}")
    print(f"Val Dice:   {avg_val_dice:.4f}")

    scheduler.step()

    # Save best model
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        no_improve = 0
        torch.save(model.state_dict(), os.path.join(save_dir, "best_model.pth"))
        print("Best model saved.")
    else:
        no_improve += 1
        if no_improve >= early_stop_patience:
            print("Early stopping triggered.")
            break

print("\nTraining completed.")



===== Epoch 1/35 =====
Train Loss: 0.6169
Val Loss:   0.5040
Val IoU:    0.5550
Val Dice:   0.6940
Best model saved.

===== Epoch 2/35 =====
Train Loss: 0.4136
Val Loss:   0.3547
Val IoU:    0.6218
Val Dice:   0.7465
Best model saved.

===== Epoch 3/35 =====
Train Loss: 0.3067
Val Loss:   0.2980
Val IoU:    0.6592
Val Dice:   0.7791
Best model saved.

===== Epoch 4/35 =====
Train Loss: 0.2431
Val Loss:   0.2901
Val IoU:    0.6470
Val Dice:   0.7690
Best model saved.

===== Epoch 5/35 =====
Train Loss: 0.2105
Val Loss:   0.2573
Val IoU:    0.6786
Val Dice:   0.7947
Best model saved.

===== Epoch 6/35 =====
Train Loss: 0.1797
Val Loss:   0.2532
Val IoU:    0.6783
Val Dice:   0.7946
Best model saved.

===== Epoch 7/35 =====
Train Loss: 0.1644
Val Loss:   0.2525
Val IoU:    0.6734
Val Dice:   0.7893
Best model saved.

===== Epoch 8/35 =====
Train Loss: 0.1529
Val Loss:   0.2538
Val IoU:    0.6741
Val Dice:   0.7902

===== Epoch 9/35 =====
Train Loss: 0.1367
Val Loss:   0.2423
Val IoU:    

<div class="alert alert-block alert-info">

----------
---------
# <b>12. Test Set Evaluation</b> 

--------------
----------------
</div>

In [12]:
model.eval()
test_losses, test_ious, test_dices = [], [], []

with torch.no_grad():
    for imgs, masks in test_loader:
        imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)

        logits = model(pixel_values=imgs).logits
        logits = F.interpolate(logits, size=(512,512), mode="bilinear", align_corners=False)

        fl = bce_loss(logits, masks)
        dl = dice_loss(logits, masks)
        loss = fl + dl
        test_losses.append(loss.item())

        preds = (torch.sigmoid(logits) > 0.5).cpu()
        masks_cpu = masks.cpu()

        for p, t in zip(preds, masks_cpu):
            test_ious.append(iou_score(p, t))
            test_dices.append(dice_score(p, t))

print("\n===== TEST RESULTS =====")
print("Test Loss:", np.mean(test_losses))
print("Test IoU: ", np.mean(test_ious))
print("Test Dice:", np.mean(test_dices))



===== TEST RESULTS =====
Test Loss: 0.2323145622735828
Test IoU:  0.7234881565825485
Test Dice: 0.8231578442212272


<div class="alert alert-block alert-info">

----------
---------
# <b> 13. Saving Test Predictions</b> 

--------------
----------------
</div>

In [13]:
pred_dir = "../outputs/predictions/segformer_b2_finetuned_predictions"
os.makedirs(pred_dir, exist_ok=True)

print("Prediction folder:", pred_dir)

model.eval()

with torch.no_grad():
    for idx, (imgs, masks) in enumerate(test_loader):
        imgs = imgs.to(DEVICE)

        logits = model(pixel_values=imgs).logits
        logits = F.interpolate(logits, size=(512,512), mode="bilinear", align_corners=False)

        preds = (torch.sigmoid(logits) > 0.5).float().cpu().numpy()[0,0]

        pred_img = (preds * 255).astype(np.uint8)
        pred_pil = Image.fromarray(pred_img)

        img_path = test_dataset.items[idx][0]
        base_name = os.path.basename(img_path)
        name_no_ext = os.path.splitext(base_name)[0]

        pred_pil.save(os.path.join(pred_dir, f"{name_no_ext}_pred.png"))

print("All predictions saved.")


Prediction folder: ../outputs/predictions/segformer_b2_finetuned_predictions
All predictions saved.
