# Wound Segmentation and Detection using U-Net

This notebook presents an end-to-end deep learning pipeline for wound segmentation from medical images using a U-Net architecture.

The workflow includes:
- Data preprocessing
- Model training
- Inference on test images
- Visual comparison with ground truth masks



In [None]:
## 1. Environment Setup

#Google Drive is mounted to access training, testing, and model files stored remotely.

from google.colab import drive
drive.mount('path to drive')



In [None]:
## 2. Image Preprocessing (Single Image Demonstration)

#A single image is preprocessed to demonstrate the standardization pipeline before applying it to the full dataset.
import cv2
import numpy as np
import matplotlib.pyplot as plt

# -------------------------------
# CONFIG (COLAB + DRIVE)
# -------------------------------
IMAGE_PATH = "image path"
IMG_SIZE = 512   # SAME as original → no detail loss

# -------------------------------
# LOAD IMAGE (PRESERVE QUALITY)
# -------------------------------
img = cv2.imread(IMAGE_PATH, cv2.IMREAD_COLOR)
assert img is not None, "Image not found. Check path or filename."

img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

# -------------------------------
# VERY GENTLE STANDARDIZATION
# -------------------------------
def ultra_gentle_crop(image, crop_ratio=0.98):
    """
    Removes only tiny borders if present.
    Keeps >98% of original pixels.
    """
    h, w, _ = image.shape
    ch, cw = int(h * crop_ratio), int(w * crop_ratio)
    y1 = (h - ch) // 2
    x1 = (w - cw) // 2
    return image[y1:y1+ch, x1:x1+cw]

img_std = ultra_gentle_crop(img)

# -------------------------------
# RESIZE (HIGH QUALITY)
# -------------------------------
img_std = cv2.resize(
    img_std,
    (IMG_SIZE, IMG_SIZE),
    interpolation=cv2.INTER_CUBIC
)

# -------------------------------
# NORMALIZATION (MODEL READY)
# -------------------------------
img_preprocessed = img_std.astype(np.float32) / 255.0

# -------------------------------
# DISPLAY COMPARISON
# -------------------------------
plt.figure(figsize=(12,5))

plt.subplot(1,2,1)
plt.title("Original (512×512)")
plt.imshow(img)
plt.axis("off")

plt.subplot(1,2,2)
plt.title("Preprocessed (Detail Preserved)")
plt.imshow(img_preprocessed)
plt.axis("off")

plt.show()


In [None]:
## 3. Batch Preprocessing of Training Images

#All training images are standardized to a fixed resolution to ensure consistent input for the U-Net model.
import cv2
import numpy as np
import matplotlib.pyplot as plt
import os

# -------------------------------
# CONFIG
# -------------------------------r"
INPUT_DIR = "image path"
OUTPUT_DIR = "image path"
IMG_SIZE = 512

os.makedirs(OUTPUT_DIR, exist_ok=True)

# -------------------------------
# VERY GENTLE STANDARDIZATION
# -------------------------------
def ultra_gentle_crop(image, crop_ratio=0.98):
    h, w, _ = image.shape
    ch, cw = int(h * crop_ratio), int(w * crop_ratio)
    y1 = (h - ch) // 2
    x1 = (w - cw) // 2
    return image[y1:y1+ch, x1:x1+cw]

# -------------------------------
# PROCESS ALL IMAGES
# -------------------------------
for filename in os.listdir(INPUT_DIR):

    if not filename.lower().endswith((".png", ".jpg", ".jpeg")):
        continue

    input_path = os.path.join(INPUT_DIR, filename)

    # Load image
    img = cv2.imread(input_path, cv2.IMREAD_COLOR)
    if img is None:
        print(f"Skipping {filename} (cannot read)")
        continue

    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    # Standardization
    img_std = ultra_gentle_crop(img)

    # Resize (preserve quality)
    img_std = cv2.resize(
        img_std,
        (IMG_SIZE, IMG_SIZE),
        interpolation=cv2.INTER_CUBIC
    )

    # Normalize
    img_preprocessed = img_std.astype(np.float32) / 255.0

    # Save as lossless PNG
    save_path = os.path.join(OUTPUT_DIR, filename)
    save_img = (img_preprocessed * 255).astype(np.uint8)
    save_img = cv2.cvtColor(save_img, cv2.COLOR_RGB2BGR)

    cv2.imwrite(save_path, save_img)

print("✅ Preprocessing completed for all images.")


In [None]:
## 4. Preprocessing of Segmentation Masks

#Ground truth masks are preprocessed using the same spatial operations as images to maintain alignment.
#Nearest-neighbor interpolation is used to preserve label integrity.
import cv2
import numpy as np
import os

# -------------------------------
# CONFIG (MASKS)
# -------------------------------
INPUT_DIR = "image path"
OUTPUT_DIR = "image path"
IMG_SIZE = 512

os.makedirs(OUTPUT_DIR, exist_ok=True)

# -------------------------------
# SAME GENTLE CROP AS IMAGES
# -------------------------------
def ultra_gentle_crop_mask(mask, crop_ratio=0.98):
    h, w = mask.shape
    ch, cw = int(h * crop_ratio), int(w * crop_ratio)
    y1 = (h - ch) // 2
    x1 = (w - cw) // 2
    return mask[y1:y1+ch, x1:x1+cw]

# -------------------------------
# PROCESS ALL MASKS
# -------------------------------
for filename in os.listdir(INPUT_DIR):

    if not filename.lower().endswith((".png", ".jpg", ".jpeg")):
        continue

    input_path = os.path.join(INPUT_DIR, filename)

    # Load mask as GRAYSCALE
    mask = cv2.imread(input_path, cv2.IMREAD_GRAYSCALE)
    if mask is None:
        print(f"Skipping {filename} (cannot read)")
        continue

    # Crop (same logic as image)
    mask = ultra_gentle_crop_mask(mask)

    # Resize (NEAREST to preserve labels)
    mask = cv2.resize(
        mask,
        (IMG_SIZE, IMG_SIZE),
        interpolation=cv2.INTER_NEAREST
    )

    # Ensure binary (safety)
    _, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)

    # Save mask
    save_path = os.path.join(OUTPUT_DIR, filename)
    success = cv2.imwrite(save_path, mask)

    if not success:
        print(f"Failed to save {filename}")

print("✅ Mask preprocessing completed for all masks.")


In [None]:
import torch
print("GPU available:", torch.cuda.is_available())
print("Device:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")


In [None]:
## 6. Dependency Installation

#Required deep learning and utility libraries are installed.
!pip install torch torchvision tqdm


In [None]:
## 7. Custom Dataset Definition

#A PyTorch Dataset class is defined to load images and masks, normalize inputs, and return tensors suitable for training.
import os
import cv2
import torch
from torch.utils.data import Dataset

class LepraSegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.images = sorted(os.listdir(image_dir))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]

        img_path = os.path.join(self.image_dir, img_name)
        mask_path = os.path.join(self.mask_dir, img_name)

        # ---- Load image ----
        image = cv2.imread(img_path)
        if image is None:
            raise FileNotFoundError(f"Image not found: {img_name}")

        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = image.astype("float32") / 255.0
        image = torch.from_numpy(image).permute(2, 0, 1)

        # ---- Load mask ----
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        if mask is None:
            raise FileNotFoundError(f"Mask not found: {img_name}")

        mask = (mask > 0).astype("float32")
        mask = torch.from_numpy(mask).unsqueeze(0)

        return image, mask


In [None]:
## 8. Model Architecture – U-Net

#A U-Net encoder–decoder architecture with skip connections is implemented for pixel-level wound segmentation.
import torch.nn as nn

class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.block(x)


class UNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.d1 = DoubleConv(3, 64)
        self.d2 = DoubleConv(64, 128)
        self.d3 = DoubleConv(128, 256)
        self.d4 = DoubleConv(256, 512)

        self.pool = nn.MaxPool2d(2)

        self.mid = DoubleConv(512, 1024)

        self.u4 = nn.ConvTranspose2d(1024, 512, 2, 2)
        self.c4 = DoubleConv(1024, 512)

        self.u3 = nn.ConvTranspose2d(512, 256, 2, 2)
        self.c3 = DoubleConv(512, 256)

        self.u2 = nn.ConvTranspose2d(256, 128, 2, 2)
        self.c2 = DoubleConv(256, 128)

        self.u1 = nn.ConvTranspose2d(128, 64, 2, 2)
        self.c1 = DoubleConv(128, 64)

        self.out = nn.Conv2d(64, 1, 1)

    def forward(self, x):
        d1 = self.d1(x)
        d2 = self.d2(self.pool(d1))
        d3 = self.d3(self.pool(d2))
        d4 = self.d4(self.pool(d3))

        m = self.mid(self.pool(d4))

        x = self.c4(torch.cat([self.u4(m), d4], dim=1))
        x = self.c3(torch.cat([self.u3(x), d3], dim=1))
        x = self.c2(torch.cat([self.u2(x), d2], dim=1))
        x = self.c1(torch.cat([self.u1(x), d1], dim=1))

        return self.out(x)


In [None]:
## 9. Training Configuration

#The dataset, DataLoader, loss function, and optimizer are initialized for model training.
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch.optim as optim
import torch

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

IMAGE_DIR = "path of image"
MASK_DIR  = "image path"
MODEL_DIR = "image path"

os.makedirs(MODEL_DIR, exist_ok=True)

dataset = LepraSegmentationDataset(IMAGE_DIR, MASK_DIR)
loader = DataLoader(
    dataset,
    batch_size=2,        # safe for 512×512 on Colab GPU
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

model = UNet().to(DEVICE)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)


In [None]:
## 10. Model Training

#The U-Net model is trained for multiple epochs using binary cross-entropy loss.
#Model checkpoints are saved after each epoch.
EPOCHS = 30

for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0

    for images, masks in tqdm(loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
        images = images.to(DEVICE)
        masks = masks.to(DEVICE)

        outputs = model(images)
        loss = criterion(outputs, masks)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    avg_loss = running_loss / len(loader)
    print(f"Epoch {epoch+1}: Loss = {avg_loss:.4f}")

    # Save checkpoint
    torch.save(
        model.state_dict(),
        f"{MODEL_DIR}/unet_epoch_{epoch+1}.pth"
    )


In [None]:
## 11. Preprocessing of Test Images

#Test images are preprocessed using the same pipeline as training images to ensure consistency during inference.
import cv2
import numpy as np
import matplotlib.pyplot as plt
import os

# -------------------------------
# CONFIG
# -------------------------------r"
INPUT_DIR = "image path"
OUTPUT_DIR = "image path"
IMG_SIZE = 512

os.makedirs(OUTPUT_DIR, exist_ok=True)

# -------------------------------
# VERY GENTLE STANDARDIZATION
# -------------------------------
def ultra_gentle_crop(image, crop_ratio=0.98):
    h, w, _ = image.shape
    ch, cw = int(h * crop_ratio), int(w * crop_ratio)
    y1 = (h - ch) // 2
    x1 = (w - cw) // 2
    return image[y1:y1+ch, x1:x1+cw]

# -------------------------------
# PROCESS ALL IMAGES
# -------------------------------
for filename in os.listdir(INPUT_DIR):

    if not filename.lower().endswith((".png", ".jpg", ".jpeg")):
        continue

    input_path = os.path.join(INPUT_DIR, filename)

    # Load image
    img = cv2.imread(input_path, cv2.IMREAD_COLOR)
    if img is None:
        print(f"Skipping {filename} (cannot read)")
        continue

    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    # Standardization
    img_std = ultra_gentle_crop(img)

    # Resize (preserve quality)
    img_std = cv2.resize(
        img_std,
        (IMG_SIZE, IMG_SIZE),
        interpolation=cv2.INTER_CUBIC
    )

    # Normalize
    img_preprocessed = img_std.astype(np.float32) / 255.0

    # Save as lossless PNG
    save_path = os.path.join(OUTPUT_DIR, filename)
    save_img = (img_preprocessed * 255).astype(np.uint8)
    save_img = cv2.cvtColor(save_img, cv2.COLOR_RGB2BGR)

    cv2.imwrite(save_path, save_img)

print("✅ Preprocessing completed for all images.")


In [None]:
## 12. Model Loading for Inference

#The trained U-Net model checkpoint is loaded and set to evaluation mode for inference on unseen test images.
import torch

# -------------------------------
# CONFIG
# -------------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

MODEL_PATH = "path"
# (you can change epoch number if needed)

# -------------------------------
# LOAD MODEL
# -------------------------------
model = UNet().to(DEVICE)

# Load trained weights
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))

# Set model to evaluation mode
model.eval()

print("✅ Model loaded successfully and ready for inference")


In [None]:
## 13. Inference on Test Images

#The trained model is applied to preprocessed test images to generate predicted wound segmentation masks.
import os
import cv2
import torch
import numpy as np
from tqdm import tqdm

# -------------------------------
# CONFIG
# -------------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

TEST_IMAGE_DIR = "path"
OUTPUT_MASK_DIR = "path"
MODEL_PATH = "path"

os.makedirs(OUTPUT_MASK_DIR, exist_ok=True)

# -------------------------------
# LOAD MODEL (already trained)
# -------------------------------
model = UNet().to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

print("✅ Model loaded for inference")

# -------------------------------
# INFERENCE LOOP
# -------------------------------
with torch.no_grad():
    for filename in tqdm(os.listdir(TEST_IMAGE_DIR), desc="Running inference"):

        if not filename.lower().endswith((".png", ".jpg", ".jpeg")):
            continue

        img_path = os.path.join(TEST_IMAGE_DIR, filename)

        # Load preprocessed image
        image = cv2.imread(img_path)
        if image is None:
            print(f"Skipping {filename}")
            continue

        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = image.astype("float32") / 255.0

        # Convert to tensor
        image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0)
        image = image.to(DEVICE)

        # Forward pass
        output = model(image)

        # Apply sigmoid + threshold
        pred_mask = torch.sigmoid(output)
        pred_mask = (pred_mask > 0.5).float()

        # Convert to NumPy
        pred_mask = pred_mask.squeeze().cpu().numpy()
        pred_mask = (pred_mask * 255).astype(np.uint8)

        # Save predicted mask
        save_path = os.path.join(OUTPUT_MASK_DIR, filename)
        cv2.imwrite(save_path, pred_mask)

print("✅ Inference completed. Predicted masks saved.")


In [None]:
import os
import cv2
import matplotlib.pyplot as plt
import numpy as np

# -------------------------------
# CONFIG
# -------------------------------
TEST_IMAGE_DIR = "path"
PRED_MASK_DIR  = "path"

# Pick one sample image
filename = sorted(os.listdir(TEST_IMAGE_DIR))[0]

img_path = os.path.join(TEST_IMAGE_DIR, filename)
mask_path = os.path.join(PRED_MASK_DIR, filename)

# -------------------------------
# LOAD IMAGE & MASK
# -------------------------------
image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

# -------------------------------
# CREATE OVERLAY
# -------------------------------
overlay = image.copy()
overlay[mask == 255] = [255, 0, 0]  # red color for ulcer

# -------------------------------
# DISPLAY RESULTS
# -------------------------------
plt.figure(figsize=(15,5))

plt.subplot(1,3,1)
plt.title("Test Image")
plt.imshow(image)
plt.axis("off")

plt.subplot(1,3,2)
plt.title("Predicted Mask")
plt.imshow(mask, cmap="gray")
plt.axis("off")

plt.subplot(1,3,3)
plt.title("Overlay (Red = Ulcer)")
plt.imshow(overlay)
plt.axis("off")

plt.show()


In [None]:
## 15. Comparison with Ground Truth Masks

#Predicted masks are compared with ground truth masks to evaluate spatial alignment and segmentation accuracy.
import os
import cv2
import matplotlib.pyplot as plt
import numpy as np

# -------------------------------
# CONFIG
# -------------------------------
TEST_IMAGE_DIR = "path"
GT_MASK_DIR    = "path"
PRED_MASK_DIR  = "path"

# Pick one sample
filename = sorted(os.listdir(TEST_IMAGE_DIR))[0]

# -------------------------------
# LOAD IMAGE & MASKS
# -------------------------------
img_path  = os.path.join(TEST_IMAGE_DIR, filename)
gt_path   = os.path.join(GT_MASK_DIR, filename)
pred_path = os.path.join(PRED_MASK_DIR, filename)

image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

gt_mask = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE)
pred_mask = cv2.imread(pred_path, cv2.IMREAD_GRAYSCALE)

# -------------------------------
# CREATE OVERLAY (PREDICTION)
# -------------------------------
overlay = image.copy()
overlay[pred_mask == 255] = [255, 0, 0]  # red = predicted ulcer

# -------------------------------
# DISPLAY COMPARISON
# -------------------------------
plt.figure(figsize=(20,5))

plt.subplot(1,4,1)
plt.title("Test Image")
plt.imshow(image)
plt.axis("off")

plt.subplot(1,4,2)
plt.title("Ground Truth Mask")
plt.imshow(gt_mask, cmap="gray")
plt.axis("off")

plt.subplot(1,4,3)
plt.title("Predicted Mask")
plt.imshow(pred_mask, cmap="gray")
plt.axis("off")

plt.subplot(1,4,4)
plt.title("Overlay (Prediction)")
plt.imshow(overlay)
plt.axis("off")

plt.show()


In [None]:
import matplotlib.pyplot as plt

epochs = list(range(1, 31))
losses = [
    0.2283, 0.1364, 0.1136, 0.1035, 0.0954,
    0.0923, 0.0854, 0.0844, 0.0821, 0.0759,
    0.0755, 0.0745, 0.0704, 0.0666, 0.0682,
    0.0650, 0.0651, 0.0606, 0.0634, 0.0597,
    0.0576, 0.0561, 0.0577, 0.0533, 0.0528,
    0.0514, 0.0499, 0.0524, 0.0481, 0.0484
]

plt.figure(figsize=(8,5))
plt.plot(epochs, losses, marker='o')
plt.xlabel("Epoch")
plt.ylabel("Training Loss")
plt.title("U-Net Training Loss Curve")
plt.grid(True)
plt.show()


In [None]:
import numpy as np

def dice_score(pred, gt):
    pred = pred > 0
    gt = gt > 0
    intersection = (pred & gt).sum()
    return 2 * intersection / (pred.sum() + gt.sum() + 1e-8)


In [None]:
dice = dice_score(pred_mask, gt_mask)
print("Dice Score:", dice)
