In [1]:
import os
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm

In [2]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.block(x)


In [3]:
def forward(self, x):
        return self.block(x)

class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1):
        super().__init__()
        self.enc1 = DoubleConv(in_channels, 64)
        self.enc2 = DoubleConv(64, 128)
        self.enc3 = DoubleConv(128, 256)
        self.enc4 = DoubleConv(256, 512)

        self.pool = nn.MaxPool2d(2)

        self.bottleneck = DoubleConv(512, 1024)

        self.up1 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.dec1 = DoubleConv(1024, 512)
        self.up2 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec2 = DoubleConv(512, 256)
        self.up3 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec3 = DoubleConv(256, 128)
        self.up4 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec4 = DoubleConv(128, 64)

        self.out = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))

        b = self.bottleneck(self.pool(e4))

        d1 = self.dec1(torch.cat([self.up1(b), e4], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d1), e3], dim=1))
        d3 = self.dec3(torch.cat([self.up3(d2), e2], dim=1))
        d4 = self.dec4(torch.cat([self.up4(d3), e1], dim=1))

        return self.out(d4)

In [4]:
# -----------------------------
# Custom Dataset (images and masks in separate directories)
# -----------------------------
class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, image_transform=None, mask_transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_names = sorted([f for f in os.listdir(image_dir) if f.endswith('.png')])
        self.image_transform = image_transform
        self.mask_transform = mask_transform

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        image_name = self.image_names[idx]
        mask_name = image_name.replace('.png', '_mask.png')

        image_path = os.path.join(self.image_dir, image_name)
        mask_path = os.path.join(self.mask_dir, mask_name)

        image = Image.open(image_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        if self.image_transform:
            image = self.image_transform(image)
        if self.mask_transform:
            mask = self.mask_transform(mask)

        mask = (mask > 0.5).float()
        if mask.ndim == 3:
            mask = mask[0].unsqueeze(0)

        return image, mask


In [5]:
# -----------------------------
# Evaluation Dataset (no masks)
# -----------------------------
class EvaluationDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.image_names = sorted([f for f in os.listdir(image_dir) if f.endswith('.png')])
        self.transform = transform

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        image_name = self.image_names[idx]
        image_path = os.path.join(self.image_dir, image_name)
        image = Image.open(image_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, image_name


In [6]:
# Transforms (no resize, no crop, no normalization)
image_transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor()
])

mask_transform = transforms.Compose([
     transforms.Resize((512, 512)),
    transforms.ToTensor()
])


In [7]:
# -----------------------------
# Training
# -----------------------------
def train(model, dataloader, optimizer, loss_fn, device):
    model.train()
    running_loss = 0.0
    for images, masks in tqdm(dataloader, desc="Training", leave=False):
        images = images.to(device)
        masks = masks.to(device)

        outputs = model(images)
        loss = loss_fn(outputs, masks)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    return running_loss / len(dataloader)

In [12]:
def predict_and_save(model, eval_loader, device, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    model.eval()
    with torch.no_grad():
        for images, filenames in tqdm(eval_loader, desc="Predicting"):
            images = images.to(device)
            outputs = model(images)
            preds = torch.sigmoid(outputs)
            preds = (preds > 0.5).float()

            for i in range(images.size(0)):
                pred_mask = preds[i].squeeze().cpu().numpy() * 255.0
                pred_mask = Image.fromarray(pred_mask.astype('uint8'))
                pred_mask = pred_mask.resize((1280,720))
                save_path = os.path.join(output_dir, filenames[i])
                pred_mask.save(save_path)

In [13]:
def main():
    image_dir = r"D:\\Indira\\Pooja_segmentation\\input_image"
    mask_dir = r"D:\\Indira\\Pooja_segmentation\\input_mask"
    batch_size = 4

    epochs = 20
    lr = 1e-4
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    dataset = SegmentationDataset(image_dir, mask_dir, image_transform, mask_transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    model = UNet(in_channels=3, out_channels=1).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.BCEWithLogitsLoss()

    for epoch in range(epochs):
        loss = train(model, dataloader, optimizer, loss_fn, device)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {loss:.4f}")

    torch.save(model.state_dict(), "unet_model.pth")

    # Evaluation and saving predictions
    eval_image_dir = r"D:\\Indira\\Pooja_segmentation\\eval_img"
    output_dir = r"D:\\Indira\\Pooja_segmentation\\output_img"
    eval_dataset = EvaluationDataset(eval_image_dir, transform=image_transform)
    eval_loader = DataLoader(eval_dataset, batch_size=4, shuffle=False)
    predict_and_save(model, eval_loader, device, output_dir)

In [14]:

if __name__ == '__main__':
    main()

                                                         

Epoch 1/20, Loss: 0.4282


                                                         

Epoch 2/20, Loss: 0.2921


                                                         

Epoch 3/20, Loss: 0.2463


                                                         

Epoch 4/20, Loss: 0.2153


                                                         

Epoch 5/20, Loss: 0.2025


                                                         

Epoch 6/20, Loss: 0.1986


                                                         

Epoch 7/20, Loss: 0.1867


                                                         

Epoch 8/20, Loss: 0.1785


                                                         

Epoch 9/20, Loss: 0.1723


                                                         

Epoch 10/20, Loss: 0.1684


                                                         

Epoch 11/20, Loss: 0.1711


                                                         

Epoch 12/20, Loss: 0.1688


                                                         

Epoch 13/20, Loss: 0.1588


                                                         

Epoch 14/20, Loss: 0.1484


                                                         

Epoch 15/20, Loss: 0.1432


                                                         

Epoch 16/20, Loss: 0.1425


                                                         

Epoch 17/20, Loss: 0.1344


                                                         

Epoch 18/20, Loss: 0.1283


                                                         

Epoch 19/20, Loss: 0.1253


                                                         

Epoch 20/20, Loss: 0.1210


Predicting: 100%|██████████| 3/3 [00:01<00:00,  2.55it/s]
