In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms
from PIL import Image
import os
import numpy as np
import re
import time

device = "cuda" if torch.cuda.is_available() else "cpu"


def natural_sort_key(filename):
    return [int(s) if s.isdigit() else s for s in re.split("(\d+)", filename)]


def addPadding(srcShapeTensor, tensor_whose_shape_isTobechanged):

    if(srcShapeTensor.shape != tensor_whose_shape_isTobechanged.shape):
        target = torch.zeros(srcShapeTensor.shape)
        target[:, :, :tensor_whose_shape_isTobechanged.shape[2],
               :tensor_whose_shape_isTobechanged.shape[3]] = tensor_whose_shape_isTobechanged
        return target.to(device)
    return tensor_whose_shape_isTobechanged.to(device)


class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir=None, transform=None):
        self.image_dir = image_dir
        self.image_filenames = sorted(os.listdir(image_dir), key=natural_sort_key)

        if mask_dir is None:
            self.mask_dir = None
        else:
            self.mask_dir = mask_dir
            self.mask_filenames = sorted(os.listdir(mask_dir), key=natural_sort_key)
            if self.image_filenames != self.mask_filenames:
                raise ValueError("Image and mask filenames do not match!")

        self.preprocess = transforms.Compose([
            # transforms.Resize((256, 256)),
            transforms.ToTensor(),
        ])
        self.transform = transform
        self.color_transform = transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.05, hue=0.05)

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.image_filenames[idx])
        image = Image.open(image_path).convert("RGB")
        image = self.preprocess(image)

        if self.mask_dir is None:
            return image
        else:
            mask_path = os.path.join(self.mask_dir, self.mask_filenames[idx])
            mask = Image.open(mask_path).convert("L")
            mask = self.preprocess(mask)

            if self.transform:
                image = self.color_transform(image)

                image_and_mask = torch.cat((image, mask))
                image_and_mask = self.transform(torch.cat((image, mask), dim=0))
                image = image_and_mask[0:3]
                mask = image_and_mask[3].unsqueeze(0)
            return image, mask


class AttentionBlock(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(AttentionBlock, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.enc1 = self.double_conv(3, 64)
        self.enc2 = self.double_conv(64, 128)
        self.enc3 = self.double_conv(128, 256)
        self.enc4 = self.double_conv(256, 512)
        self.enc5 = self.double_conv(512, 1024)

        self.up5 = self.up_trans(1024, 512)
        self.att5 = AttentionBlock(F_g=512, F_l=512, F_int=256)
        self.dec5 = self.double_conv(1024, 512)

        self.up4 = self.up_trans(512, 256)
        self.att4 = AttentionBlock(F_g=256, F_l=256, F_int=128)
        self.dec4 = self.double_conv(512, 256)

        self.up3 = self.up_trans(256, 128)

        self.dec3 = self.double_conv(256, 128)
        self.att3 = AttentionBlock(F_g=128, F_l=128, F_int=64)
        self.up2 = self.up_trans(128, 64)

        self.dec2 = self.double_conv(128, 64)
        self.att2 = AttentionBlock(F_g=64, F_l=64, F_int=32)
        self.final_conv = nn.Conv2d(64, 1, kernel_size=1)

    def double_conv(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

    def up_trans(self, in_channels, out_channels):
        return nn.ConvTranspose2d(
            in_channels,
            out_channels,
            kernel_size = 2,
            stride = 2
        )

    def crop_and_concat(self, upsampled, bypass):
        return torch.cat((upsampled, bypass), dim=1)

    def forward(self, x):
        enc1 = self.enc1(x)
        enc2 = self.enc2(nn.MaxPool2d(2)(enc1))
        enc3 = self.enc3(nn.MaxPool2d(2)(enc2))
        enc4 = self.enc4(nn.MaxPool2d(2)(enc3))
        enc5 = self.enc5(nn.MaxPool2d(2)(enc4))

        x = self.up5(enc5)
        x = addPadding(enc4, x)
        x = self.att5(x, enc4)
        x = self.dec5(self.crop_and_concat(x, enc4))

        x = self.up4(x)
        x = addPadding(enc3, x)
        x = self.att4(x, enc3)
        x = self.dec4(self.crop_and_concat(x, enc3))

        x = self.up3(x)
        x = addPadding(enc2, x)
        x = self.att3(x, enc2)
        x = self.dec3(self.crop_and_concat(x, enc2))

        x = self.up2(x)
        x = addPadding(enc1, x)
        x = self.att2(x, enc1)
        x = self.dec2(self.crop_and_concat(x, enc1))

        output = torch.sigmoid(self.final_conv(x))
        return output


def calculate_iou(output, target, threshold=0.5):
    output = output > threshold
    target = target > 0

    output_flat = output.view(-1)
    target_flat = target.view(-1)

    intersection = (output_flat & target_flat).sum().float()
    union = (output_flat | target_flat).sum().float()
    if union == 0:
        print("UNION SHOULD NOT BE ZERO, THERE MUST BE SOME MISTAKE")
        return 1

    iou = intersection / union
    return iou.item()


def tensor_to_required_image(image, input_path, input_index, have_input_concated=True):
    image = image.squeeze().cpu().numpy() > 0.5

    input_image = Image.open(f"{input_path}/{input_index}.png")
    input_shape = input_image.size
    if have_input_concated:
        input_shape = (input_shape[0] * 2, input_shape[1])
    else:
        input_shape = (input_shape[0], input_shape[1])

    image = Image.fromarray((image * 255).astype(np.uint8))
    image = image.resize(input_shape)
    return image

In [3]:
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(30),
    transforms.RandomResizedCrop(256, scale=(0.3, 1.0)),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
])

directory_path = "/content/drive/MyDrive/WaterSegementation/training_dataset"
train_dataset = SegmentationDataset(f"{directory_path}/image", f"{directory_path}/mask", transform=train_transform)
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)

pretend_train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - pretend_train_size
_, val_subset = random_split(train_dataset, [pretend_train_size, val_size], generator=torch.Generator().manual_seed(42))
val_dataloader = DataLoader(val_subset, batch_size=8, shuffle=False)

directory_path = "/content/drive/MyDrive/WaterSegementation/testing_dataset"
test_dataset = SegmentationDataset(f"{directory_path}/image")
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)

torch.cuda.empty_cache()

model = UNet().to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [None]:
phases = ["train", "valid"]
data_loader = {"train": train_dataloader, "valid": val_dataloader}
take_top_n = 5
top_mean_ious = [0] * take_top_n
top_epochs = [0] * take_top_n

num_epochs = 200
start_time = time.time()
for epoch in range(num_epochs):
    for phase in phases:
        if phase == "train":
            model.train()
            train_loss = 0
            train_iou_scores = []
        elif phase == "valid":
            model.eval()
            val_loss = 0
            val_iou_scores = []

        for images, masks in data_loader[phase]:
            images, masks = images.to(device), masks.to(device)

            if phase == "train":
                outputs = model(images)

                loss = criterion(outputs, masks)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                train_loss += loss.item()

                # iou = calculate_iou(outputs, masks)
                # train_iou_scores.append(iou)
            elif phase == "valid":
                with torch.no_grad():
                    outputs = model(images)

                    # loss = criterion(outputs, masks)

                    # val_loss += loss.item()

                    iou = calculate_iou(outputs, masks)
                    val_iou_scores.append(iou)

    end_time = time.time()
    # train_mean_iou = sum(train_iou_scores) / len(train_iou_scores)
    val_mean_iou = sum(val_iou_scores) / len(val_iou_scores)
    # print(f"Epoch {epoch+1}/{num_epochs}, Time: {end_time - start_time:.2f} Training Loss: {train_loss:.4f}, Training Mean IoU: {train_mean_iou:.4f}, Validation Loss: {val_loss:.4f}, Validation Mean IoU: {val_mean_iou:.4f}")
    print(f"Epoch {epoch+1}/{num_epochs}, Time: {end_time - start_time:.2f} Training Loss: {train_loss:.4f}, Validation Mean IoU: {val_mean_iou:.4f}")

    for i in range(5):
        if val_mean_iou > top_mean_ious[i]:
            top_mean_ious[i] = val_mean_iou
            top_epochs[i] = epoch
            torch.save(model.state_dict(), f'top_{i + 1}.pth')
            break

print("Training Complete!")
print(f"Top Mean IoUs: {top_mean_ious}, at epochs {top_epochs}")

In [12]:
model.load_state_dict(torch.load(f'top_{1}.pth', weights_only=True))
model.eval()
output_dir = "./output"
os.makedirs(output_dir, exist_ok=True)

with torch.no_grad():
    for idx, images in enumerate(test_dataloader):
        images = images.to(device)

        outputs = model(images)

        output_image_path = os.path.join(output_dir, f"{idx + 1}.png")
        output_image = tensor_to_required_image(outputs, f"{directory_path}/image", idx + 1, have_input_concated=False)
        output_image.save(output_image_path)

        # output_and_mask_image_path = os.path.join(output_dir, f"output_{idx + 1}.png")
        # output_and_mask = torch.cat((masks, outputs), dim=3)
        # output_and_mask_image = tensor_to_required_image(output_and_mask, f"{directory_path}/mask", idx + 1)
        # output_and_mask_image.save(output_image_path)

        print(f"Saved output and mask for image {idx + 1}")

Top Mean IoUs: [0.8796939849853516, 0.8736484944820404, 0.8458597362041473, 0.843244880437851, 0.8418437540531158], at epochs [235, 239, 236, 242, 243]
Saved output and mask for image 1
Saved output and mask for image 2
Saved output and mask for image 3
Saved output and mask for image 4
Saved output and mask for image 5
Saved output and mask for image 6
Saved output and mask for image 7
Saved output and mask for image 8
Saved output and mask for image 9
Saved output and mask for image 10
Saved output and mask for image 11
Saved output and mask for image 12
Saved output and mask for image 13
Saved output and mask for image 14
Saved output and mask for image 15
Saved output and mask for image 16
Saved output and mask for image 17
Saved output and mask for image 18
Saved output and mask for image 19
Saved output and mask for image 20
