In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import os
import numpy as np
import re


def natural_sort_key(filename):
    return [int(s) if s.isdigit() else s for s in re.split("(\d+)", filename)]


class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_filenames = sorted(os.listdir(image_dir), key=natural_sort_key)
        self.mask_filenames = sorted(os.listdir(mask_dir), key=natural_sort_key)

        if self.image_filenames != self.mask_filenames:
            raise ValueError("Image and mask filenames do not match!")

        self.preprocess = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
        ])
        self.transform = transform
        self.color_transform = transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.05, hue=0.05)

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.image_filenames[idx])
        mask_path = os.path.join(self.mask_dir, self.mask_filenames[idx])
        image = Image.open(image_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        image = self.preprocess(image)
        mask = self.preprocess(mask)

        if self.transform:
            image = self.color_transform(image)
            
            image_and_mask = torch.cat((image, mask))
            image_and_mask = self.transform(torch.cat((image, mask), dim=0))
            image = image_and_mask[0:3]
            mask = image_and_mask[3].unsqueeze(0)

        return image, mask


class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.enc1 = self.contracting_block(3, 64)
        self.enc2 = self.contracting_block(64, 128)
        self.enc3 = self.contracting_block(128, 256)
        self.enc4 = self.contracting_block(256, 512)
        self.enc5 = self.contracting_block(512, 1024)

        self.up5 = self.expanding_block(1024, 512)
        self.up4 = self.expanding_block(1024, 256)
        self.up3 = self.expanding_block(512, 128)
        self.up2 = self.expanding_block(256, 64)
        self.final_conv = nn.Conv2d(128, 1, kernel_size=1)

    def contracting_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

    def expanding_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

    def crop_and_concat(self, upsampled, bypass):
        return torch.cat((upsampled, bypass), dim=1)

    def forward(self, x):
        enc1 = self.enc1(x)
        enc2 = self.enc2(nn.MaxPool2d(2)(enc1))
        enc3 = self.enc3(nn.MaxPool2d(2)(enc2))
        enc4 = self.enc4(nn.MaxPool2d(2)(enc3))
        enc5 = self.enc5(nn.MaxPool2d(2)(enc4))

        up5 = self.crop_and_concat(self.up5(enc5), enc4)
        up4 = self.crop_and_concat(self.up4(up5), enc3)
        up3 = self.crop_and_concat(self.up3(up4), enc2)
        up2 = self.crop_and_concat(self.up2(up3), enc1)

        output = torch.sigmoid(self.final_conv(up2))
        return output


def calculate_iou(output, target, threshold=0.5):
    output = output > threshold
    target = target > 0

    output_flat = output.view(-1)
    target_flat = target.view(-1)

    intersection = (output_flat & target_flat).sum().float()
    union = (output_flat | target_flat).sum().float()
    if union == 0:
        print("UNION SHOULD NOT BE ZERO, THERE MUST BE SOME MISTAKE")
        return 1

    iou = intersection / union
    return iou.item()


def tensor_to_required_image(image, input_path, input_index):
    image = image > 0.5
    
    ref_image = Image.open(f"{input_path}/{input_index}.png")
    ref_shape = ref_image.size
    ref_shape = (ref_shape[0] * 2, ref_shape[1])
    
    image = Image.fromarray((image * 255).astype(np.uint8))
    image = image.resize(ref_shape)
    return image

In [2]:
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(30),
    transforms.RandomResizedCrop(256, scale=(0.8, 1.0)),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
])

directory_path = "./training_dataset"
train_dataset = SegmentationDataset(f"{directory_path}/image", f"{directory_path}/mask", transform=train_transform)
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)

directory_path = "./testing_dataset"
test_dataset = SegmentationDataset(f"{directory_path}/image", f"{directory_path}/mask")
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)

device = "cuda" if torch.cuda.is_available() else "cpu"

model = UNet().to(device)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [3]:
phases = ["train", "valid"]
data_loader = {"train": train_dataloader, "valid": test_dataloader}
best_mean_iou = 0
best_epoch = 0

num_epochs = 250
for epoch in range(num_epochs):
    for phase in phases:
        if phase == "train":
            model.train()
            epoch_loss = 0
        elif phase == "valid":
            model.eval()
            iou_scores = []

        for images, masks in data_loader[phase]:
            images, masks = images.to(device), masks.to(device)

            outputs = model(images)

            if phase == "train":
                loss = criterion(outputs, masks)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()
            elif phase == "valid":
                iou = calculate_iou(outputs, masks)
                iou_scores.append(iou)

        if phase == "valid":
            mean_iou = sum(iou_scores) / len(iou_scores)
            print(f"Epoch {epoch+1}/{num_epochs}, Training Loss: {epoch_loss:.4f}, Testing Mean IoU: {mean_iou:.4f}")

            if mean_iou > best_mean_iou:
                best_mean_iou = mean_iou
                best_epoch = epoch
                torch.save(model.state_dict(), 'best_epoch.pth')

print("Training Complete!")
print(f"Best Mean IoU: {best_mean_iou:.4f}, at epoch {best_epoch}")

Epoch 1/250, Training Loss: 6.8589, Testing Mean IoU: 0.4143
Epoch 2/250, Training Loss: 6.2515, Testing Mean IoU: 0.4143
Epoch 3/250, Training Loss: 5.8953, Testing Mean IoU: 0.4143
Epoch 4/250, Training Loss: 5.7828, Testing Mean IoU: 0.4147
Epoch 5/250, Training Loss: 5.7169, Testing Mean IoU: 0.4180
Epoch 6/250, Training Loss: 5.5618, Testing Mean IoU: 0.4596
Epoch 7/250, Training Loss: 5.4034, Testing Mean IoU: 0.4936
Epoch 8/250, Training Loss: 5.4333, Testing Mean IoU: 0.5215
Epoch 9/250, Training Loss: 5.2396, Testing Mean IoU: 0.4669
Epoch 10/250, Training Loss: 5.2546, Testing Mean IoU: 0.5324
Epoch 11/250, Training Loss: 5.2300, Testing Mean IoU: 0.5016
Epoch 12/250, Training Loss: 5.1629, Testing Mean IoU: 0.4728
Epoch 13/250, Training Loss: 5.1612, Testing Mean IoU: 0.5004
Epoch 14/250, Training Loss: 5.0116, Testing Mean IoU: 0.4826
Epoch 15/250, Training Loss: 4.8574, Testing Mean IoU: 0.4951
Epoch 16/250, Training Loss: 4.9174, Testing Mean IoU: 0.4856
Epoch 17/250, Tra

In [4]:
model.load_state_dict(torch.load('best_epoch.pth', weights_only=True))
model.eval()
iou_scores = []
output_dir = "./output"
os.makedirs(output_dir, exist_ok=True)

with torch.no_grad():
    for idx, (images, masks) in enumerate(test_dataloader):
        images, masks = images.to(device), masks.to(device)

        outputs = model(images)
        iou = calculate_iou(outputs, masks)

        iou_scores.append(iou)

        output_image_path = os.path.join(output_dir, f"output_{idx + 1}.png")

        output_and_mask = torch.cat((masks, outputs), dim=3).squeeze().cpu().numpy()
        output_and_mask_image = tensor_to_required_image(output_and_mask, f"{directory_path}/mask", idx + 1)
        output_and_mask_image.save(output_image_path)

        print(f"Saved output and mask for image {idx + 1}. IoU: {iou:.4f}")

mean_iou = sum(iou_scores) / len(iou_scores)
print(f"Mean IoU: {mean_iou:.4f}")

Saved output and mask for image 1. IoU: 0.0000
Saved output and mask for image 2. IoU: 0.1864
Saved output and mask for image 3. IoU: 0.4665
Saved output and mask for image 4. IoU: 0.2504
Saved output and mask for image 5. IoU: 0.7739
Saved output and mask for image 6. IoU: 0.4680
Saved output and mask for image 7. IoU: 0.7683
Saved output and mask for image 8. IoU: 0.9501
Saved output and mask for image 9. IoU: 0.2811
Saved output and mask for image 10. IoU: 0.7462
Saved output and mask for image 11. IoU: 0.8927
Saved output and mask for image 12. IoU: 0.9932
Saved output and mask for image 13. IoU: 0.7065
Saved output and mask for image 14. IoU: 0.7853
Saved output and mask for image 15. IoU: 0.1620
Saved output and mask for image 16. IoU: 0.8340
Saved output and mask for image 17. IoU: 0.7860
Saved output and mask for image 18. IoU: 0.8060
Saved output and mask for image 19. IoU: 0.7595
Saved output and mask for image 20. IoU: 0.1253
Mean IoU: 0.5871
