### CellOMaps: A Compact Representation for Robust Classification of Lung Adenocarcinoma Growth Patterns

This notebook demonstrates how to create CellOMaps and use them for growth pattern classification

### CellOMaps Construction:

Step 1: Perform nuclie segmentation and classification using Hover-Net. Implementation can be found in the TIA toolbox here: (https://tia-toolbox.readthedocs.io/en/latest/_notebooks/jnb/08-nucleus-instance-segmentation.html)
To insure compatability with this notbook we recommend using pretrained_model="hovernet_fast-pannuke", and inference on 0.5 mpp. 

In [None]:
from CreatingCentroidsMasksFromHoverNet import create_maks
path_to_hoverNet_output = 'Cell_predictions/'  
path_to_cell_masks= 'cellmaps/' 
path_to_slides = 'Data/slides/'
create_maks(path_to_slides, path_to_hoverNet_output, path_to_cell_masks)

### Growth Pattern classification

import necessary packages and define variables

In [None]:
import os
import numpy as np
import torch
from PIL import Image
from sklearn.model_selection import StratifiedShuffleSplit
import torch.nn as nn
import torchvision.models as models
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torch.utils.data as data
import torchvision.transforms as transforms
from sklearn.metrics import auc, roc_curve, f1_score, precision_recall_curve, average_precision_score, roc_auc_score, \
    confusion_matrix, precision_score, recall_score, accuracy_score
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
from tiatoolbox.wsicore import WSIReader
from tiatoolbox.tools import patchextraction
from torch.utils.data import ConcatDataset
from tiatoolbox.annotation.storage import Annotation, SQLiteStore
from shapely.geometry import Polygon
from sklearn.preprocessing import MultiLabelBinarizer
from skmultilearn.model_selection import iterative_train_test_split
from sklearn.preprocessing import label_binarize
import cv2
import csv
from torch.utils.tensorboard import SummaryWriter
from focal_loss.focal_loss import FocalLoss

In [None]:
Image.MAX_IMAGE_PIXELS = None
device = torch.device("cuda:0")
multi_gpu = False


dataPath = 'Data/'
output = 'output'
path_to_centroids_masks = 'cellmaps/'


all_centroid_masks = os.listdir(path_to_centroids_masks)
all_centroid_masks = [os.path.join(path_to_centroids_masks, m) for m in all_centroid_masks]

num_trials = 7
runs = 1
batch_size = 5
nepochs = 50
learning_rate = 1e-5
workers = 12
test_every = 1
patch_size = (448, 448, 3)  # this is calculated at approximately 5x
patch_size1d,_,_ = patch_size  # this is calculated at approximately 5x
mask_percentage = 0.4
# Dialation parameters
kernel_size = 1
dialation_iterations = 2
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2 * kernel_size + 1, 2 * kernel_size + 1))
focal_loss_gamma = 0.7
cellMapChannels = 4

labelDict = {'Solid': 0, 'Acinar': 1, 'Papillary': 2, 'Micropapillary': 3, 'Lepidic': 4, 'Normal': 5}
label_dict = {
    0: "Solid",
    1: "Acinar",
    2: "Papillary",
    3: "Micropapillary",
    4: "Lepidic",
    5: "Normal"
}
numClasses = len(labelDict)

# logging results ...
csv_file_path = os.path.join(output, 'Results.csv')

In [None]:
class TileDataset(data.Dataset):
    def __init__(self, tiles, labels, transform, version):
        self.tiles = tiles
        self.targets = labels
        self.transform = transform
        self.version = version


    def __getitem__(self, index):
        tile = self.tiles[index]
        patch = tile['patch']
        target = self.targets[index]
        img = cv2.dilate(patch,kernel, iterations=dialation_iterations)
        img = Image.fromarray(img)
        if img.width < patch_size1d or img.height < patch_size1d:
            slide = tile['slideName']
            print(f'Small patch found in slide {slide} pattern {label_dict[target]} with size {img.width}x{img.height}')
            img = img.resize((patch_size1d,patch_size1d), Image.BICUBIC)

        img = self.transform(img)  # Used to be before the dialation
        return img, target


    def __len__(self):
        return len(self.targets)

    def shuffle(self):
        indices = torch.randperm(len(self.targets))
        self.tiles = [self.tiles[i] for i in indices]
        self.targets = [self.targets[i] for i in indices]



Helper Methods

In [None]:

def patchExtraction(img, mask, slide_name):
    # Extract patches
    slides_path = os.path.join(dataPath, 'slides')
    slide_path = f'{slides_path}/{slide_name}.svs'
    wsi = WSIReader.open(slide_path)
    slide_dimentions = wsi.slide_dimensions(2 , 'mpp')
    # resize the cell map (img) to insure dimentions compatability
    img_pil = Image.fromarray(img)
    resized_image = img_pil.resize(slide_dimentions, Image.BICUBIC)
    resized_cellMap = np.array(resized_image)

    patchextractor = patchextraction.SlidingWindowPatchExtractor(
        input_img=slide_path,
        patch_size=(patch_size1d, patch_size1d),
        input_mask=mask, resolution=2, units='mpp',
        stride=(patch_size1d, patch_size1d), min_mask_ratio=mask_percentage)

    coordinates = patchextractor.coordinate_list

    # Create an empty list to store the patches
    patches = []
    # Extract patches
    for patch_coor in coordinates:
        patch = resized_cellMap[patch_coor[1]:patch_coor[3], patch_coor[0]:patch_coor[2], :]
        tile = {'x_start': patch_coor[0], 'x_end': patch_coor[2], 'y_start': patch_coor[1], 'y_end': patch_coor[3], 'patch': patch,
                'slideName': slide_name}
        patches.append(tile)

    return patches

def prep_data(masks):
    data = []
    labels = []
    counts_per_class = {'Solid': 0, 'Acinar': 0, 'Papillary': 0, 'Micropapillary': 0, 'Lepidic': 0, 'Normal': 0,
                        'other': 0}

    for maskP in masks:
        mask = os.path.basename(maskP)
        # Extract label
        mask = os.path.splitext(mask)[0]
        label = mask.split('_')[-1]
        mask_name = mask.split('_')[0]

        # find the corresponding cellMaps, stack them, then tile it
        tdata = []
        tlab = []
        img_layers = []
        for centroid_mask in all_centroid_masks:
            centroid_mask_slide = os.path.basename(centroid_mask)
            if centroid_mask_slide.startswith(mask_name):
                if 'Neoplastic'  in centroid_mask_slide  or 'ConnectiveSoftTissue_Centroids' in centroid_mask_slide :
                    img_layers.append(centroid_mask)

        if len(img_layers) == 0:
            print('Slide#: {} has No hoverNet Features ! \n'.format(mask_name))
            continue
        img_layers.sort()  # To insure all instance are in the same order
        images = [Image.open(maskp) for maskp in img_layers]
        images = [im.convert('L') for im in images]
        width, height = images[0].size
        stacked_array = np.empty((height, width, 3), dtype=np.uint8)
        for i, image in enumerate(images):
            stacked_array[:, :, i] = np.array(image)

        mask_gp = Image.open(maskP).convert('L')
        mask_gp = mask_gp.resize(images[0].size)
        image_gp_array = np.array(mask_gp)
        new_mask = image_gp_array
        new_mask[new_mask > 0] = 1

        patches = patchExtraction(stacked_array, new_mask, mask_name)
        for patch in patches:
            tdata.append(patch)
            tlab.append(labelDict[label])

        # Display the number of patches
        print(f"slide {mask_name}: Number of patches: {len(patches)} for class: {label}")
        counts_per_class[label] = counts_per_class[label] + len(patches)
        data.extend(tdata)
        labels.extend(tlab)  # tile ground truth label

    return data, labels

def append_to_csv(trial, epoch, unseen_slides, aucroc, f1, accuracy, precision, recall, pr_macro, pr_micro, acc_per_class, f1_per_class):
    solid, acinar, papillary, micropapillary, lepidic, normal = acc_per_class
    f1_solid, f1_acinar, f1_papillary, f1_micropapillary, f1_lepidic, f1_normal = f1_per_class
    data = [trial, epoch,len(unseen_slides), unseen_slides, aucroc, f1, accuracy, precision, recall, pr_macro, pr_micro, solid, acinar, papillary, micropapillary, lepidic, normal, f1_solid, f1_acinar, f1_papillary, f1_micropapillary, f1_lepidic, f1_normal]

    # Check if the CSV file exists
    file_exists = False
    try:
        with open(csv_file_path, 'r') as file:
            file_exists = True
    except FileNotFoundError:
        file_exists = False

    # Append data to the CSV file
    with open(csv_file_path, 'a', newline='') as file:
        writer = csv.writer(file)
        # Write header row if the file is newly created
        if not file_exists:
            writer.writerow(['Trial', 'epoch', 'count', 'unseen_slides', 'AUC-ROC', 'F1_macro', 'Accuracy', 'Precision', 'Recall', 'AUCPR_macro', 'AUCPR_micro', 'solid','acinar', 'papillary', 'micropapillary','lepidic','normal', 'f1_solid', 'f1_acinar', 'f1_papillary', 'f1_micropapillary', 'f1_lepidic', 'f1_normal'])
        writer.writerow(data)


def stratified_split(exclusions=[]):
    slides_path = os.path.join(dataPath, 'slides')
    masks_path = os.path.join(dataPath, 'masks')
    masks = os.listdir(masks_path)
    img_paths = os.listdir(slides_path)
    labels = []
    required_slides=[]
    img_paths = [os.path.splitext(name)[0] for name in img_paths]

    for slide in img_paths:
        if slide in exclusions:
            continue
        slide_masks = [m for m in masks if m.startswith(slide)]
        slide_masks = [string.split('_')[1].split('.')[0] for string in slide_masks]

        slide_labels = [labelDict[mask] for mask in slide_masks]
        labels.append(slide_labels)
        required_slides.append(slide)

    label_binarizer = MultiLabelBinarizer()
    label_binarizer.fit(labels)
    label_matrix = label_binarizer.transform(labels)
    label_matrix = np.array(label_matrix)

    required_slides = np.array(required_slides)
    img_paths_2d = required_slides.reshape(-1, 1)

    # Shuffle the data to obtain diffrent splits each time
    shuffled_indices = np.random.permutation(len(img_paths_2d))
    shuffled_data = img_paths_2d[shuffled_indices]
    shuffled_labels = label_matrix[shuffled_indices]
    test_size = 0.19

    train_paths, _, test_paths, _ = iterative_train_test_split(shuffled_data, shuffled_labels, test_size=test_size)

    return test_paths

def train_valid_split(x, y, seed=5):
    outer_splitter = StratifiedShuffleSplit(
        n_splits=1, test_size=0.1, random_state=seed
    )
    x = np.array(x)
    y = np.array(y)

    for train_index, test_index in outer_splitter.split(x, y):
        X_train, X_test = x[train_index], x[test_index]
        y_train, y_test = y[train_index], y[test_index]

    return X_train, y_train, X_test, y_test


def inference(run, loader, model, criterion):
    model.eval()

    running_acc = 0.
    running_loss = 0.
    probs = torch.FloatTensor(len(loader.dataset), numClasses)
    preds = torch.FloatTensor(len(loader.dataset))

    with torch.no_grad():
        for i, (inputs, target) in enumerate(loader):
            inputs = inputs.to(device)
            target = torch.tensor(target).to(device)
            output = model(inputs)
            y = F.softmax(output, dim=1)

            loss = criterion(y, target)
            acc = calculate_accuracy(output, target)

            _, pr = torch.max(output, 1)

            preds[i * batch_size:i * batch_size + inputs.size(0)] = pr.detach().clone()
            probs[i * batch_size:i * batch_size + inputs.size(0)] = y.detach().clone()
            running_acc += (acc.item()) * inputs.size(0)
            running_loss += loss.item() * inputs.size(0)
            if i % 100 == 0:
                print('Inference\tEpoch: [{:3d}/{:3d}]\tBatch: [{:3d}/{}]\t Validatoin: Loss: {:.4f}, acc: {:0.2f}%'
                      .format(run + 1, nepochs, i + 1, len(loader), running_loss / ((i + 1) * inputs.size(0)),
                              (100 * running_acc) / ((i + 1) * inputs.size(0))))

    return probs.cpu().numpy(), running_loss / len(loader.dataset), running_acc / len(
        loader.dataset), preds.cpu().numpy()


# cnn training
def train(run, loader, model, criterion, optimizer):
    model.train()

    running_loss = 0.
    running_acc = 0.

    for i, (inputs, target) in enumerate(loader):
        inputs = inputs.to(device)
        target = torch.tensor(target).to(device)
        optimizer.zero_grad()
        output = model(inputs)
        y = F.softmax(output, dim=1)

        if torch.isnan(y).any():
            print(inputs)
            print(y)

        loss = criterion(y, target)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
        acc = calculate_accuracy(output, target)
        running_acc += acc.item() * inputs.size(0)

        if i % 100 == 0:
            print("Train Epoch : [{:3d}/{:3d}] Batch number: {:3d}, Training: Loss: {:.4f}, acc: {:.2f}%".
                  format(run + 1, nepochs, i + 1, running_loss / ((i + 1) * inputs.size(0)),
                         (100 * running_acc) / ((i + 1) * inputs.size(0))))

    return running_loss / len(loader.dataset), running_acc / len(loader.dataset)


def calc_AUCPR(target, probs):
    # Convert y_true to one-hot encoded format
    y_true = label_binarize(target, classes=np.arange(numClasses))

    # Initialize variables to store precision, recall, and average precision for each class
    precision = dict()
    recall = dict()
    average_precision = dict()

    # Compute precision, recall, and average precision for each class
    for i in range(numClasses):
        precision[i], recall[i], _ = precision_recall_curve(y_true[:, i], probs[:, i])
        average_precision[i] = average_precision_score(y_true[:, i], probs[:, i])

    # Compute micro-average precision and recall
    precision["micro"], recall["micro"], _ = precision_recall_curve(y_true.ravel(), probs.ravel())
    average_precision["micro"] = average_precision_score(y_true, probs, average="micro")

    # Compute macro-average precision and recall
    # Here, you can also calculate the mean of average_precision[i] values for each class
    average_precision["macro"] = np.mean(list(average_precision.values()))

    return round(average_precision["macro"], 2), round(average_precision["micro"], 2)



def calc_metrics(target, prediction, probs):
    # calculate the ROC AUC score
    roc_auc = roc_auc_score(target, probs, multi_class='ovo')

    f1 = f1_score(target, prediction, average='macro')
    precision = precision_score(target, prediction, average='macro')
    recall = recall_score(target, prediction, average='macro')

    return f1, roc_auc, precision, recall


def calculate_accuracy(output, target):
    preds = output.max(1, keepdim=True)[1]
    correct = preds.eq(target.view_as(preds)).sum()
    acc = correct.float() / preds.shape[0]
    return acc

def calculate_accuracy_f1_perClass(output, target):
    output = np.array(output)
    target = np.array(target)
    accuracies = []
    f1_scores = []

    for class_index in range(numClasses):
        class_mask = (target == class_index)  # Create a mask for samples of class i
        class_true = target[class_mask]
        class_pred = output[class_mask]
        class_accuracy = accuracy_score(class_true, class_pred)  # Calculate accuracy for class i
        class_f1 = f1_score(class_true, class_pred, average='weighted')  # Calculate F1 score for class i
        accuracies.append(round(class_accuracy, 2))
        f1_scores.append(round(class_f1, 2))

    return accuracies, f1_scores

Training and testing code 

In [None]:
tempout = output

for trial in range(num_trials):

    # Stratify the data
    unseen_test_set_1 = stratified_split()
    unseen_test_set = unseen_test_set_1.flatten()

    # Training data
    masksp = os.path.join(dataPath, 'masks')
    masks = os.listdir(masksp)


    training_masks = np.array([x for x in masks if not any(x.startswith(prefix) for prefix in unseen_test_set)])
    training_masks = [os.path.join(masksp, m) for m in training_masks]

    train_valid_data, train_valid_labels = prep_data(training_masks)
    # splitting into train and validation
    train_data, train_labels, valid_data, valid_labels = train_valid_split(train_valid_data, train_valid_labels)

    # prepare the unseen test set
    test_masks = np.array([x for x in masks if any(x.startswith(prefix) for prefix in unseen_test_set)])
    test_masks = [os.path.join(masksp, m) for m in test_masks]
    test_data, test_labels = prep_data(test_masks)


    print('\nStarting trial {}'.format(trial + 1))
    print('Unseen Test set is:\n')
    print(unseen_test_set)

    trial_output = os.path.join(tempout, 'trial' + str(trial + 1))
    if not os.path.exists(trial_output):
        os.mkdir(trial_output)

    writer = SummaryWriter(log_dir=os.path.join(trial_output, 'logs'))
    for sets in range(runs):  # number of runs for each trial
        output = os.path.join(trial_output, 'best' + str(sets))
        if not os.path.exists(output):
            os.mkdir(output)

        best_auc_v = 0
        best_auc = 0
        best_loss = 100000.
        best_f1_v = 0.
        best_Acc = 0.

        trans = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.RandomRotation(degrees=90),
            transforms.ToTensor()
        ])

        trans_Valid = transforms.Compose([
            transforms.ToTensor()
        ])

        train_dset1 = TileDataset(train_data, train_labels, trans,version=0)
        train_dset2 = TileDataset(train_data, train_labels, trans, version=1)
        train_dset = ConcatDataset([train_dset1, train_dset2])

        train_loader = torch.utils.data.DataLoader(
            train_dset,
            batch_size=batch_size, shuffle=True,
            num_workers=workers, pin_memory=False)

        # validation set
        val_dset = TileDataset(valid_data, valid_labels, trans_Valid, version=0)
        val_loader = torch.utils.data.DataLoader(
            val_dset,
            batch_size=batch_size, shuffle=False,
            num_workers=workers, pin_memory=False)

        # test set
        test_dset = TileDataset(test_data, test_labels, trans_Valid, version=0)
        test_loader = torch.utils.data.DataLoader(
            test_dset,
            batch_size=batch_size, shuffle=False,
            num_workers=workers, pin_memory=False)


        # cnn resNet
        model = models.resnet50(pretrained=True)  # pretrained resnet
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, numClasses)

        model.to(device)

        criterion = FocalLoss(gamma=focal_loss_gamma).to(device)
        optimizer = optim.Adam(model.parameters(), lr=learning_rate)
        cudnn.benchmark = True

        epoch = 0

        for epoch in range(nepochs):
            print('before training')
            loss, acc, = train(epoch, train_loader, model, criterion, optimizer)
            print('Training\tEpoch : [{}]\tLoss: {:0.4f}\tAccuracy: {:3d}'.
                  format(epoch, loss, int(acc * 100)))

            # Log training loss
            writer.add_scalar('Loss/Training', loss, epoch)

            if (epoch + 1) % test_every == 0:
                val_probs, val_loss, val_acc, val_preds = inference(epoch, val_loader, model, criterion)
                print('Valdiation Set: Epoch {:3d}, Loss {:.4f}, Acccuracy {:2d}'.format(epoch + 1, val_loss,
                                                                                         int(val_acc * 100)))

                ###########
                f1_valid, roc_auc_valid, precision_valid, recall_valid = calc_metrics(val_dset.targets, val_preds, val_probs)

                print('Validation: AP-score(f1 micro): {:0.2f}\t'.format(f1_valid))
                print('Validation: ROC-AUC: {:0.2f}\t'.format(roc_auc_valid))

                # Log validation F1 score and accuracy
                writer.add_scalar('Metrics/Validation_F1_Score', f1_valid, epoch)
                writer.add_scalar('Metrics/Validation_Accuracy', val_acc, epoch)


                if roc_auc_valid > best_auc_v:
                    best_f1_v = f1_valid
                    best_auc_v = roc_auc_valid
                    best_loss = val_loss
                    best_acc = val_acc

                    obj = {
                        'epoch': epoch + 1,
                        'state_dict': model.state_dict(),
                        'best_ap_v': best_f1_v,
                        'best_auc_v': best_auc_v,
                        'optimizer': optimizer.state_dict()
                    }
                    torch.save(obj, os.path.join(output, 'checkpoint_best_AUC.pth'))

        writer.close()
        # test set inference after completing the train epochs for one fold
        ch = torch.load(os.path.join(output, 'checkpoint_best_AUC.pth'))
        model.load_state_dict(ch['state_dict'])

        # infernece
        test_probs, test_loss, test_acc, test_preds = inference(epoch, test_loader, model, criterion)
        print('Test Set: Epoch {:3d}, Loss {:.4f}, Acccuracy {:2d}'.format(epoch + 1, test_loss,
                                                                           int(test_acc * 100)))

        f1_test, roc_auc_test, precision_test,  recall_test = calc_metrics(test_dset.targets, test_preds, test_probs)
        print('Test: AP-score(f1 micro): {:0.2f}\t Precision: {:0.2f}\t recall: {:0.2f}\t'.format(f1_test, precision_test, recall_test))
        print('Test: ROC-AUC: {:0.2f}\t'.format(roc_auc_test))

        aucPr_macro, aucPR_micro = calc_AUCPR(test_dset.targets, test_probs)
        acc_per_class, f1_per_class = calculate_accuracy_f1_perClass(test_preds, test_dset.targets)


        append_to_csv(trial, epoch, unseen_test_set, round(roc_auc_test, 2), round(f1_test, 2), round(test_acc, 2),
                      round(precision_test, 2), round(recall_test, 2), aucPr_macro, aucPR_micro, acc_per_class, f1_per_class)

        plt.close()
        conf_matrix = confusion_matrix(test_dset.targets, test_preds)
        class_names = list(labelDict.keys())
        plt.figure(figsize=(10, 8))
        sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', cbar=False,
                    xticklabels=class_names, yticklabels=class_names)
        plt.xlabel('Predicted')
        plt.ylabel('Actual')
        plt.title('Test results for fold ' + str(trial))
        plt.savefig(
            trial_output + '/Confusion_matrix_trial' + str(trial) + 'E' + str(epoch) + 'run' + str(sets) + '.png')
        plt.show()
