In [1]:
import os
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm
from torchvision import transforms
from torch.optim.lr_scheduler import StepLR
import numpy as np



In [2]:
# -----------------------------
# U-Net++ Architecture
# -----------------------------
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

In [3]:
class UNetPP(nn.Module):
    def __init__(self, in_channels=3, out_channels=1):
        super(UNetPP, self).__init__()
        self.conv0_0 = ConvBlock(in_channels, 64)
        self.conv1_0 = ConvBlock(64, 128)
        self.conv2_0 = ConvBlock(128, 256)
        self.conv3_0 = ConvBlock(256, 512)
        self.conv4_0 = ConvBlock(512, 1024)

        self.up1_0 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.up2_0 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.up3_0 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.up4_0 = nn.ConvTranspose2d(1024, 512, 2, stride=2)

        self.conv0_1 = ConvBlock(64 + 64, 64)
        self.conv1_1 = ConvBlock(128 + 128, 128)
        self.conv2_1 = ConvBlock(256 + 256, 256)
        self.conv3_1 = ConvBlock(512 + 512, 512)

        self.up1_1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.up2_1 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.up3_1 = nn.ConvTranspose2d(512, 256, 2, stride=2)

        self.conv0_2 = ConvBlock(64 * 3, 64)
        self.conv1_2 = ConvBlock(128 * 3, 128)
        self.conv2_2 = ConvBlock(256 * 3, 256)

        self.up1_2 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.up2_2 = nn.ConvTranspose2d(256, 128, 2, stride=2)

        self.conv0_3 = ConvBlock(64 * 4, 64)
        self.conv1_3 = ConvBlock(128 * 4, 128)

        self.up1_3 = nn.ConvTranspose2d(128, 64, 2, stride=2)

        self.conv0_4 = ConvBlock(64 * 5, 64)

        self.final = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        x0_0 = self.conv0_0(x)
        x1_0 = self.conv1_0(nn.MaxPool2d(2)(x0_0))
        x0_1 = self.conv0_1(torch.cat([x0_0, self.up1_0(x1_0)], 1))

        x2_0 = self.conv2_0(nn.MaxPool2d(2)(x1_0))
        x1_1 = self.conv1_1(torch.cat([x1_0, self.up2_0(x2_0)], 1))
        x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up1_1(x1_1)], 1))

        x3_0 = self.conv3_0(nn.MaxPool2d(2)(x2_0))
        x2_1 = self.conv2_1(torch.cat([x2_0, self.up3_0(x3_0)], 1))
        x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up2_1(x2_1)], 1))
        x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up1_2(x1_2)], 1))

        x4_0 = self.conv4_0(nn.MaxPool2d(2)(x3_0))
        x3_1 = self.conv3_1(torch.cat([x3_0, self.up4_0(x4_0)], 1))
        x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up3_1(x3_1)], 1))
        x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up2_2(x2_2)], 1))
        x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up1_3(x1_3)], 1))

        return self.final(x0_4)

In [4]:
class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_names = sorted([f for f in os.listdir(image_dir) if f.endswith('.png')])
        self.transform = transform

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        image_name = self.image_names[idx]
        mask_name = image_name.replace('.png', '_mask.png')

        image = Image.open(os.path.join(self.image_dir, image_name)).convert("RGB")
        mask = Image.open(os.path.join(self.mask_dir, mask_name)).convert("L")

        if self.transform:
            seed = np.random.randint(0, 10000)
            torch.manual_seed(seed)
            image = self.transform(image)

            torch.manual_seed(seed)
            mask = self.transform(mask)

        mask = (mask > 0.5).float()
        return image, mask


In [5]:
# Training Loop
# -----------------------------
def train(model, dataloader, optimizer, loss_fn, device):
    model.train()
    running_loss = 0.0
    for images, masks in tqdm(dataloader, desc="Training", leave=False):
        images = images.to(device)
        masks = masks.to(device)

        outputs = model(images)
        loss = loss_fn(outputs, masks)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    return running_loss / len(dataloader)

In [6]:
def main():
    image_dir = r"D:\\Indira\\Pooja_segmentation\\input_image"
    mask_dir = r"D:\\Indira\\Pooja_segmentation\\input_mask"
    batch_size = 2
    epochs = 50
    lr = 1e-5
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # ✅ Define augmentation transforms here
    train_transforms = transforms.Compose([
        transforms.Resize((512, 512)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.RandomRotation(degrees=15),
        transforms.ToTensor()
    ])

    # ✅ Pass transform into your dataset
    dataset = SegmentationDataset(image_dir, mask_dir, transform=train_transforms)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    model = UNetPP(in_channels=3, out_channels=1).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = StepLR(optimizer, step_size=5, gamma=0.5)
    loss_fn = nn.BCEWithLogitsLoss()

    for epoch in range(epochs):
        loss = train(model, dataloader, optimizer, loss_fn, device)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {loss:.4f}")
        scheduler.step()

    torch.save(model.state_dict(), "unetpp_model.pth")


In [7]:
if __name__ == '__main__':
    main()

Using device: cuda


                                                         

Epoch 1/50, Loss: 0.4366


                                                         

Epoch 2/50, Loss: 0.3481


                                                         

Epoch 3/50, Loss: 0.3070


                                                         

Epoch 4/50, Loss: 0.2739


                                                         

Epoch 5/50, Loss: 0.2612


                                                         

Epoch 6/50, Loss: 0.2397


                                                         

Epoch 7/50, Loss: 0.2339


                                                         

Epoch 8/50, Loss: 0.2266


                                                         

Epoch 9/50, Loss: 0.2248


                                                         

Epoch 10/50, Loss: 0.2191


                                                         

Epoch 11/50, Loss: 0.2155


                                                         

Epoch 12/50, Loss: 0.2148


                                                         

Epoch 13/50, Loss: 0.2139


                                                         

Epoch 14/50, Loss: 0.2108


                                                         

Epoch 15/50, Loss: 0.2101


                                                         

Epoch 16/50, Loss: 0.2093


                                                         

Epoch 17/50, Loss: 0.2066


                                                         

Epoch 18/50, Loss: 0.2083


                                                         

Epoch 19/50, Loss: 0.2062


                                                         

Epoch 20/50, Loss: 0.2076


                                                         

Epoch 21/50, Loss: 0.2055


                                                         

Epoch 22/50, Loss: 0.2066


                                                         

Epoch 23/50, Loss: 0.2054


                                                         

Epoch 24/50, Loss: 0.2047


                                                         

Epoch 25/50, Loss: 0.2039


                                                         

Epoch 26/50, Loss: 0.2050


                                                         

Epoch 27/50, Loss: 0.2048


                                                         

Epoch 28/50, Loss: 0.2025


                                                         

Epoch 29/50, Loss: 0.2038


                                                         

Epoch 30/50, Loss: 0.2046


                                                         

Epoch 31/50, Loss: 0.2045


                                                         

Epoch 32/50, Loss: 0.2039


                                                         

Epoch 33/50, Loss: 0.2030


                                                         

Epoch 34/50, Loss: 0.2031


                                                         

Epoch 35/50, Loss: 0.2032


                                                         

Epoch 36/50, Loss: 0.2036


                                                         

Epoch 37/50, Loss: 0.2014


                                                         

Epoch 38/50, Loss: 0.2034


                                                         

Epoch 39/50, Loss: 0.2037


                                                         

Epoch 40/50, Loss: 0.2031


                                                         

Epoch 41/50, Loss: 0.2039


                                                         

Epoch 42/50, Loss: 0.2033


                                                         

Epoch 43/50, Loss: 0.2019


                                                         

Epoch 44/50, Loss: 0.2026


                                                         

Epoch 45/50, Loss: 0.2037


                                                         

Epoch 46/50, Loss: 0.2030


                                                         

Epoch 47/50, Loss: 0.2026


                                                         

Epoch 48/50, Loss: 0.2018


                                                         

Epoch 49/50, Loss: 0.2026


                                                         

Epoch 50/50, Loss: 0.2033


In [8]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm
import numpy as np
import torch.nn.functional as F

In [9]:
# Custom Evaluation Dataset
# -----------------------------
class EvaluationDataset(Dataset):
    def __init__(self, image_dir, mask_dir):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_names = sorted([f for f in os.listdir(image_dir) if f.endswith('.png')])
        self.resize = transforms.Resize((512, 512))
        self.to_tensor = transforms.ToTensor()

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        filename = self.image_names[idx]
        image = Image.open(os.path.join(self.image_dir, filename)).convert("RGB")
        mask = Image.open(os.path.join(self.mask_dir, filename.replace(".png", "_mask.png"))).convert("L")

        image_resized = self.to_tensor(self.resize(image))
        mask_resized = self.to_tensor(self.resize(mask))
        mask_resized = (mask_resized > 0.5).float()

        return image_resized, mask_resized, filename

In [10]:
# IoU and Dice Score Functions
# -----------------------------
def compute_iou(pred, target):
    smooth = 1e-6
    intersection = (pred * target).sum()
    union = pred.sum() + target.sum() - intersection
    return (intersection + smooth) / (union + smooth)

def compute_dice(pred, target):
    smooth = 1e-6
    intersection = (pred * target).sum()
    return (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth)


In [11]:
# Prediction and Evaluation
# -----------------------------
def predict_and_evaluate(model, dataloader, device, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    model.eval()
    total_iou = 0.0
    total_dice = 0.0
    count = 0

    with torch.no_grad():
        for images, masks, filenames in tqdm(dataloader, desc="Evaluating"):
            images = images.to(device)
            masks = masks.to(device)

            outputs = model(images)
            preds = torch.sigmoid(outputs)
            preds = (preds > 0.5).float()

            for i in range(images.size(0)):
                pred = preds[i]
                mask = masks[i]

                iou = compute_iou(pred, mask)
                dice = compute_dice(pred, mask)
                total_iou += iou.item()
                total_dice += dice.item()
                count += 1

                # Save predicted mask
                pred_img = (pred.squeeze().cpu().numpy() * 255).astype(np.uint8)
                Image.fromarray(pred_img).save(os.path.join(output_dir, filenames[i]))

    avg_iou = total_iou / count
    avg_dice = total_dice / count
    print(f"\nAverage IoU: {avg_iou:.4f}")
    print(f"Average Dice Score: {avg_dice:.4f}")

# -----------------------------
# Main

In [12]:
def main():
    image_dir = r"D:\Indira\Pooja_segmentation\eval_img"
    mask_dir = r"D:\Indira\Pooja_segmentation\eval_mask"
    output_dir = r"D:\Indira\Pooja_segmentation\output_unetpp_img"
    model_path = "unetpp_model.pth"

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    dataset = EvaluationDataset(image_dir, mask_dir)
    dataloader = DataLoader(dataset, batch_size=4, shuffle=False)

    model = UNetPP(in_channels=3, out_channels=1).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))

    predict_and_evaluate(model, dataloader, device, output_dir)

In [13]:
if __name__ == '__main__':
    main()

Using device: cuda


Evaluating: 100%|██████████| 3/3 [00:04<00:00,  1.55s/it]


Average IoU: 0.8604
Average Dice Score: 0.9248



