In [1]:
from pathlib import Path
import matplotlib.pyplot as plt
import os
import cv2
import numpy as np

import torch
from torch import nn
from torch.nn import functional as F

from data_loader import CCAgTDataset
from torch.utils.data import random_split
from torch.utils.data import DataLoader

from utils import dice_loss

from tqdm import tqdm

## Path setting

In [2]:
os.environ['KMP_DUPLICATE_LIB_OK']='True'

image_root = Path('./data/images/')
mask_root = Path('./data/masks/')

model_root = Path('./model/')
result_root = Path('./image/')
log_root = Path('./log/')

## Configurations and random setting

In [3]:
# Data configuration
SCALE_SIZE = (800, 800)
BATCH_SIZE = 4

# GPU resources
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Set random seed for reproducibility
torch.manual_seed(12)
torch.backends.cudnn.deterministic = True

## Data Loading

In [4]:
# Create the dataset
dataset = CCAgTDataset(image_root, mask_root, scale_size=SCALE_SIZE)

# Training, validation and test split
generator = torch.Generator().manual_seed(42)
train_set, valid_set, test_set = random_split(dataset, [0.7, 0.1, 0.2], generator=generator)

# Data loaders
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
valid_loader = DataLoader(valid_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

## Visualize function

In [5]:
def plot_img_and_mask(img, ground, mask, image_name:str):
    fig = plt.figure(figsize=(12, 5))
    plt.title(image_name)
    plt.axis('off')
    plt.subplot(1, 3, 1)
    plt.title('Image')
    plt.imshow(img.permute(1, 2, 0))
    plt.subplot(1, 3, 2)
    plt.title('Ground')
    plt.imshow(ground, vmin=0, vmax=7)
    plt.subplot(1, 3, 3)
    plt.title('Pred')
    plt.imshow(mask, vmin=0, vmax=7)
    plt.savefig(result_root / f'{image_name}.png')
    plt.close()

cmap = plt.get_cmap('Dark2')

def plot_fusion(ax, image, mask, alpha=0.05):
    for c, binary in enumerate((mask == c).astype(np.uint8) for c in range(1, 8)):
        countours, hierarchy = cv2.findContours(binary, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
        ax.imshow(image, cmap='tab10', alpha=alpha)
        color = cmap(c) if alpha < 1 else 'r'
        for i in range(len(countours)):
            ax.plot(countours[i][:, :, 0], countours[i][:, :, 1], c=color, linewidth=1.5)

def plot_inference(imgs, grounds, masks, image_names):
    # Create a 5x6 grid of subplots
    count = len(imgs)
    fig, axs = plt.subplots(6, count, figsize=(20, 4*count))  # Adjust the figsize as per your preference

    # Iterate through each subplot and plot something simple
    for i in range(count):
        img = imgs[i].permute(1, 2, 0).numpy()
        ground = grounds[i].numpy()
        mask = masks[i].numpy()
        name = image_names[i]
        axs[0, i].imshow(img)
        axs[0, i].set_title(name)
        axs[1, i].imshow(ground, cmap='bone', vmin=0, vmax=7)
        plot_fusion(axs[2, i], img, ground, alpha=0.05)
        axs[3, i].imshow(mask, cmap='bone', vmin=0, vmax=7)
        plot_fusion(axs[4, i], img, mask, alpha=0.05)
        plot_fusion(axs[5, i], ground, mask, alpha=1)

        for j in range(6):
            axs[j, i].axis('off')

    # Adjust layout and show the plots
    plt.tight_layout()
    plt.savefig(result_root / f'{image_names[0]}.png')
    plt.close()

def plot_metrics(train_metrics, valid_metrics, metric_name):
    plt.title(metric_name)
    plt.plot(train_metrics, label=f'train {metric_name}')
    plt.plot(valid_metrics, label=f'valid {metric_name}')
    plt.legend()
    plt.savefig(log_root / f'{metric_name}.png')
    plt.close()

def resize_mask(mask, size):
    mask = mask.unsqueeze(0).float()
    mask = F.interpolate(mask, size=size, mode="nearest").squeeze(0)
    return mask.long().contiguous()


## Metrics for evaluation

In [6]:
# Metric: iou, dice, PA, cPA, mPA
def calculate_iou(predicted_mask, ground_truth_mask):
    intersection = torch.logical_and(predicted_mask, ground_truth_mask).sum().item()
    union = torch.logical_or(predicted_mask, ground_truth_mask).sum().item()
    iou = intersection / union if union != 0 else 0
    return iou

def calculate_class_TF(predicted_masks, ground_truth_masks, num_classes, class_TF=[]):
    thresholds = torch.arange(0.5, 1.0, 0.05)

    if len(class_TF) == 0:
        for class_idx in range(num_classes):
            class_thred_TF = []
            for threshold in thresholds:
                class_thred_TF.append({'true_positives': 0, 'false_positives': 0, 'false_negatives': 0})
            class_TF.append(class_thred_TF)

    for class_idx in range(num_classes):
        for tid, threshold in enumerate(thresholds):
            for i in range(len(predicted_masks)):
                pred_mask = (predicted_masks[i] == class_idx).bool()
                gt_mask = (ground_truth_masks[i] == class_idx).bool()

                iou = calculate_iou(pred_mask, gt_mask)
                if iou >= threshold:
                    class_TF[class_idx][tid]['true_positives'] += 1
                else:
                    class_TF[class_idx][tid]['false_positives'] += 1
                    class_TF[class_idx][tid]['false_negatives'] += 1 if iou < 0.5 else 0

    return class_TF

def calculate_mean_iou(predicted_masks, ground_truth_masks, num_classes):
    class_iou = torch.zeros(num_classes)
    for class_idx in range(num_classes):
        class_mask_pred = (predicted_masks == class_idx)
        class_mask_gt = (ground_truth_masks == class_idx)
        class_iou[class_idx] = calculate_iou(class_mask_pred, class_mask_gt)

    mean_iou = class_iou.sum() / (class_iou != 0).sum()  # Calculate mean ignoring classes with IoU = 0
    return mean_iou.item()

def calculate_dice(predicted_mask, ground_truth_mask):
    intersection = torch.logical_and(predicted_mask, ground_truth_mask).sum().item()
    dice = (2. * intersection) / (predicted_mask.sum().item() + ground_truth_mask.sum().item()) if (predicted_mask.sum().item() + ground_truth_mask.sum().item()) != 0 else 0
    return dice

def calculate_dice_per_class(predicted_masks, ground_truth_masks, num_classes):
    dice_scores = []
    smooth = 1e-6  # Smoothing factor to avoid division by zero

    for class_idx in range(num_classes):
        pred = predicted_masks == class_idx
        target = ground_truth_masks == class_idx

        intersection = (pred & target).sum().item()
        union = pred.sum().item() + target.sum().item()

        dice = (2. * intersection + smooth) / (union + smooth)
        dice_scores.append(dice)

    return dice_scores

def calculate_mean_dice(dice_scores):
    return sum(dice_scores) / len(dice_scores)

def calculate_pixel_accuracy(predicted_mask, ground_truth_mask):
    correct_pixels = (predicted_mask == ground_truth_mask).sum().item()
    total_pixels = ground_truth_mask.numel()
    pixel_accuracy = correct_pixels / total_pixels
    return pixel_accuracy

def calculate_class_pixel_accuracy(predicted_masks, ground_truth_masks, num_classes):
    class_pixel_accuracy = torch.zeros(num_classes)
    for class_idx in range(num_classes):
        class_mask = (ground_truth_masks == class_idx)
        correct_pixels = (predicted_masks == ground_truth_masks)[class_mask].sum().item()
        total_pixels = class_mask.sum().item()
        class_pixel_accuracy[class_idx] = correct_pixels / total_pixels if total_pixels != 0 else 0
    return class_pixel_accuracy

def calculate_mean_pixel_accuracy(class_pixel_accuracy):
    return class_pixel_accuracy.mean().item()

## Model 

In [7]:
# Model
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()

        # Encoder path
        self.conv_down1 = DoubleConv(in_channels, 64)
        self.conv_down2 = DoubleConv(64, 128)
        self.conv_down3 = DoubleConv(128, 256)
        self.conv_down4 = DoubleConv(256, 512)
        self.conv_down5 = DoubleConv(512, 1024)  # Additional layer
        
        self.maxpool = nn.MaxPool2d(kernel_size=2)
        
        # Decoder path
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv_up4 = DoubleConv(1024+512, 512)  # Additional layer
        self.conv_up3 = DoubleConv(512+256, 256)
        self.conv_up2 = DoubleConv(256+128, 128)
        self.conv_up1 = DoubleConv(128+64, 64)
        
        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        # Encoder path
        conv1 = self.conv_down1(x)
        x = self.maxpool(conv1)
        
        conv2 = self.conv_down2(x)
        x = self.maxpool(conv2)
        
        conv3 = self.conv_down3(x)
        x = self.maxpool(conv3)
        
        conv4 = self.conv_down4(x)
        x = self.maxpool(conv4)

        # Additional layer in encoder
        x = self.conv_down5(x)
        
        # Decoder path
        x = self.upsample(x)
        x = torch.cat([x, conv4], dim=1)
        x = self.conv_up4(x)
        
        x = self.upsample(x)
        x = torch.cat([x, conv3], dim=1)
        x = self.conv_up3(x)
        
        x = self.upsample(x)
        x = torch.cat([x, conv2], dim=1)
        x = self.conv_up2(x)
        
        x = self.upsample(x)
        x = torch.cat([x, conv1], dim=1)
        x = self.conv_up1(x)
        
        x = self.final_conv(x)
        return x

## Training parameters

In [8]:
# Training
input_channels = 3
output_channels = 8 # number of classes
lr = 1e-5
num_epochs = 100
train_mode = False

model = UNet(input_channels, output_channels).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.RMSprop(model.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9, foreach=True)

## Train

In [9]:
def train_step(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    running_miou = 0.0
    running_mdice = 0.0
    running_pa = 0.0
    running_mpa = 0.0
    processed_data = 0

    i_bar = tqdm(dataloader, desc="Training")
    for batch_idx, data in enumerate(i_bar):
        inputs = data['rescale_img'].to(device)
        labels = data['rescale_mask'].to(device)

        optimizer.zero_grad()

        # with torch.autocast(device_type=device):
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        d_loss = dice_loss(F.softmax(outputs, dim=1).float(),
                            F.one_hot(labels, num_classes=output_channels).permute(0, 3, 1, 2).float(),
                            multiclass=True)
        loss += d_loss
        pred = torch.argmax(outputs, dim=1)
        
        # resize the predicted mask to the original size
        origin_labels = data['mask'].to(device)
        pred = resize_mask(pred, origin_labels.shape[-2:])
        
        miou = calculate_mean_iou(pred, origin_labels, output_channels)
        mdice = calculate_mean_dice(calculate_dice_per_class(pred, origin_labels, output_channels))
        pa = calculate_pixel_accuracy(pred, origin_labels)
        mpa = calculate_mean_pixel_accuracy(calculate_class_pixel_accuracy(pred, origin_labels, output_channels))
        loss.backward()
        optimizer.step()
        # scheduler.step(d_loss)

        running_loss += loss.item() * inputs.size(0)
        running_miou += miou * inputs.size(0)
        running_mdice += mdice * inputs.size(0)
        running_pa += pa * inputs.size(0)
        running_mpa += mpa * inputs.size(0)
        processed_data += inputs.size(0)
        i_bar.set_postfix_str(f"Loss: {loss.item():.4f}| mIoU: {miou:.4f}, mDice: {mdice:.4f}, PA: {pa:.4f}, mPA: {mpa:.4f}")
        # plot_img_and_mask(data['image'][0], origin_labels[0].cpu(), pred[0].cpu(), "Train")

    train_loss = running_loss / processed_data
    train_miou = running_miou / processed_data
    train_mdice = running_mdice / processed_data
    train_pa = running_pa / processed_data
    train_mpa = running_mpa / processed_data

    return train_loss, train_miou, train_mdice, train_pa, train_mpa

def valid_step(model, dataloader, criterion, device, evaluate=False):
    model.eval()
    running_loss = 0.0
    running_miou = 0.0
    running_pa = 0.0
    running_mdice = 0.0
    running_mpa = 0.0
    processed_size = 0

    i_bar = tqdm(dataloader, desc="Validation" if not evaluate else "Testing")
    for batch_idx, data in enumerate(i_bar):
        inputs = data['rescale_img'].to(device)
        labels = data['rescale_mask'].to(device)

        with torch.no_grad():
            # with torch.autocast(device_type=device):
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            d_loss = dice_loss(F.softmax(outputs, dim=1).float(),
                            F.one_hot(labels, num_classes=output_channels).permute(0, 3, 1, 2).float(),
                            multiclass=True)
            loss += d_loss
            pred = torch.argmax(outputs, dim=1)

            # resize the predicted mask to the original size
            origin_labels = data['mask'].to(device)
            pred = resize_mask(pred, origin_labels.shape[-2:])

            miou = calculate_mean_iou(pred, origin_labels, output_channels)
            mdice = calculate_mean_dice(calculate_dice_per_class(pred, origin_labels, output_channels))
            pa = calculate_pixel_accuracy(pred, origin_labels)
            mpa = calculate_mean_pixel_accuracy(calculate_class_pixel_accuracy(pred, origin_labels, output_channels))

        running_loss += loss.item() * inputs.size(0)
        processed_size += inputs.size(0)
        running_miou += miou * inputs.size(0)
        running_mdice += mdice * inputs.size(0)
        running_pa += pa * inputs.size(0)
        running_mpa += mpa * inputs.size(0)

        i_bar.set_postfix_str(f"Loss: {loss.item():.4f}| mIoU: {miou:.4f}, mDice: {mdice:.4f}, PA: {pa:.4f}, mPA: {mpa:.4f}")

        # plot_img_and_mask(data['image'][0], origin_labels[0].cpu(), pred[0].cpu(), "Validation")
        

    valid_loss = running_loss / processed_size
    valid_miou = running_miou / processed_size
    valid_mdice = running_mdice / processed_size
    valid_pa = running_pa / processed_size
    valid_mpa = running_mpa / processed_size

    return valid_loss, valid_miou, valid_mdice, valid_pa, valid_mpa

In [10]:
def train(model, train_loader, valid_loader, criterion, optimizer, device, epochs=10):
    best_loss = float('inf')

    train_losses = []
    valid_losses = []

    train_mious = []
    valid_mious = []
    train_mdices = []
    valid_mdices = []
    train_pas = []
    valid_pas = []
    train_mpas = []
    valid_mpas = []

    for epoch in range(epochs):
        print(f"Epoch {epoch + 1} of {epochs}")
        train_loss, train_miou, train_mdice, train_pa, train_mpa = train_step(model, train_loader, criterion, optimizer, device)
        valid_loss, valid_miou, valid_mdice, valid_pa, valid_mpa = valid_step(model, valid_loader, criterion, device)

        train_losses.append(train_loss)
        valid_losses.append(valid_loss)
        train_mious.append(train_miou)
        valid_mious.append(valid_miou)
        train_mdices.append(train_mdice)
        valid_mdices.append(valid_mdice)
        train_pas.append(train_pa)
        valid_pas.append(valid_pa)
        train_mpas.append(train_mpa)
        valid_mpas.append(valid_mpa)

        print(f"Train Loss: {train_loss:.4f} | Train mIoU: {train_miou:.4f} | Train mDice: {train_mdice:.4f} | Train PA: {train_pa:.4f}, Train mPA: {train_mpa:.4f}")
        print(f"Valid Loss: {valid_loss:.4f} | Valid mIoU: {valid_miou:.4f} | Valid mDice: {valid_mdice:.4f} | Valid PA: {valid_pa:.4f}, Valid mPA: {valid_mpa:.4f}")

        # visualize the metrics
        plot_metrics(train_losses, valid_losses, 'Loss')
        plot_metrics(train_mious, valid_mious, 'mIoU')
        plot_metrics(train_mdices, valid_mdices, 'mDice')
        plot_metrics(train_pas, valid_pas, 'PA')
        plot_metrics(train_mpas, valid_mpas, 'mPA')

        if valid_loss < best_loss:
            best_loss = valid_loss
            torch.save(model.state_dict(), model_root / 'best_model.pth')

def evaluate(model, test_loader, criterion, device):
    test_loss, test_miou, test_mdice, test_pa, test_mpa = valid_step(model, test_loader, criterion, device, evaluate=True)
    print(f"Test Loss: {test_loss:.4f} | Test mIoU: {test_miou:.4f} | Test mDice: {test_mdice:.4f} | Test PA: {test_pa:.4f} | Test mPA: {test_mpa:.4f}")

def inference(model, test_loader, device):
    buffer = {'image': [], 'ground': [], 'pred': [], 'name': []}
    buffer_count = 0

    class_TF = []
    

    model.eval()
    with torch.no_grad():
        for batch_idx, data in enumerate(tqdm(test_loader, desc="Inference")):
            inputs = data['rescale_img'].to(device)
            labels = data['rescale_mask'].to(device)

            outputs = model(inputs)
            pred = torch.argmax(outputs, dim=1)
            origin_labels = data['mask'].to(device)
            pred = resize_mask(pred, origin_labels.shape[-2:])

            class_TF = calculate_class_TF(pred, origin_labels, output_channels, class_TF)

            if buffer_count > 50:
                print("Collecting enough data for inference")
                break

            # Plot that complex table
            for i in range(BATCH_SIZE):
                buffer['image'].append(data['image'][i])
                buffer['ground'].append(origin_labels[i].cpu())
                buffer['pred'].append(pred[i].cpu())
                buffer['name'].append(f"test/{data['name'][i]}")

                buffer_count += 1
                # plot_img_and_mask(data['image'][i], origin_labels[i].cpu(), pred[i].cpu(), f"test/{data['name'][i]}")
                if buffer_count % 5 == 0:
                    plot_inference(buffer['image'], buffer['ground'], buffer['pred'], buffer['name'])
                    buffer = {'image': [], 'ground': [], 'pred': [], 'name': []}

## Main function

In [11]:
print("Dataset size: ", len(dataset))
print("Image size: ", dataset[0]['rescale_mask'].shape)
print("Train set size: ", len(train_set))
print("Valid set size: ", len(valid_set))
print("Test set size: ", len(test_set))

if train_mode:
    train(model, train_loader, valid_loader, criterion, optimizer, device, epochs=num_epochs)

# load the best model
model.load_state_dict(torch.load(model_root / 'best_model.pth'))
evaluate(model, test_loader, criterion, device)
inference(model, test_loader, device)

Dataset size:  9339
Image size:  torch.Size([800, 800])
Train set size:  6538
Valid set size:  934
Test set size:  1867


Testing: 100%|██████████| 467/467 [36:40<00:00,  4.71s/it, Loss: 0.4511| mIoU: 0.8034, mDice: 0.9291, PA: 0.9993, mPA: 0.5855]


Test Loss: 0.5784 | Test mIoU: 0.6733 | Test mDice: 0.5811 | Test PA: 0.9967 | Test mPA: 0.4952


Inference:   3%|▎         | 13/467 [06:57<4:10:51, 33.15s/it]

Collecting enough data for inference


Inference:   3%|▎         | 13/467 [07:03<4:06:22, 32.56s/it]
