# Segmenting Images using YOLO model

## Training Semantic Segmentation Model

### Agumenting the images

In [7]:
from PIL import ImageEnhance
import random
import os
import matplotlib.pyplot as plt
from PIL import Image

def augment_image(img):
    augmented = []
    # 1. Horizontal flip
    augmented.append(img.transpose(Image.FLIP_LEFT_RIGHT))
    # 2. Vertical flip
    augmented.append(img.transpose(Image.FLIP_TOP_BOTTOM))
    # 3. Rotate 90 degrees
    augmented.append(img.rotate(90, expand=True))
    # 4. Change saturation
    enhancer = ImageEnhance.Color(img)
    augmented.append(enhancer.enhance(random.uniform(0.1, 2)))
    # 5. Change brightness
    enhancer = ImageEnhance.Brightness(img)
    augmented.append(enhancer.enhance(random.uniform(0.1, 2)))
    return augmented

def visualize_augmentation(img):
    augmented = augment_image(img)
    fig, axes = plt.subplots(1, len(augmented) + 1, figsize=(15, 4))
    axes[0].imshow(img)
    axes[0].set_title('Original')
    axes[0].axis('off')
    for i, aug_img in enumerate(augmented):
        axes[i+1].imshow(aug_img)
        axes[i+1].set_title(f'Aug {i+1}')
        axes[i+1].axis('off')
    plt.tight_layout()
    plt.show()

def augment_and_save(base_dir, visualize=False):
    #for split in ['train', 'valid']:
    for split in ['valid']:
        for subfolder in ['images', 'masks']:
            folder = os.path.join(base_dir, split, subfolder)
            for fname in os.listdir(folder):
                if not fname.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff')):
                    continue
                fpath = os.path.join(folder, fname)
                with Image.open(fpath) as img:
                    # For masks, only apply geometric transforms (no color/brightness)
                    if subfolder == 'masks':
                        # For masks, only apply geometric transforms (flips and 90-degree rotation), 
                        # for Aug 4 and 5 just copy the original mask (no-op)
                        augmented = [
                            img.transpose(Image.FLIP_LEFT_RIGHT),
                            img.transpose(Image.FLIP_TOP_BOTTOM),
                            img.rotate(90, expand=True),
                            img.copy(),  # Aug 4: just copy the original mask
                            img.copy()   # Aug 5: just copy the original mask
                        ]
                        if visualize:
                            fig, axes = plt.subplots(1, len(augmented) + 1, figsize=(15, 4))
                            axes[0].imshow(img)
                            axes[0].set_title('Original Mask')
                            axes[0].axis('off')
                            for i, aug_img in enumerate(augmented):
                                axes[i+1].imshow(aug_img)
                                axes[i+1].set_title(f'Aug {i+1}')
                                axes[i+1].axis('off')
                            plt.tight_layout()
                            plt.show()
                            # Only visualize the first mask and return
                            return
                    else:
                        augmented = augment_image(img)
                        if visualize:
                            visualize_augmentation(img)
                            # Only visualize the first image and return
                            return
                    for i, aug_img in enumerate(augmented):
                        name, ext = os.path.splitext(fname)
                        aug_fname = f"{name}_aug{i+1}{ext}"
                        aug_img.save(os.path.join(folder, aug_fname))

In [8]:
augment_and_save(base_dir='../images_train', visualize=False)

### Creating Ground Truth Masks

In [4]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt

def create_masks_from_yolov8(images_folder, labels_folder, masks_folder, img_ext='jpg', label_ext='txt', show_example=True):
    os.makedirs(masks_folder, exist_ok=True)
    image_files = [f for f in os.listdir(images_folder) if f.endswith(f'.{img_ext}')]
    for img_file in image_files:
        img_path = os.path.join(images_folder, img_file)
        label_file = img_file.replace(f'.{img_ext}', f'.{label_ext}')
        label_path = os.path.join(labels_folder, label_file)
        img = cv2.imread(img_path)
        h, w = img.shape[:2]
        mask = np.zeros((h, w), dtype=np.uint8)
        if os.path.exists(label_path):
            with open(label_path, 'r') as f:
                for line in f:
                    parts = line.strip().split()
                    if len(parts) < 2:
                        continue
                    # Yolov8 format: class_id x_center y_center width height [polygon points...]
                    if len(parts) > 5:
                        # polygon mask (segmentation)
                        pts = np.array(parts[1:], dtype=float).reshape(-1, 2)
                        pts[:, 0] *= w
                        pts[:, 1] *= h
                        pts = pts.astype(np.int32)
                        cv2.fillPoly(mask, [pts], 255)
                    else:
                        # bbox mask
                        _, x, y, bw, bh = map(float, parts[:5])
                        x1 = int((x - bw/2) * w)
                        y1 = int((y - bh/2) * h)
                        x2 = int((x + bw/2) * w)
                        y2 = int((y + bh/2) * h)
                        cv2.rectangle(mask, (x1, y1), (x2, y2), 255, -1)
        # Show example
        if show_example:
            plt.figure(figsize=(10,5))
            plt.subplot(1,2,1)
            plt.title('Image')
            plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
            plt.axis('off')
            plt.subplot(1,2,2)
            plt.title('Mask')
            plt.imshow(mask, cmap='gray')
            plt.axis('off')
            plt.show()
            break
        else:
            mask_path = os.path.join(masks_folder, img_file.replace(f'.{img_ext}', '.png'))
            cv2.imwrite(mask_path, mask)

In [6]:
create_masks_from_yolov8('../images_train/valid/images', '../images_train/valid/labels', '../images_train/valid/masks', show_example=False)

### Loading the database

In [1]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.functional as TF
import numpy as np

IMG_SIZE = 512
BATCH_SIZE = 8
NUM_CLASSES = 1

dataset_path = "../images_train/"

# --- Custom Dataset for Segmentation ---
class SegmentationDataset(Dataset):
    def __init__(self, root_dir, train=True):
        self.root_dir = root_dir
        self.train = train
        self.img_dir = os.path.join(root_dir, "train/images" if train else "valid/images")
        self.mask_dir = os.path.join(root_dir, "train/masks" if train else "valid/masks")

        self.image_files = sorted([
            f for f in os.listdir(self.img_dir)
            if f.lower().endswith(('.png', '.jpg', '.jpeg'))
        ])
        self.mask_files = sorted([
            f for f in os.listdir(self.mask_dir)
            if f.lower().endswith(('.png', '.jpg', '.jpeg'))
        ])

        assert len(self.image_files) == len(self.mask_files), \
            f"Image/Mask count mismatch: {len(self.image_files)} vs {len(self.mask_files)}"

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.image_files[idx])
        mask_path = os.path.join(self.mask_dir, self.mask_files[idx])

        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")  # grayscale mask

        # --- To Tensor ---
        image = TF.to_tensor(image)
        image = TF.normalize(image, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

        # Mask â†’ tensor (no normalization)
        mask = torch.from_numpy(np.array(mask)).float().unsqueeze(0)  # [1, H, W]
        mask = (mask > 0).float()

        return image, mask

# --- Datasets & Loaders ---
train_dataset = SegmentationDataset(dataset_path, train=True)
test_dataset = SegmentationDataset(dataset_path, train=False)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

print(f"âœ… Train size: {len(train_dataset)} | Test size: {len(test_dataset)}")

âœ… Train size: 4788 | Test size: 546


### Importing Models

In [2]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(device)

cuda


In [11]:
import sys
sys.path.append('../')
from utils.models.uNet import UNet

channels = [32, 64, 128, 256, 512]
model = UNet(in_channels=3, out_channels=1, channels=channels, bilinear=True, use_batchnorm=True)
model.to(device)
modelName = "U-NET"

In [3]:
import sys
sys.path.append('../')
from utils.models.deeplabv3p import DeepLabV3Plus

model = DeepLabV3Plus(num_classes=1, output_stride=16, backbone_width_mult=1.0).to(device)
model.to(device)
modelName = "DeepLabV3Plus"

In [3]:
import sys
sys.path.append('../')
from utils.models.SegFormer import segformer

model = segformer(in_channels = 3, num_classes = 1)
model.to(device)
modelName = "SegFormer"

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import sys
sys.path.append('../')
from utils.models.SegNet import segnet

model = segnet(in_channels=3, num_classes=1, pretrained=False)
model.to(device)
modelName = "SegNet"

In [3]:
import sys
sys.path.append('../')

from utils.models.maskFormer import MaskFormer
from utils.models.resnet101 import resnet101_backbone

resnet101 = resnet101_backbone()
model = MaskFormer(
    backbone=resnet101, 
    num_classes=1, 
    num_queries=5,          # Reduced from 10
    embed_dim=64,           # Reduced from 128  
    transformer_layers=1,   # Reduced from 2
    transformer_heads=2,    # Reduced from 4
    transformer_ffn_dim=256, # Reduced from 512
    return_binary=True
).to(device)
modelName = "MaskFormer"



### Training

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import os
from tqdm import tqdm
import matplotlib.pyplot as plt

In [5]:
# --- Configuration ---
EPOCHS = 200
early_stopping_patience = 5
best_val_loss = float('inf') 
patience_counter = 0
scaler = torch.amp.GradScaler('cuda')

# --- Loss & Optimizer ---
loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3)

checkpoint_path = os.path.join("../models/", modelName + "_seg.pt")
print(f"Model checkpoints will be saved to: {checkpoint_path}")

# --- Tracking ---
train_losses, val_losses = [], []

Model checkpoints will be saved to: ../models/MaskFormer_seg.pt


In [None]:
# --- Training Loop ---
for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0

    with tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}", unit="batch") as tepoch:
        for inputs, labels in tepoch:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            with torch.amp.autocast('cuda'):
                outputs = model(inputs)
                if torch.isnan(outputs).any() or torch.isinf(outputs).any():
                    print("NaN or Inf in model outputs!")
                if torch.isnan(labels).any() or torch.isinf(labels).any():
                    print("NaN or Inf in labels!")
                if labels.sum() == 0:
                    print("Skipping batch with all empty masks")
                    continue
                loss = loss_fn(outputs, labels)
            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            running_loss += loss.item()
            tepoch.set_postfix(loss=running_loss / (len(tepoch)))
    avg_train_loss = running_loss / len(train_loader)
    train_losses.append(avg_train_loss)

    # --- Validation Phase ---
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        with tqdm(test_loader, desc=f"Epoch {epoch+1}/{EPOCHS} - Validation", unit="batch") as vepoch:
            for inputs, labels in vepoch:
                inputs, labels = inputs.to(device), labels.to(device)
                with torch.amp.autocast('cuda'):
                    outputs = model(inputs)
                    if torch.isnan(outputs).any() or torch.isinf(outputs).any():
                        print("NaN or Inf in model outputs!")
                    if torch.isnan(labels).any() or torch.isinf(labels).any():
                        print("NaN or Inf in labels!")
                    if labels.sum() == 0:
                        print("Skipping batch with all empty masks")
                        continue
                    loss = loss_fn(outputs, labels)
                    preds = (torch.sigmoid(outputs) > 0.5).float()
                val_correct += (preds == labels).sum().item()
                val_total += labels.numel()
                val_acc = val_correct / val_total if val_total > 0 else 0
                val_loss += loss.item()
                vepoch.set_postfix(loss=val_loss / (len(vepoch)))

    avg_val_loss = val_loss / len(test_loader)
    val_losses.append(avg_val_loss)

    print(f"Epoch {epoch+1} - Val Loss: {avg_val_loss:.4f} - Val Acc: {val_acc:.4f}%")

    # --- Scheduler step (use validation loss) ---
    lr_scheduler.step(avg_val_loss)

    # --- Checkpointing ---
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        patience_counter = 0
        torch.save(model.state_dict(), checkpoint_path)
        print(f"âœ… Model improved (Val Loss={avg_val_loss:.3f}), saved to {checkpoint_path}")
    else:
        patience_counter += 1
        print(f"ðŸ”´ No improvement, patience counter: {patience_counter}")

    if patience_counter >= early_stopping_patience:
        print("ðŸ›‘ Early stopping triggered!")
        break

In [None]:
# --- Plot losses ---
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.title('Segmentation Training & Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

### Inference on Test Images

In [None]:
import random

# Pick 5 random indices from the test dataset
indices = random.sample(range(len(test_dataset)), 5)

fig, axs = plt.subplots(5, 3, figsize=(12, 18))
for i, idx in enumerate(indices):
    img, mask = test_dataset[idx]
    img_input = img.unsqueeze(0).to(device)
    model.eval()
    with torch.no_grad():
        with torch.cuda.amp.autocast():
            pred_logits = model(img_input)
            pred_mask = (torch.sigmoid(pred_logits) > 0.5).float().cpu().squeeze().numpy()
    axs[i, 0].imshow(img.permute(1, 2, 0).cpu() * 0.5 + 0.5)
    axs[i, 0].set_title("Input Image")
    axs[i, 1].imshow(mask.squeeze().cpu(), cmap='gray')
    axs[i, 1].set_title("Ground Truth Mask")
    axs[i, 2].imshow(pred_mask, cmap='gray')
    axs[i, 2].set_title("Predicted Mask")
    for j in range(3):
        axs[i, j].axis('off')
plt.tight_layout()
plt.show()


### Choosing the best model

In [10]:
import cv2
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.functional as TF

class SegEvalDataset(Dataset):
    def __init__(self, image_dir, mask_dir, image_files):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_files = image_files

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        mask_name = os.path.splitext(img_name)[0] + ".png"
        image = cv2.imread(os.path.join(self.image_dir, img_name))
        mask = cv2.imread(os.path.join(self.mask_dir, mask_name), cv2.IMREAD_GRAYSCALE)
        if mask.ndim == 3:
            mask = mask[..., 0]
        image = TF.to_tensor(image)
        image = TF.normalize(image, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        mask = (mask > 0).astype(np.float32)
        mask = torch.from_numpy(mask).float().unsqueeze(0)
        return image, mask, img_name

In [3]:
# ======================
# Utility functions
# ======================

def iou_score(y_true, y_pred):
    intersection = np.sum(y_true * y_pred)
    union = np.sum(y_true) + np.sum(y_pred) - intersection
    return intersection / (union + 1e-7)

def load_model(model_name, device):
    if model_name == "U-NET":
        channels = [32, 64, 128, 256, 512]
        model = UNet(in_channels=3, out_channels=1, channels=channels, bilinear=True, use_batchnorm=True).to(device)
    elif model_name == "SegFormer":
        model = segformer(in_channels = 3, num_classes = 1).to(device)
    elif model_name == "DeepLabV3Plus":
        model = DeepLabV3Plus(num_classes=1, output_stride=16, backbone_width_mult=1.0).to(device)
    elif model_name == "SegNet":
        model = segnet(in_channels=3, num_classes=1, pretrained=False).to(device)
    elif model_name == "MaskFormer":
        resnet101 = resnet101_backbone()
        model = MaskFormer(
            backbone=resnet101, 
                num_classes=1, 
                num_queries=5,
                embed_dim=64, 
                transformer_layers=1,
                transformer_heads=2,
                transformer_ffn_dim=256,
                return_binary=True
            ).to(device)
    else:
        raise ValueError(f"Unknown model name: {model_name}")
    model.load_state_dict(torch.load(f'../models/{model_name}_seg.pt', map_location=device))
    model.eval()
    return model

def predict_mask(model, image, device):
    img_input = image.unsqueeze(0).to(device)
    with torch.no_grad():
        with torch.amp.autocast('cuda'):
            pred_logits = model(img_input)
            pred_mask = (torch.sigmoid(pred_logits) > 0.5).float().cpu().squeeze().numpy()
    return pred_mask

In [6]:
import sys
sys.path.append('../')
import torch
from utils.models.maskFormer import MaskFormer
from utils.models.resnet101 import resnet101_backbone
from utils.models.uNet import UNet
from utils.models.SegNet import segnet
from utils.models.SegFormer import segformer
from utils.models.deeplabv3p import DeepLabV3Plus

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data_dir = "../images_train/"
models = ["DeepLabV3Plus", "SegFormer", "SegNet", "U-NET", "MaskFormer"]

In [13]:
# ======================
# Main evaluation
# ======================

import pandas as pd
import os
from tqdm import tqdm
import numpy as np
results = []

BATCH_SIZE = 8

image_dir = os.path.join(data_dir,"valid", "images")
mask_dir = os.path.join(data_dir,"valid", "masks")
image_files = [f for f in os.listdir(image_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]

dataset = SegEvalDataset(image_dir, mask_dir, image_files)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)

for model_name in models:
    print(f"Evaluating {model_name}...")
    model = load_model(model_name, device)
    model.eval()

    for images, masks, img_names in tqdm(loader):
        images = images.to(device)
        masks_np = masks.squeeze(1).cpu().numpy()  # [B, H, W]
        with torch.no_grad():
            with torch.amp.autocast('cuda'):
                logits = model(images)
                preds = (torch.sigmoid(logits) > 0.5).float().cpu().numpy().squeeze(1)  # [B, H, W]

        for i in range(images.size(0)):
            mask = masks_np[i]
            pred = preds[i]
            iou = iou_score(mask, pred)
            results.append({
                "modelname": model_name,
                "IoU": iou
            })

# Aggregate results
df = pd.DataFrame(results)
summary = df.groupby("modelname")["IoU"].agg(['mean', 'std', 'count'])
summary['coef_var'] = summary['std'] / (summary['mean'] + 1e-8)

# Display results
print("Model\tMean IoU\tCoeff. of Variation")
for idx, row in summary.iterrows():
    print(f"{idx}\t{row['mean']:.4f}\t{row['coef_var']:.4f}")

Evaluating DeepLabV3Plus...


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 69/69 [00:12<00:00,  5.71it/s]


Evaluating SegFormer...


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 69/69 [00:10<00:00,  6.42it/s]


Evaluating SegNet...


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 69/69 [00:16<00:00,  4.18it/s]


Evaluating U-NET...


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 69/69 [00:14<00:00,  4.67it/s]


Evaluating MaskFormer...


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 69/69 [00:13<00:00,  5.19it/s]

Model	Mean IoU	Coeff. of Variation
DeepLabV3Plus	0.9429	0.0306
MaskFormer	0.9386	0.0314
SegFormer	0.9359	0.0402
SegNet	0.9298	0.0456
U-NET	0.9454	0.0309





| Model | Mean IoU | Coeff. of Variation |
| - | - | - |
| DeepLabV3Plus | 0.9429 | 0.0306 |
| MaskFormer | 0.9386 | 0.0314 |
| SegFormer	| 0.9359	| 0.0402
| SegNet	| 0.9298 | 0.0456 |
| U-NET	| 0.9454	| 0.0309|

U-NET is the chosen one

## Using the model to get segmented images

Segmenting the images

In [2]:
import torch
import cv2
import os
import numpy as np
import sys
sys.path.append('../')
from utils.models.uNet import UNet
import torchvision.transforms.functional as TF

# Load U-NET model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
channels = [32, 64, 128, 256, 512]
model = UNet(in_channels=3, out_channels=1, channels=channels, bilinear=True, use_batchnorm=True).to(device)
model.load_state_dict(torch.load("../models/U-NET_seg.pt", map_location=device))
model.eval()

image_folder = "../images"
output_folder = "../segmented_images"
os.makedirs(output_folder, exist_ok=True)

for filename in os.listdir(image_folder):
    if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
        image_path = os.path.join(image_folder, filename)
        image = cv2.imread(image_path)
        image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image_tensor = TF.to_tensor(image_rgb)
        image_tensor = TF.normalize(image_tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        image_tensor = image_tensor.unsqueeze(0).to(device)

        with torch.no_grad():
            with torch.cuda.amp.autocast():
                pred_logits = model(image_tensor)
                pred_mask = (torch.sigmoid(pred_logits) > 0.5).float().cpu().squeeze().numpy()

        # Save mask as PNG
        mask_path = os.path.join(output_folder, filename.replace('.jpg', '.png').replace('.jpeg', '.png').replace('.png', '_mask.png'))
        cv2.imwrite(mask_path, (pred_mask * 255).astype(np.uint8))

  with torch.cuda.amp.autocast():


KeyboardInterrupt: 