In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import os
import numpy as np
import re
import time

device = "cuda" if torch.cuda.is_available() else "cpu"

def natural_sort_key(filename):
    return [int(s) if s.isdigit() else s for s in re.split("(\d+)", filename)]

def addPadding(srcShapeTensor, tensor_whose_shape_isTobechanged):

    if(srcShapeTensor.shape != tensor_whose_shape_isTobechanged.shape):
        target = torch.zeros(srcShapeTensor.shape)
        target[:, :, :tensor_whose_shape_isTobechanged.shape[2],
               :tensor_whose_shape_isTobechanged.shape[3]] = tensor_whose_shape_isTobechanged
        return target.to(device)
    return tensor_whose_shape_isTobechanged.to(device)

class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_filenames = sorted(os.listdir(image_dir), key=natural_sort_key)
        self.mask_filenames = sorted(os.listdir(mask_dir), key=natural_sort_key)

        if self.image_filenames != self.mask_filenames:
            raise ValueError("Image and mask filenames do not match!")

        self.preprocess = transforms.Compose([
            # transforms.Resize((256, 256)),
            transforms.ToTensor(),
        ])
        self.transform = transform
        self.color_transform = transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.05, hue=0.05)

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.image_filenames[idx])
        mask_path = os.path.join(self.mask_dir, self.mask_filenames[idx])
        image = Image.open(image_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        
        image = self.preprocess(image)
        mask = self.preprocess(mask)
        

        if self.transform:
            image = self.color_transform(image)
            
            image_and_mask = torch.cat((image, mask))
            image_and_mask = self.transform(torch.cat((image, mask), dim=0))
            image = image_and_mask[0:3]
            mask = image_and_mask[3].unsqueeze(0)

        return image, mask


class AttentionBlock(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(AttentionBlock, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.enc1 = self.double_conv(3, 64)
        self.enc2 = self.double_conv(64, 128)
        self.enc3 = self.double_conv(128, 256)
        self.enc4 = self.double_conv(256, 512)
        self.enc5 = self.double_conv(512, 1024)

        self.up5 = self.up_trans(1024, 512)
        self.att5 = AttentionBlock(F_g=512, F_l=512, F_int=256)
        self.dec5 = self.double_conv(1024, 512)
        
        self.up4 = self.up_trans(512, 256)
        self.att4 = AttentionBlock(F_g=256, F_l=256, F_int=128)
        self.dec4 = self.double_conv(512, 256)
        
        self.up3 = self.up_trans(256, 128)
        
        self.dec3 = self.double_conv(256, 128)
        self.att3 = AttentionBlock(F_g=128, F_l=128, F_int=64)
        self.up2 = self.up_trans(128, 64)
        
        self.dec2 = self.double_conv(128, 64)
        self.att2 = AttentionBlock(F_g=64, F_l=64, F_int=32)
        self.final_conv = nn.Conv2d(64, 1, kernel_size=1)

    def double_conv(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

    def up_trans(self, in_channels, out_channels):
        return nn.ConvTranspose2d(
            in_channels,
            out_channels,
            kernel_size = 2,
            stride = 2
        )

    def crop_and_concat(self, upsampled, bypass):
        return torch.cat((upsampled, bypass), dim=1)

    def forward(self, x):
        enc1 = self.enc1(x)
        enc2 = self.enc2(nn.MaxPool2d(2)(enc1))
        enc3 = self.enc3(nn.MaxPool2d(2)(enc2))
        enc4 = self.enc4(nn.MaxPool2d(2)(enc3))
        enc5 = self.enc5(nn.MaxPool2d(2)(enc4))

        x = self.up5(enc5)
        x = addPadding(enc4, x)
        x = self.att5(x, enc4)
        x = self.dec5(self.crop_and_concat(x, enc4))
        
        x = self.up4(x)
        x = addPadding(enc3, x)
        x = self.att4(x, enc3)
        x = self.dec4(self.crop_and_concat(x, enc3))
        
        x = self.up3(x)
        x = addPadding(enc2, x)
        x = self.att3(x, enc2)
        x = self.dec3(self.crop_and_concat(x, enc2))
        
        x = self.up2(x)
        x = addPadding(enc1, x)
        x = self.att2(x, enc1)
        x = self.dec2(self.crop_and_concat(x, enc1))

        output = torch.sigmoid(self.final_conv(x))
        return output


def calculate_iou(output, target, threshold=0.5):
    output = output > threshold
    target = target > 0

    output_flat = output.view(-1)
    target_flat = target.view(-1)

    intersection = (output_flat & target_flat).sum().float()
    union = (output_flat | target_flat).sum().float()
    if union == 0:
        print("UNION SHOULD NOT BE ZERO, THERE MUST BE SOME MISTAKE")
        return 1

    iou = intersection / union
    return iou.item()


def tensor_to_required_image(image, input_path, input_index):
    image = image > 0.5
    
    ref_image = Image.open(f"{input_path}/{input_index}.png")
    ref_shape = ref_image.size
    ref_shape = (ref_shape[0] * 2, ref_shape[1])
    
    image = Image.fromarray((image * 255).astype(np.uint8))
    image = image.resize(ref_shape)
    return image

In [11]:
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(30),
    transforms.RandomResizedCrop(256, scale=(0.3, 1.0)),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
])

directory_path = "./training_dataset"
train_dataset = SegmentationDataset(f"{directory_path}/image", f"{directory_path}/mask", transform=train_transform)
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)

directory_path = "./testing_dataset"
test_dataset = SegmentationDataset(f"{directory_path}/image", f"{directory_path}/mask")
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)



torch.cuda.empty_cache()

model = UNet().to(device)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [12]:
phases = ["train", "valid"]
data_loader = {"train": train_dataloader, "valid": test_dataloader}
best_mean_iou = 0
best_epoch = 0

num_epochs = 250
for epoch in range(num_epochs):
    start_time = time.time()
    for phase in phases:
        if phase == "train":
            model.train()
            train_loss = 0
            train_iou_scores = []
        elif phase == "valid":
            model.eval()
            val_loss = 0
            val_iou_scores = []

        for images, masks in data_loader[phase]:
            images, masks = images.to(device), masks.to(device)

            if phase == "train":
                outputs = model(images)
                
                loss = criterion(outputs, masks)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                train_loss += loss.item()
                
                iou = calculate_iou(outputs, masks)
                train_iou_scores.append(iou)
            elif phase == "valid":
                with torch.no_grad():
                    outputs = model(images)
                    
                    loss = criterion(outputs, masks)

                    val_loss += loss.item()
                    
                    iou = calculate_iou(outputs, masks)
                    val_iou_scores.append(iou)

    end_time = time.time()
    train_mean_iou = sum(train_iou_scores) / len(train_iou_scores)
    val_mean_iou = sum(val_iou_scores) / len(val_iou_scores)
    print(f"Epoch {epoch+1}/{num_epochs}, Time: {end_time - start_time:.2f} Training Loss: {train_loss:.4f}, Training Mean IoU: {train_mean_iou:.4f}, Validation Loss: {val_loss:.4f}, Validation Mean IoU: {val_mean_iou:.4f}")

    if val_mean_iou > best_mean_iou:
        best_mean_iou = val_mean_iou
        best_epoch = epoch
        torch.save(model.state_dict(), 'best_epoch.pth')

print("Training Complete!")
print(f"Best Mean IoU: {best_mean_iou:.4f}, at epoch {best_epoch}")

Epoch 1/250, Time: 16.88 Training Loss: 6.5110, Training Mean IoU: 0.3874, Validation Loss: 14.0847, Validation Mean IoU: 0.4048
Epoch 2/250, Time: 16.23 Training Loss: 5.9010, Training Mean IoU: 0.5087, Validation Loss: 14.5964, Validation Mean IoU: 0.4049
Epoch 3/250, Time: 16.21 Training Loss: 5.7665, Training Mean IoU: 0.5303, Validation Loss: 14.4551, Validation Mean IoU: 0.4121
Epoch 4/250, Time: 16.45 Training Loss: 5.4650, Training Mean IoU: 0.5568, Validation Loss: 13.4459, Validation Mean IoU: 0.4288
Epoch 5/250, Time: 16.27 Training Loss: 5.5213, Training Mean IoU: 0.5198, Validation Loss: 11.9318, Validation Mean IoU: 0.4556
Epoch 6/250, Time: 16.92 Training Loss: 5.3309, Training Mean IoU: 0.5449, Validation Loss: 11.9720, Validation Mean IoU: 0.4643
Epoch 7/250, Time: 16.18 Training Loss: 5.5890, Training Mean IoU: 0.5288, Validation Loss: 11.2664, Validation Mean IoU: 0.4638
Epoch 8/250, Time: 16.81 Training Loss: 5.1701, Training Mean IoU: 0.5547, Validation Loss: 11.33

In [13]:
model.load_state_dict(torch.load('best_epoch.pth', weights_only=True))
model.eval()
iou_scores = []
output_dir = "./output"
os.makedirs(output_dir, exist_ok=True)

with torch.no_grad():
    for idx, (images, masks) in enumerate(test_dataloader):
        images, masks = images.to(device), masks.to(device)

        outputs = model(images)
        iou = calculate_iou(outputs, masks)

        iou_scores.append(iou)

        output_image_path = os.path.join(output_dir, f"output_{idx + 1}.png")

        output_and_mask = torch.cat((masks, outputs), dim=3).squeeze().cpu().numpy()
        output_and_mask_image = tensor_to_required_image(output_and_mask, f"{directory_path}/mask", idx + 1)
        output_and_mask_image.save(output_image_path)

        print(f"Saved output and mask for image {idx + 1}. IoU: {iou:.4f}")

mean_iou = sum(iou_scores) / len(iou_scores)
print(f"Mean IoU: {mean_iou:.4f}")

  return self.fget.__get__(instance, owner)()


Saved output and mask for image 1. IoU: 0.5733
Saved output and mask for image 2. IoU: 0.5660
Saved output and mask for image 3. IoU: 0.2498
Saved output and mask for image 4. IoU: 0.4876
Saved output and mask for image 5. IoU: 0.7004
Saved output and mask for image 6. IoU: 0.4179
Saved output and mask for image 7. IoU: 0.7566
Saved output and mask for image 8. IoU: 0.9357
Saved output and mask for image 9. IoU: 0.4040
Saved output and mask for image 10. IoU: 0.7932
Saved output and mask for image 11. IoU: 0.7666
Saved output and mask for image 12. IoU: 0.6844
Saved output and mask for image 13. IoU: 0.5260
Saved output and mask for image 14. IoU: 0.7814
Saved output and mask for image 15. IoU: 0.1440
Saved output and mask for image 16. IoU: 0.8591
Saved output and mask for image 17. IoU: 0.6160
Saved output and mask for image 18. IoU: 0.9315
Saved output and mask for image 19. IoU: 0.9636
Saved output and mask for image 20. IoU: 0.6571
Mean IoU: 0.6407
