In [2]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os
import sys
from icecream import ic

In [3]:
image_path = "./Image"
mask_path = "./Mask"
synthetic_image_path = "./synthetic_image"
synthetic_mask_path = "./synthetic_mask"
image_path_list = sorted(os.listdir(image_path))
mask_path_list = sorted(os.listdir(mask_path))
image_path_list = [os.path.join(image_path, i) for i in image_path_list]
mask_path_list = [os.path.join(mask_path, i) for i in mask_path_list]
synthetic_path_list = [os.path.join(synthetic_image_path, i) for i in os.listdir(synthetic_image_path)]
synthetic_mask_list = [os.path.join(synthetic_mask_path, i) for i in os.listdir(synthetic_mask_path)]

In [4]:
from sklearn.model_selection import train_test_split
train_size = 15
data = list(zip(image_path_list, mask_path_list))
train_images_path, test_images_path = train_test_split(data, train_size=train_size, shuffle=True)

In [5]:
synthetic_data = list(zip(synthetic_path_list, synthetic_mask_list))
train_images_path += synthetic_data

In [6]:
len(train_images_path), len(test_images_path)

(215, 5)

In [13]:
from torchvision.transforms import v2 as transforms
from PIL import Image
from torch.utils.data import Dataset, DataLoader

image_height, image_width = 320, 320

transform_image = transforms.Compose([
    transforms.Resize((image_height, image_width), interpolation=Image.BICUBIC),
    transforms.ToTensor()
])

transform_image_mask = transforms.Compose([
    transforms.Resize((image_height, image_width), interpolation=Image.NEAREST),
    transforms.ToTensor()
])

class ImageDataset(Dataset):
    def __init__(self, images_path, transform_image, transform_image_mask):
        self.images_path = images_path
        self.transform_image = transform_image
        self.transform_image_mask = transform_image_mask

    def __len__(self):
        return len(self.images_path)

    def __getitem__(self, idx):
        image_path, mask_path = self.images_path[idx]
        image = Image.open(image_path)
        mask = Image.open(mask_path)
        return self.transform_image(image), self.transform_image_mask(mask)
    
train_dataset = ImageDataset(train_images_path, transform_image, transform_image_mask)
test_dataset = ImageDataset(test_images_path, transform_image, transform_image_mask)



In [8]:
# import model
from u2net import U2NET

model = U2NET()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

U2NET(
  (stage1): RSU7(
    (rebnconvin): REBNCONV(
      (conv_s1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn_s1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu_s1): ReLU(inplace=True)
    )
    (rebnconv1): REBNCONV(
      (conv_s1): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn_s1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu_s1): ReLU(inplace=True)
    )
    (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=True)
    (rebnconv2): REBNCONV(
      (conv_s1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn_s1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu_s1): ReLU(inplace=True)
    )
    (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=True)
    (rebnconv3): REBNCONV(
      (conv_s1): Conv2d(32, 32, k

In [9]:
class DiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):

        #comment out if your model contains a sigmoid or equivalent activation layer
        # inputs = nn.sigmoid(inputs)

        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)

        intersection = (inputs * targets).sum()
        dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)

        return 1 - dice

In [10]:
class IoULoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(IoULoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):

        #comment out if your model contains a sigmoid or equivalent activation layer
        # inputs = F.sigmoid(inputs)

        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)

        #intersection is equivalent to True Positive count
        #union is the mutually inclusive area of all labels & predictions
        intersection = (inputs * targets).sum()
        total = (inputs + targets).sum()
        union = total - intersection

        IoU = (intersection + smooth)/(union + smooth)

        return 1 - IoU

In [11]:
import matplotlib.pyplot as plt

# plot the training and testing loss and metrics
def plot_metrics(history):
    fig, axs = plt.subplots(3, 2, figsize=(18, 15))  # Adjusted to 3 rows and 2 columns
    axs[0, 0].plot(history[0], label="train")
    axs[0, 0].plot(history[1], label="test")
    axs[0, 0].set_title("Loss")
    axs[0, 0].legend()

    axs[0, 1].plot(history[2], label="train")
    axs[0, 1].plot(history[3], label="test")
    axs[0, 1].set_title("Pixel Accuracy")
    axs[0, 1].legend()

    axs[1, 0].plot(history[4], label="train")
    axs[1, 0].plot(history[5], label="test")
    axs[1, 0].set_title("IoU")
    axs[1, 0].legend()

    axs[1, 1].plot(history[6], label="train")
    axs[1, 1].plot(history[7], label="test")
    axs[1, 1].set_title("Dice")
    axs[1, 1].legend()

    axs[2, 0].plot(history[8], label="train")
    axs[2, 0].plot(history[9], label="test")
    axs[2, 0].set_title("Precision")
    axs[2, 0].legend()

    axs[2, 1].plot(history[10], label="train")
    axs[2, 1].plot(history[11], label="test")
    axs[2, 1].set_title("Recall")
    axs[2, 1].legend()

    plt.tight_layout()
    plt.show()

In [12]:
from tqdm import tqdm
from torchmetrics.classification import BinaryAccuracy, BinaryRecall, BinaryPrecision, BinaryJaccardIndex
from torchmetrics import Dice

def train(model, epochs, train_loader, test_loader, criterion, lr=0.001):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    iou_metric = BinaryJaccardIndex().to(device)
    dice_metric = Dice().to(device)
    precision_metric = BinaryPrecision().to(device)
    recall_metric = BinaryRecall().to(device)
    pixel_acc_metric = BinaryAccuracy().to(device)

    loss_train = [0]*epochs
    loss_test = [0]*epochs
    pixel_acc_train = [0]*epochs
    pixel_acc_test = [0]*epochs
    iou_train = [0]*epochs
    iou_test = [0]*epochs
    dice_train = [0]*epochs
    dice_test = [0]*epochs
    precision_train = [0]*epochs
    precision_test = [0]*epochs
    recall_train = [0]*epochs
    recall_test = [0]*epochs

    for epoch in tqdm(range(epochs)):
        model.train()
        for image, mask in train_loader:
            image, mask = image.to(device), mask.to(device)
            optimizer.zero_grad()
            pred_mask = model(image)
            if isinstance(pred_mask, tuple):
                pred_mask = pred_mask[0]
            loss = criterion(pred_mask, mask)
            loss.backward()
            optimizer.step()

            with torch.no_grad():
                # Calculate metrics
                pred_mask_binary = pred_mask > 0.5
                pixel_acc = pixel_acc_metric(pred_mask_binary, mask)
                iou = iou_metric(pred_mask_binary, mask)
                dice = dice_metric(pred_mask_binary, mask)
                precision = precision_metric(pred_mask_binary, mask)
                recall = recall_metric(pred_mask_binary, mask)

                # Accumulate metrics
                loss_train[epoch] += loss.item()
                pixel_acc_train[epoch] += pixel_acc.item()
                iou_train[epoch] += iou.item()
                dice_train[epoch] += dice.item()
                precision_train[epoch] += precision.item()
                recall_train[epoch] += recall.item()

        # Calculate average metrics
        num_batches = len(train_loader)
        loss_train[epoch] /= num_batches
        pixel_acc_train[epoch] /= num_batches
        iou_train[epoch] /= num_batches
        dice_train[epoch] /= num_batches
        precision_train[epoch] /= num_batches
        recall_train[epoch] /= num_batches

        # Evaluate on test set
        model.eval()
        with torch.no_grad():
            for image, mask in test_loader:
                image, mask = image.to(device), mask.to(device)
                pred_mask = model(image)
                if isinstance(pred_mask, tuple):
                    pred_mask = pred_mask[0]
                loss = criterion(pred_mask, mask)

                pred_mask_binary = pred_mask > 0.5
                pixel_acc = pixel_acc_metric(pred_mask_binary, mask)
                iou = iou_metric(pred_mask_binary, mask)
                dice = dice_metric(pred_mask_binary, mask)
                precision = precision_metric(pred_mask_binary, mask)
                recall = recall_metric(pred_mask_binary, mask)

                loss_test[epoch] += loss.item()
                pixel_acc_test[epoch] += pixel_acc.item()
                iou_test[epoch] += iou.item()
                dice_test[epoch] += dice.item()
                precision_test[epoch] += precision.item()
                recall_test[epoch] += recall.item()

        num_batches_test = len(test_loader)
        loss_test[epoch] /= num_batches_test
        pixel_acc_test[epoch] /= num_batches_test
        iou_test[epoch] /= num_batches_test
        dice_test[epoch] /= num_batches_test
        precision_test[epoch] /= num_batches_test
        recall_test[epoch] /= num_batches_test

        if epoch%10 == 0:
            print(f"Epoch: {epoch+1}/{epochs}, Loss: {loss_train[epoch]:.4f}, Pixel Acc: {pixel_acc_train[epoch]:.4f}, IoU: {iou_train[epoch]:.4f}, Dice: {dice_train[epoch]:.4f}, Precision: {precision_train[epoch]:.4f}, Recall: {recall_train[epoch]:.4f}")
            print(f"Val Loss: {loss_test[epoch]:.4f}, Val Pixel Acc: {pixel_acc_test[epoch]:.4f}, Val IoU: {iou_test[epoch]:.4f}, Val Dice: {dice_test[epoch]:.4f}, Val Precision: {precision_test[epoch]:.4f}, Val Recall: {recall_test[epoch]:.4f}")

    return loss_train, loss_test, pixel_acc_train, pixel_acc_test, iou_train, iou_test, dice_train, dice_test, precision_train, precision_test, recall_train, recall_test

In [14]:
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True, prefetch_factor=1, num_workers=1)
test_loader = DataLoader(test_dataset, batch_size=5, shuffle=False, prefetch_factor=1, num_workers=1)