<a href="https://colab.research.google.com/github/ChiemBosboom/LEXY_seqmentation/blob/main/Models/U_Net_LEXY.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
LEXY_file_path = ''  # path to LEXY tifs
working_dir = ''  # path to working directory

In [None]:
# Torch and torchvision imports
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor

# Scikit-image, scipy, Scikit-learn imports
from skimage import io, filters, measure, morphology, segmentation, color, feature
from skimage.filters import threshold_li, gaussian, threshold_otsu
from skimage.feature import peak_local_max, canny
from skimage.morphology import remove_small_objects
from skimage.transform import resize, AffineTransform, warp
from skimage.metrics import hausdorff_distance
from scipy.ndimage import distance_transform_edt, rotate
from scipy import ndimage as ndi
from scipy.spatial.distance import directed_hausdorff
from sklearn.metrics import confusion_matrix, recall_score, accuracy_score, f1_score

# miscellaneous imports
import numpy as np
import matplotlib
matplotlib.rcParams["image.interpolation"] = 'none'
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
from PIL import Image
import cv2
import random
from tqdm.auto import tqdm
import os
import albumentations as A
import re

# Device agnostic
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# Unet model
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv_op = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv_op(x)


class DownSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = DoubleConv(in_channels, out_channels)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        down = self.conv(x)
        p = self.pool(down)

        return down, p


class UpSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.ConvTranspose2d(
            in_channels, in_channels // 2, kernel_size=2, stride=2
        )
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        x = torch.cat([x1, x2], 1)
        return self.conv(x)

In [None]:
# classic UNet
class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.down_convolution_1 = DownSample(1, 64)
        self.down_convolution_2 = DownSample(64, 128)
        self.down_convolution_3 = DownSample(128, 256)
        self.down_convolution_4 = DownSample(256, 512)

        self.bottle_neck = DoubleConv(512, 1024)

        self.up_convolution_1 = UpSample(1024, 512)
        self.up_convolution_2 = UpSample(512, 256)
        self.up_convolution_3 = UpSample(256, 128)
        self.up_convolution_4 = UpSample(128, 64)

        self.out = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1)

    def forward(self, x):
        down_1, p1 = self.down_convolution_1(x)
        down_2, p2 = self.down_convolution_2(p1)
        down_3, p3 = self.down_convolution_3(p2)
        down_4, p4 = self.down_convolution_4(p3)

        b = self.bottle_neck(p4)

        up_1 = self.up_convolution_1(b, down_4)
        up_2 = self.up_convolution_2(up_1, down_3)
        up_3 = self.up_convolution_3(up_2, down_2)
        up_4 = self.up_convolution_4(up_3, down_1)

        out = self.out(up_4)
        out = torch.sigmoid(out)

        return out

# UNet based on VGG16 encoder
class VGG16_UNet(nn.Module):
    def __init__(self):
        super().__init__()
        vgg16 = torchvision.models.vgg16(weights='DEFAULT')
        self.pool = nn.MaxPool2d(2, 2)

        # Modify the first convolutional layer to accept 1 channel instead of 3
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),
            *vgg16.features[1:]
        )
        # Initialize weights from pretrained VGG16 model
        self.encoder[0].weight.data = vgg16.features[0].weight.mean(dim=1, keepdim=True)
        self.encoder[0].bias.data = vgg16.features[0].bias.data

        # Define convolutional blocks corresponding to VGG layers
        self.conv_1 = nn.Sequential(*self.encoder[:4])
        self.conv_2 = nn.Sequential(*self.encoder[5:9])
        self.conv_3 = nn.Sequential(*self.encoder[10:16])
        self.conv_4 = nn.Sequential(*self.encoder[17:23])
        self.conv_5 = nn.Sequential(*self.encoder[24:30])

        self.bottle_neck = DoubleConv(512, 1024)

        self.up_convolution_1 = UpSample(512 + 512, 1024)
        self.up_convolution_2 = UpSample(512 + 512, 512)
        self.up_convolution_3 = UpSample(256 + 256, 256)
        self.up_convolution_4 = UpSample(128 + 128, 128)
        self.up_convolution_5 = UpSample(64 + 64, 64)

        self.out = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1)

    def forward(self, x):
        down_1 = self.conv_1(x)                      # 1 > 64
        down_2 = self.conv_2(self.pool(down_1))      # 64 > 128
        down_3 = self.conv_3(self.pool(down_2))      # 128 > 256
        down_4 = self.conv_4(self.pool(down_3))      # 256 > 512
        down_5 = self.conv_5(self.pool(down_4))      # 512 > 512

        b = self.bottle_neck(self.pool(down_5))

        up_1 = self.up_convolution_1(b, down_5)
        up_2 = self.up_convolution_2(up_1, down_4)
        up_3 = self.up_convolution_3(up_2, down_3)
        up_4 = self.up_convolution_4(up_3, down_2)
        up_5 = self.up_convolution_5(up_4, down_1)

        out = self.out(up_5)
        out = torch.sigmoid(out)

        return out

# UNet based on ResNet34 encoder
class ResNet34_UNet(nn.Module):
    def __init__(self):
        super().__init__()
        resnet = torchvision.models.resnet34(weights='DEFAULT')
        self.pool = resnet.maxpool

        # Modify the first convolutional layer to accept 1 channel instead of 3
        self.layer0 = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False),
            resnet.bn1,
            resnet.relu
        )
        # Initialize weights from pretrained ResNet34 model
        self.layer0[0].weight.data = resnet.conv1.weight.mean(dim=1, keepdim=True)

        self.layer1 = resnet.layer1
        self.layer2 = resnet.layer2
        self.layer3 = resnet.layer3
        self.layer4 = resnet.layer4

        self.bottle_neck = DoubleConv(512, 1024)

        self.up_convolution_1 = UpSample(1024, 512)
        self.up_convolution_2 = UpSample(512, 256)
        self.up_convolution_3 = UpSample(256, 128)
        self.up_convolution_4 = UpSample(128, 64)

        self.out = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=32, out_channels=1, kernel_size=1)
        )

    def forward(self, x):
        down_0 = self.layer0(x)                      # 1 > 64
        down_1 = self.layer1(down_0)                 # 64 > 64
        down_2 = self.layer2(down_1)                 # 64 > 128
        down_3 = self.layer3(down_2)                 # 128 > 256
        down_4 = self.layer4(down_3)                 # 256 > 512

        b = self.bottle_neck(self.pool(down_4))

        up_1 = self.up_convolution_1(b, down_4)
        up_2 = self.up_convolution_2(up_1, down_3)
        up_3 = self.up_convolution_3(up_2, down_2)
        up_4 = self.up_convolution_4(up_3, down_1)

        out = self.out(up_4)
        out = torch.sigmoid(out)

        return out

In [None]:
class DiceBCELoss(nn.Module):
    def __init__(self, weight=None, size_average=True, loss_type='combined'):
        super(DiceBCELoss, self).__init__()
        self.loss_type = loss_type  # Initialize with a default loss_type

    def forward(self, inputs, targets, smooth=1, loss_type=None):
        """
        loss_type can be 'combined', 'dice', or 'bce' to specify which loss is returned first.
        If not provided, it will default to the loss_type set during initialization.
        """
        # If loss_type is provided during the forward pass, use it; otherwise, use the initialized loss_type
        if loss_type is None:
            loss_type = self.loss_type

        # Flatten label and prediction tensors
        inputs = inputs.reshape(-1)
        targets = targets.reshape(-1)

        # Dice loss
        intersection = (inputs * targets).sum()
        dice_loss = 1 - (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)

        # Calculate weights
        weights = torch.where(targets == 1, 2.0, 1.0)

        # BCE loss
        BCE = F.binary_cross_entropy(inputs, targets, weight=weights, reduction='mean')

        # Combine Dice loss and BCE loss
        Dice_BCE = BCE + dice_loss

        # Return the losses
        if loss_type == 'combined':
            return Dice_BCE, dice_loss, BCE
        elif loss_type == 'dice':
            return dice_loss, dice_loss, BCE
        elif loss_type == 'bce':
            return BCE, dice_loss, BCE
        else:
            raise ValueError("loss_type must be 'combined', 'dice', or 'bce'")

In [None]:
def train_model(model, train_loader, loss_fn, optimizer):
    # Initialize cumulative loss values
    train_loss = 0
    train_dice_loss = 0
    train_BCE = 0

    # Set the model to training mode
    model.train()

    for img in train_loader:
        # Extract input image (X) and target mask (y)
        X = img[:, 0, :, :].unsqueeze(1)
        y = img[:, 1, :, :]

        # Forward pass: Predict output mask using the model
        y_pred = model(X).squeeze()

        # Calculate the losses
        loss, dice_loss, BCE = loss_fn(y_pred, y)
        train_loss += loss.item()
        train_dice_loss += dice_loss.item()
        train_BCE += BCE.item()

        # Backward pass and optimization step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Calculate average losses over the entire training set
    train_loss /= len(train_loader)
    train_dice_loss /= len(train_loader)
    train_BCE /= len(train_loader)

    return train_loss, train_dice_loss, train_BCE

def test_model(model, test_loader, loss_fn):
    # Initialize cumulative loss values
    test_loss = 0
    test_dice_loss = 0
    test_BCE = 0

    # Set the model to evaluation mode
    model.eval()

    with torch.no_grad():
        for img in test_loader:
            # Extract input image (X_test) and target mask (y_test)
            X_test = img[:, 0, :, :].unsqueeze(1)
            y_test = img[:, 1, :, :]

            # Forward pass: Predict output mask using the model
            test_pred = model(X_test).squeeze()

            # Calculate the losses
            loss, dice_loss, BCE = loss_fn(test_pred, y_test)
            test_loss += loss.item()
            test_dice_loss += dice_loss.item()
            test_BCE += BCE.item()

    # Calculate average losses over the entire test set
    test_loss /= len(test_loader)
    test_dice_loss /= len(test_loader)
    test_BCE /= len(test_loader)

    return test_loss, test_dice_loss, test_BCE

def fit_model(model, train_loader, test_loader, loss_fn, optimizer, epochs=5):
    test_loss_list = []
    train_loss_list = []

    for epoch in tqdm(range(epochs)):
        print(f'Epoch {epoch + 1}/{epochs}')

        # Train the model and get training losses
        train_loss, dice_train, BCE_train = train_model(model, train_loader, loss_fn, optimizer)
        train_loss_list.append(round(train_loss, 4))

        # Evaluate the model on the test set and get test losses
        test_loss, dice_test, BCE_test = test_model(model, test_loader, loss_fn)
        test_loss_list.append(round(test_loss, 4))

        # Print the training and test losses for this epoch
        print(f'Training Loss: {train_loss:.4f} (DICE: {dice_train:.4f}, BCE: {BCE_train:.4f}) | Test Loss: {test_loss:.4f} (DICE: {dice_test:.4f}, BCE: {BCE_test:.4f})')

    return train_loss_list, test_loss_list

def plot_loss(train_loss_list, test_loss_list):
    # Plot the training and test loss curves over epochs
    epochs = range(1, len(train_loss_list) + 1)

    plt.plot(epochs, train_loss_list, 'b', label='Training Loss')
    plt.plot(epochs, test_loss_list, 'r', label='test Loss')
    plt.title('Training and test Loss Over Epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

def evaluate_model(model, data_loader):
    model.eval()  # Set the model to evaluation mode

    # Initialize lists to store per-sample metrics
    sensitivity_list, specificity_list, dice_list, hd_list = [], [], [], []
    predictions = []

    with torch.no_grad():
        for img in data_loader:
            # Extract input image (X_test) and target mask (y_test)
            X_test = img[:, 0, :, :].unsqueeze(1)
            y_test = img[:, 1, :, :].to(torch.int)

            # Forward pass: Predict output mask using the model
            test_pred = model(X_test).squeeze()

            # Binarize the predictions based on a threshold (0.5)
            test_pred = (test_pred >= 0.5).to(torch.int)

            # Move to numpy for processing
            test_pred_np = test_pred.cpu().numpy()
            y_test_np = y_test.cpu().numpy()
            predictions.append(test_pred_np)

            for i in range(X_test.shape[0]):
                # Get the boundary pixels using the Canny edge detector
                edges_y = canny(y_test_np[i].astype(bool))
                edges_pred = canny(test_pred_np[i].astype(bool))

                # Compute the Hausdorff distance in both directions
                d1 = hausdorff_distance(edges_y, edges_pred)
                d2 = hausdorff_distance(edges_pred, edges_y)
                hd_list.append(max(d1, d2))

                # Flatten the predictions and targets for evaluation
                test_pred_flat = test_pred_np[i].flatten()
                y_test_flat = y_test_np[i].flatten()

                # Compute confusion matrix components
                tn, fp, fn, tp = confusion_matrix(y_test_flat, test_pred_flat, labels=[0, 1]).ravel()

                # Calculate and store metrics for this batch
                sensitivity_list.append(recall_score(y_test_flat, test_pred_flat))
                specificity_list.append(tn / (tn + fp) if (tn + fp) != 0 else 0)
                dice_list.append(f1_score(y_test_flat, test_pred_flat))

    # ignore any inf in hd
    hd_array = np.array(hd_list)[np.isfinite(hd_list)]

    # Calculate mean and standard deviation for each metric
    avg_sensitivity = np.mean(sensitivity_list)
    std_sensitivity = np.std(sensitivity_list)

    avg_specificity = np.mean(specificity_list)
    std_specificity = np.std(specificity_list)

    avg_dice_coefficient = np.mean(dice_list)
    std_dice_coefficient = np.std(dice_list)

    avg_hausdorff_distance = np.mean(hd_array)
    std_hausdorff_distance = np.std(hd_array)

    return [[avg_sensitivity, std_sensitivity],
            [avg_specificity, std_specificity],
            [avg_dice_coefficient, std_dice_coefficient],
            [avg_hausdorff_distance, std_hausdorff_distance]], predictions

In [None]:
# Define the transformations for data augmentation
train_transforms = transforms.Compose([
    transforms.RandomCrop(size=(512, 512)),
    transforms.Resize((128, 128), interpolation=Image.NEAREST),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(degrees=random.choice([(0, 0), (90, 90), (180, 180), (270, 270)])),
])

test_transforms = transforms.Compose([
    transforms.CenterCrop(size=(512, 512)),
    transforms.Resize((128, 128), interpolation=Image.NEAREST),
])

class LEXY(Dataset):
    def __init__(self, data, transform=None):

        # Convert numpy arrays to tensors
        self.data = [torch.tensor(d, dtype=torch.float32).to(device) for d in data]
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        # Extract image and mask from the tensor
        sample = self.data[index]

        if self.transform:
            sample = self.transform(sample)

        return sample

In [None]:
def extract_number(s):
    match = re.search(r'[a-zA-Z](\d{1,2})', s)
    return int(match.group(1)) if match else float('inf')

def load_images(file_path):
    filenames = sorted(next(os.walk(file_path), (None, None, []))[2], key=extract_number)
    images_dict = {}

    for fname in filenames:
        image = io.imread(os.path.join(file_path, fname))
        cell_id = fname.split('_')[4]
        if cell_id not in images_dict:
            images_dict[cell_id] = []
        images_dict[cell_id].append(image.astype(np.float32))

    return images_dict

def process_images(images):
    for i in range(len(images)):
        if images[i][1, 0, 0] == 255:
            images[i] = np.stack((images[i][0, :, :], 255 - images[i][1, :, :]), axis=0)

        if (images[i][1, 0, 0] == 0) & (images[i][1, 1, 1] == 255):
            images[i] = np.stack((images[i][0, :, :], 255 - images[i][1, :, :]), axis=0)

            images[i][1, :, 0] = 0
            images[i][1, 0, :] = 0
            images[i][1, :, images[i].shape[2] - 1] = 0
            images[i][1, images[i].shape[1] - 1, :] = 0

    return images

def min_max_normalize(image):
    # Find the minimum and maximum pixel values in the image
    min_val = np.min(image)
    max_val = np.max(image)

    # Apply the Min-Max normalization
    normalized_image = (image - min_val) / (max_val - min_val)

    return normalized_image

def adjust_background_mean(images, desired_mean=0.01):

    adjusted_images = []
    for img in images:

        # extract cell and mask image
        cell = img[0]
        mask = img[1]

        # get background mask
        bg_mask = cell < threshold_li(cell)

        # normalize images
        cell = min_max_normalize(cell)
        mask = mask / 255.0

        # Calculate the current mean of the background
        current_mean = np.mean(cell[bg_mask])

        # Compute the shift required to reach the desired mean
        shift = desired_mean / current_mean

        # Apply the shift to the background pixels
        adjusted_image = cell * shift

        # add adjusted cell to mask
        adjusted_image = np.stack((adjusted_image, mask), axis=0)
        adjusted_images.append(adjusted_image)

    return adjusted_images

In [None]:
# Load the images
images_dict = load_images(LEXY_file_path)

# Set up cross-validation
cell_ids = sorted(images_dict.keys())
results = {cell_id: [] for cell_id in cell_ids}

for cell in tqdm(cell_ids, desc="preprocessing images"):
    # Process each dataset
    images = process_images(images_dict[cell])
    images = adjust_background_mean(images)
    images_dict[cell] = images

for test_cell in tqdm(cell_ids, desc="cross validation"):
    images_train = []
    images_test = images_dict[test_cell]

    for train_cell in cell_ids:
        if train_cell != test_cell:
            images_train.extend(images_dict[train_cell])

    # Use LEXY dataset with transformations
    train_dataset = LEXY(data=images_train, transform=train_transforms)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

    test_dataset = LEXY(data=images_test, transform=test_transforms)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    # model initialization
    loss_type = '' # loss_type must be 'combined', 'dice', or 'bce'
    loss_fn = DiceBCELoss(loss_type=loss_type)
    model = UNet().to(device)  # model_name must be 'UNet', 'VGG16_UNet', or 'ResNet34_UNet'
    pretrained_path = ''  # optional path to pretrained model state_dict

    if pretrained_path != '':
        pretrained = True
        model.load_state_dict(torch.load(pretrained_path))
    else:
        pretrained = False

    optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0001)

    # train model
    train_loss_list, test_loss_list = fit_model(model, train_loader, test_loader, loss_fn, optimizer, epochs=10)
    plot_loss(train_loss_list, test_loss_list)

    # test model
    metrics, predictions = evaluate_model(model, test_loader)
    metrics.extend([train_loss_list, test_loss_list, predictions])
    results[test_cell] = metrics

In [None]:
#plot results
cell_id = 'C10'

predictions = results[cell_id][6][0]
cell_images = images_dict[cell_id]
num_images = len(predictions)

# Create subplots
fig, axes = plt.subplots(nrows=num_images, ncols=3, figsize=(15, 5 * num_images))

for i in range(num_images):
    cell_image = cell_images[i][0]  # Extract cell image
    ground_truth = cell_images[i][1]  # Extract ground truth

    # Plot cell image
    axes[i, 0].imshow(cell_image, cmap='gray')
    axes[i, 0].set_title(f'Cell Image {i+1}')
    axes[i, 0].axis('off')

    # Plot ground truth
    axes[i, 1].imshow(ground_truth, cmap='gray')
    axes[i, 1].set_title(f'Ground Truth {i+1}')
    axes[i, 1].axis('off')

    # Plot prediction
    axes[i, 2].imshow(predictions[i], cmap='gray')
    axes[i, 2].set_title(f'Prediction {i+1}')
    axes[i, 2].axis('off')

plt.tight_layout()
plt.show()

In [None]:
# save results
results_file = os.path.join(working_dir, f"{'pretrained_' if pretrained else ''}LEXY_results_{model.__class__.__name__}_{loss_type}.txt")
model_file = os.path.join(working_dir, f"{'pretrained_' if pretrained else ''}LEXY_model_{model.__class__.__name__}_{loss_type}.pth")

with open(results_file, 'w') as file:
    for cell in results:
        stats = results[cell]
        file.write(f'results for cell {cell}:\n'
                   f'sensitivity: {stats[0]}\n'
                   f'specificity: {stats[1]}\n'
                   f'dice coefficient: {stats[2]}\n'
                   f'hausdorff distance: {stats[3]}\n'
                   f'train loss: {stats[4]}\n'
                   f'test loss: {stats[5]}\n\n')

torch.save(model.state_dict(), model_file)

In [None]:
# Initialize accumulators for the means
total_dice_coefficient = 0
total_hausdorff_distance = 0

num_cells = len(results)

# Calculate sums for mean calculations
for cell in results:
    stats = results[cell]
    total_dice_coefficient += stats[2][0]
    total_hausdorff_distance += stats[3][0]

# Calculate the means outside of the file-writing process
mean_dice_coefficient = total_dice_coefficient / num_cells
mean_hausdorff_distance = total_hausdorff_distance / num_cells

# Print the means
print(f"Mean Dice Coefficient: {mean_dice_coefficient}")
print(f"Mean Hausdorff Distance: {mean_hausdorff_distance}")