In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import os
import numpy as np
import tqdm


class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.enc1 = self.contracting_block(3, 64)
        self.enc2 = self.contracting_block(64, 128)
        self.enc3 = self.contracting_block(128, 256)
        self.enc4 = self.contracting_block(256, 512)
        self.enc5 = self.contracting_block(512, 1024)

        self.up5 = self.expanding_block(1024, 512)
        self.up4 = self.expanding_block(1024, 256)
        self.up3 = self.expanding_block(512, 128)
        self.up2 = self.expanding_block(256, 64)
        self.final_conv = nn.Conv2d(128, 1, kernel_size=1)

    def contracting_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU()
        )

    def expanding_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU()
        )

    def crop_and_concat(self, upsampled, bypass):
        return torch.cat((upsampled, bypass), dim=1)

    def forward(self, x):
        enc1 = self.enc1(x)
        enc2 = self.enc2(nn.MaxPool2d(2)(enc1))
        enc3 = self.enc3(nn.MaxPool2d(2)(enc2))
        enc4 = self.enc4(nn.MaxPool2d(2)(enc3))
        enc5 = self.enc5(nn.MaxPool2d(2)(enc4))

        up5 = self.crop_and_concat(self.up5(enc5), enc4)
        up4 = self.crop_and_concat(self.up4(up5), enc3)
        up3 = self.crop_and_concat(self.up3(up4), enc2)
        up2 = self.crop_and_concat(self.up2(up3), enc1)

        output = torch.sigmoid(self.final_conv(up2))
        return output


class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_filenames = sorted(os.listdir(image_dir))
        self.mask_filenames = sorted(os.listdir(mask_dir))

        if self.image_filenames != self.mask_filenames:
            raise ValueError("Image and mask filenames do not match!")

        self.transform = transform

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.image_filenames[idx])
        mask_path = os.path.join(self.mask_dir, self.mask_filenames[idx])
        image = Image.open(image_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        if self.transform:
            image = self.transform(image)
            mask =self.transform(mask)  # Separate transform for masks


        return image, mask

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

dataset = SegmentationDataset("./training_dataset/image", "./training_dataset/mask", transform=transform)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

model = UNet().to(device)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

num_epochs = 100
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    for images, masks in dataloader:
        images, masks = images.to(device), masks.to(device)

        outputs = model(images)
        loss = criterion(outputs, masks)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")

print("Training Complete!")

Epoch 1/100, Loss: 6.9346
Epoch 2/100, Loss: 6.9138
Epoch 3/100, Loss: 6.9117
Epoch 4/100, Loss: 6.9021
Epoch 5/100, Loss: 6.8944
Epoch 6/100, Loss: 6.8834
Epoch 7/100, Loss: 6.8591
Epoch 8/100, Loss: 6.8242
Epoch 9/100, Loss: 6.6687
Epoch 10/100, Loss: 6.5634
Epoch 11/100, Loss: 6.5928
Epoch 12/100, Loss: 6.5365
Epoch 13/100, Loss: 6.4461
Epoch 14/100, Loss: 6.5064
Epoch 15/100, Loss: 6.4632
Epoch 16/100, Loss: 6.4747
Epoch 17/100, Loss: 6.3575
Epoch 18/100, Loss: 6.3155
Epoch 19/100, Loss: 6.2835
Epoch 20/100, Loss: 6.3212
Epoch 21/100, Loss: 6.4838
Epoch 22/100, Loss: 6.3973
Epoch 23/100, Loss: 6.5653
Epoch 24/100, Loss: 6.2909
Epoch 25/100, Loss: 6.2588
Epoch 26/100, Loss: 6.2575
Epoch 27/100, Loss: 6.1635
Epoch 28/100, Loss: 6.3339
Epoch 29/100, Loss: 6.2724
Epoch 30/100, Loss: 6.2152
Epoch 31/100, Loss: 6.1807
Epoch 32/100, Loss: 6.0921
Epoch 33/100, Loss: 5.9941
Epoch 34/100, Loss: 6.1166
Epoch 35/100, Loss: 6.0382
Epoch 36/100, Loss: 6.0893
Epoch 37/100, Loss: 5.9581
Epoch 38/1