# Astrocyte Segmentation Ensemble/Non-Ensemble

In [None]:

# !pip install pip install segmentation-models-pytorch pytorch-lightning

In [None]:
import torch
import albumentations as albu
from PIL import Image
import numpy as np
import os
import cv2
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.encoders import get_preprocessing_fn
import pytorch_lightning as pl
import matplotlib.pyplot as plt

np.random.seed(0)

In [None]:
model_list = [
    {
        "architecture": "unet++",
        "encoder_name": "vgg16"
    },
    {
        "architecture": "manet",
        "encoder_name": "resnet152"
    },
    {
        "architecture": "unet++",
        "encoder_name": "resnet152"
    }
]


# please set this to the folder containing the image and the label

IMAGE_FOLDER_PATH = "/data2/SrikanthData/ResearchWork2/Latest_build/20221219-unet/Final_Models/Img_data/unetRGBData/singles"
MASK_FOLDER_PATH = "/data2/SrikanthData/ResearchWork2/Latest_build/20221219-unet/Final_Models/Img_data/unetRGBData/masks_3"

SIZE = 192

class Types:
    BACKGROUND = 0
    BRANCH = 1
    BODY = 2

CLASS_MAP = {
    0: Types.BACKGROUND,
    1: Types.BRANCH,
    2: Types.BODY,

    # you can add more here, like if you want to treat body as background, set 2: Types.BACKGROUND
    # you can set different color for different branches, by: 3: Types.BRANCH, 4: Types.BRANCH, etc.
}

N_CLASSES = len(set(CLASS_MAP.values()))
COLOR_MAP = {
    Types.BACKGROUND: (0, 0, 0),
    Types.BRANCH: (255, 255, 255),
    Types.BODY: (0, 255, 0),
}


# Define augmentation function for training and validation

In [None]:
def get_training_augmentation():
    """
    Most of the images have size closest to SIZE, which is a divisible by 32 (required by the model),
    therefore we will set the final size of input image to SIZE
    """
    train_transform = [
        albu.Resize(SIZE, SIZE),

        albu.HorizontalFlip(p=0.5),

        albu.ShiftScaleRotate(
            scale_limit=0.2, rotate_limit=0, shift_limit=0.1, p=1, border_mode=0,
            value=0, mask_value=0
        ),
        albu.PadIfNeeded(min_height=SIZE, min_width=SIZE, always_apply=True, border_mode=0,
            value=0, mask_value=0
        ),
        albu.RandomCrop(height=SIZE, width=SIZE, always_apply=True),

        albu.IAAAdditiveGaussianNoise(p=0.1),
        albu.IAAPerspective(p=0.3),

        albu.OneOf(
            [
                albu.CLAHE(p=1),
                albu.RandomBrightness(p=1),
                albu.RandomGamma(p=1),
            ],
            p=0.9,
        ),

        albu.OneOf(
            [
                albu.IAASharpen(p=1),
                albu.Blur(blur_limit=3, p=1),
                albu.MotionBlur(blur_limit=3, p=1),
            ],
            p=0.9,
        ),

        albu.OneOf(
            [
                albu.RandomContrast(p=1),
                albu.HueSaturationValue(p=1),
            ],
            p=0.9,
        ),
    ]
    return albu.Compose(train_transform)


def get_validation_augmentation():
    """
    Most of the images have size closest to 128, which is a divisible by 32 (required by the model),
    therefore we will set the final size of input image to 128

    No augmentation will be done for validation data except for resizing (from the center) to 128
    """
    test_transform = [
        albu.Resize(SIZE, SIZE)
    ]
    return albu.Compose(test_transform)


def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1)


def get_preprocessing(preprocessing_fn):
    """Construct preprocessing transform
    
    Args:
        preprocessing_fn (callable): data normalization function 
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose
    
    """
    
    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)

# Define the Dataset and Dataloader

In [None]:
class SegmentationDataset(torch.utils.data.Dataset):
    def __init__(self, name_list, augmentation=None,  preprocessing=None):
        self.name_list = name_list
        self.augmentation = augmentation
        self.preprocessing = preprocessing

    def __len__(self):
        """ function required by torch.utils.data.Dataset """
        return len(self.name_list)

    def load_image_rgb(self, name):
        """load image from the image folder """
        image = Image.open(os.path.join(IMAGE_FOLDER_PATH, f"{name}.jpg"))
        image = np.expand_dims(np.array(image), axis=-1)
        image = np.repeat(image, repeats=3, axis=-1)
        return image

    def load_image_rgb_v2(self, name):
        """load image from the image folder """
        image = Image.open(os.path.join(IMAGE_FOLDER_PATH, f"{name}.tif"))
        image = np.array(image)
        return image

    def load_mask(self, name):
        """ load the mask from the image folder """
        mask = Image.open(os.path.join(MASK_FOLDER_PATH, f"{name}.tif")) # this returns a 1 channel image
        mask = np.array(mask)
        # map the different color of branch to 1 :
        for key, value in CLASS_MAP.items():
            mask[mask == key] = value

        mask = np.expand_dims(mask, axis=-1).astype(np.int32)

        return mask

    def load_mask_v2(self, name):
        """ load the mask from the image folder """
        mask = Image.open(os.path.join(MASK_FOLDER_PATH, f"{name}.tif")) # this returns a 3 channel image
        mask = np.array(mask)            # [192, 192, 3]  3 colors
        # color less than 255 will be set to 255
        mask[mask < 255] = 0                     

        output_mask = np.zeros(mask.shape[:2])  # [192, 192]

        # cell body
        output_mask[np.all(mask == (255, 0, 255), axis=-1)] = Types.BODY

        # cell branch
        output_mask[np.all(mask == (255, 255, 255), axis=-1)] = Types.BRANCH

        output_mask = np.expand_dims(output_mask, axis=-1).astype(np.int32)            # [192, 192, 1]

        return output_mask

    def __getitem__(self, idx):
        """ function required by torch.utils.data.Dataset """

        # load the image and the mask
        image = self.load_image_rgb_v2(self.name_list[idx])
        mask = self.load_mask_v2(self.name_list[idx])

        # transform the image and mask
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask  = sample["image"], sample["mask"]

        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample["image"], sample["mask"]

        result = {
            "name": self.name_list[idx],
            "image": image, 
            "mask": mask
        }
        return result


def create_data_loaders(preprocessing_fn=None):
    """
    preprocessing_fn: this is specific to different encoder
    This function create the dataloaders for train and test dataset
    Assuming that the image and mask have the same name, only different in extensions
    """
    # images name without extension
    image_names = [name.rsplit(".", 1)[0] for name in os.listdir(IMAGE_FOLDER_PATH)]
    image_names = [name for name in image_names if len(name) != 0]


    mask_names = [name.rsplit(".", 1)[0] for name in os.listdir(MASK_FOLDER_PATH)]
    mask_names = [name for name in mask_names if len(name) != 0]


    assert len(image_names) == len(mask_names)
    # image that has mask
    image_with_mask_names = sorted(set(image_names).intersection(set(mask_names)))

    # split 80% for train, 20% for testing
    np.random.seed(0)
    mask = np.arange(len(image_with_mask_names))
    np.random.shuffle(mask)
    n_train = int(0.8*len(mask))
    train_images = [image_with_mask_names[i] for i in mask[:n_train]]
    test_images = [image_with_mask_names[i] for i in mask[n_train:]]

    datasets = {
        "train": SegmentationDataset(
            name_list=train_images,
            augmentation=get_training_augmentation(),
            preprocessing=get_preprocessing(preprocessing_fn) if preprocessing_fn is not None else None
        ),
        "val": SegmentationDataset(
            name_list=test_images[:35],
            augmentation=get_validation_augmentation(),
            preprocessing=get_preprocessing(preprocessing_fn) if preprocessing_fn is not None else None
        ),
         "test": SegmentationDataset(
            name_list=test_images[35:]
        )
    }

    data_loaders = {
        "train": torch.utils.data.DataLoader(datasets["train"], batch_size=4, shuffle=True),
        "val": torch.utils.data.DataLoader(datasets["val"], batch_size=4, shuffle=False),
        "test": torch.utils.data.DataLoader(datasets["test"], batch_size=4, shuffle=False)
    }

    return datasets, data_loaders

# Visualize example of train data and validation data
## Some training examples

In [None]:
datasets, dataloaders = create_data_loaders(None)

In [None]:
def inverse_process(processed_image):
    std = np.array([[[0.229, 0.224, 0.225]]])
    mean = np.array([[[0.485, 0.456, 0.406]]])
    processed_image = (processed_image.transpose(1, 2, 0)*std + mean)*255
    processed_image = processed_image.astype(np.uint8)

    return processed_image

def get_colored_mask(processed_mask):
    h, w, _ = processed_mask.shape

    colored_mask = np.zeros((h, w, 3))
    mask = processed_mask[:, :, 0]
    colored_mask[mask == Types.BACKGROUND] = COLOR_MAP[Types.BACKGROUND]
    colored_mask[mask == Types.BRANCH] = COLOR_MAP[Types.BRANCH]
    colored_mask[mask == Types.BODY] = COLOR_MAP[Types.BODY]
    colored_mask = colored_mask.astype(np.uint8)
    return colored_mask

In [None]:
plt.figure(figsize=(8, 12))
for count, i in enumerate(range(3)):
   
    one_example = datasets["train"][i]
    
    plt.subplot(3, 2, 2*count + 1)
    plt.title("training image")
    plt.imshow(one_example['image'])
    plt.axis("off")
    plt.subplot(3, 2, 2*count + 2)
    plt.title("training label")
    plt.imshow(get_colored_mask(one_example["mask"]))
    plt.axis("off")
    
plt.savefig("training_examples.jpg")

Some validation example (without augmentation, only resizing)

In [None]:
plt.figure(figsize=(8, 12))
for count, i in enumerate(range(3)):

    one_example = datasets["val"][i]
    
    plt.subplot(3, 2, 2*count + 1)
    plt.title("validation image")
    plt.imshow(one_example['image'])
    plt.axis("off")
    plt.subplot(3, 2, 2*count + 2)
    plt.title("validation label")
    plt.imshow(get_colored_mask(one_example["mask"]))
    plt.axis("off")

plt.savefig("validation_examples.jpg")

# Create the model

In [None]:
import copy
class SegmentationCell(pl.LightningModule):

    def __init__(self, architecture="unet", encoder_name=""):
        super().__init__()

        # TODO: add more architecture if you like
        if architecture == "unet":
            self.model = smp.Unet(
                encoder_name=encoder_name,        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
                encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
                in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
                classes=N_CLASSES,                      # model output channels (number of classes in your dataset)
            )
        elif architecture == "unet++":
            self.model = smp.UnetPlusPlus(
                encoder_name=encoder_name,        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
                encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
                in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
                classes=N_CLASSES,                      # model output channels (number of classes in your dataset)
            )
        elif architecture == "manet":
            self.model = smp.MAnet(
                encoder_name=encoder_name,        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
                encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
                in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
                classes=N_CLASSES,                      # model output channels (number of classes in your dataset)
            ) 

        self.losses = [
            ("jaccard", 0.5, smp.losses.JaccardLoss(mode="multiclass", from_logits=True)),
            ("dice", 0.5, smp.losses.DiceLoss(mode="multiclass", from_logits=True))
        ]

        self.collection = []


    def forward(self, x):
        return self.model(x.float())

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
        return optimizer

    def training_step(self, train_batch, batch_idx):
        x, masks = train_batch["image"], train_batch["mask"]
        masks = masks.long()
        logits = self.forward(x)

        total_loss = 0
        logs = {}
        for loss_name, weight, loss in self.losses:
            ls_mask = loss(logits, masks)
            total_loss += weight * ls_mask
            logs[f"train_mask_{loss_name}"] = ls_mask

        logs["train_loss"] = total_loss
        self.log("train_loss", total_loss,  on_epoch=True, prog_bar=True, logger=True)

        return {"loss": total_loss, "log": logs}

    def validation_step(self, batch, idx):
        x, masks = batch["image"], batch["mask"]

        logits = self.forward(x)
        masks = masks.long()

        total_loss = 0
        logs = {}
        for loss_name, weight, loss in self.losses:
            ls_mask = loss(logits, masks)
            total_loss += weight * ls_mask
            logs[f"val_mask_{loss_name}"] = ls_mask

        logs["val_loss"] = total_loss
        self.log("val_loss", total_loss, on_epoch=True, prog_bar=True, logger=True)
        return {"loss": total_loss, "log": logs}


class MetricsCallback(pl.callbacks.Callback):
    """PyTorch Lightning metric callback."""

    def __init__(self):
        super().__init__()
        self.metrics = []

    def on_validation_epoch_end(self, trainer, pl_module):
        each_me = copy.deepcopy(trainer.callback_metrics)
        self.metrics.append(each_me)

# Train model 1

In [None]:
architecture_name = model_list[0]["architecture"]
encoder_name = model_list[0]["encoder_name"]
preprocessing_fn = get_preprocessing_fn(encoder_name, pretrained='imagenet')
datasets, dataloaders = create_data_loaders(preprocessing_fn=preprocessing_fn)

model_0 = SegmentationCell(architecture_name, encoder_name)

early_stopping_0 = pl.callbacks.early_stopping.EarlyStopping(monitor="val_loss", patience=75, mode="min")
checkpoint_0 = pl.callbacks.model_checkpoint.ModelCheckpoint(monitor="val_loss", mode="min")
metrics_0 = MetricsCallback()

trainer = pl.Trainer(
    devices=1,
    accelerator='auto',
    callbacks=[early_stopping_0, checkpoint_0, metrics_0],
    limit_train_batches=1.0, 
    max_epochs=300,
    log_every_n_steps=1)
trainer.fit(model_0, dataloaders["train"], dataloaders["val"])

# load the best weight for the the model
model_0 = SegmentationCell.load_from_checkpoint(checkpoint_0.best_model_path, architecture=architecture_name, encoder_name=encoder_name)

In [None]:
def plot_metrics(callback_metrics, name):
    """
    This function plots the metrics
    """
    data = [
        [item['train_loss'], item['val_loss']]
        for item in callback_metrics.metrics
        if 'train_loss' in item and 'val_loss' in item
    ]
    x = np.arange(len(data))
    train_loss = [item[0].cpu() for item in data]
    val_loss = [item[1].cpu() for item in data]

    plt.title(name)
    plt.plot(x, train_loss, label='train_loss')
    plt.plot(x, val_loss, label='val_loss')

    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid()

    plt.show()

plot_metrics(metrics_0, f"{architecture_name} - {encoder_name}")

# Train model 2

In [None]:
architecture_name = model_list[1]["architecture"]
encoder_name = model_list[1]["encoder_name"]
preprocessing_fn = get_preprocessing_fn(encoder_name, pretrained='imagenet')
datasets, dataloaders = create_data_loaders(preprocessing_fn=preprocessing_fn)

model_1 = SegmentationCell(architecture_name, encoder_name)

early_stopping_1 = pl.callbacks.early_stopping.EarlyStopping(monitor="val_loss", patience=75, mode="min")
checkpoint_1 = pl.callbacks.model_checkpoint.ModelCheckpoint(monitor="val_loss", mode="min")
metrics_1 = MetricsCallback()

trainer = pl.Trainer(
        devices=1,
    accelerator='auto',
    callbacks=[early_stopping_1, checkpoint_1, metrics_1],
    limit_train_batches=1.0, 
    max_epochs=300,
    log_every_n_steps=1)
trainer.fit(model_1, dataloaders["train"], dataloaders["val"])

# load the best weight for the the model
model_1 = SegmentationCell.load_from_checkpoint(checkpoint_1.best_model_path, architecture=architecture_name, encoder_name=encoder_name)

In [None]:
plot_metrics(metrics_1, f"{architecture_name} - {encoder_name}")

# Train model 3

In [None]:
architecture_name = model_list[2]["architecture"]
encoder_name = model_list[2]["encoder_name"]
preprocessing_fn = get_preprocessing_fn(encoder_name, pretrained='imagenet')
datasets, dataloaders = create_data_loaders(preprocessing_fn=preprocessing_fn)

model_2 = SegmentationCell(architecture_name, encoder_name)

early_stopping_2 = pl.callbacks.early_stopping.EarlyStopping(monitor="val_loss", patience=75, mode="min")
checkpoint_2 = pl.callbacks.model_checkpoint.ModelCheckpoint(monitor="val_loss", mode="min")
metrics_2 = MetricsCallback()

trainer = pl.Trainer(
    devices=1,
    accelerator='auto',
    callbacks=[early_stopping_2, checkpoint_2, metrics_2],
    limit_train_batches=1.0, 
    max_epochs=300,
    log_every_n_steps=1)
trainer.fit(model_2, dataloaders["train"], dataloaders["val"])

# load the best weight for the the model
model_2 = SegmentationCell.load_from_checkpoint(checkpoint_2.best_model_path, architecture=architecture_name, encoder_name=encoder_name)

In [None]:
plot_metrics(metrics_2, f"{architecture_name} - {encoder_name}")

# Create ensemble model

In [None]:
def ensemble_prediction(image_batch, *models):

    with torch.no_grad():
        for model in models:
            model.eval()
        
        logits = [model(image_batch) for model in models] 
        logits = [torch.nn.functional.softmax(logit, dim=1).cpu() for logit in logits]
        logits = np.concatenate([logit[:, :, np.newaxis, :, :] for logit in logits], axis=2)

        logits = np.mean(logits, axis=2)

    return logits

# Visualize the results

In [None]:
batch = next(iter(dataloaders["test"]))
logits = ensemble_prediction(batch["image"], model_0, model_1, model_2)

plt.figure(figsize=(16, 24))
for i, (name, image, gt_mask, pr_mask) in enumerate(zip(batch["name"], batch["image"], batch["mask"], logits)):
    
    image = inverse_process(image.numpy())
    plt.subplot(len(batch["name"]), 3, i*3 + 1)
    plt.imshow(image)
    plt.title(f"Image {name}")
    plt.axis("off")

    plt.subplot(len(batch["name"]), 3, i*3 + 2)
    gt_mask = gt_mask.permute(1, 2, 0).numpy()
    gt_mask = get_colored_mask(gt_mask) 
    plt.imshow(gt_mask)
    plt.title(f"Ground truth")
    plt.axis("off")

    plt.subplot(len(batch["name"]), 3, i*3 + 3)
    pr_mask = np.transpose(pr_mask, (1, 2, 0)).argmax(axis=-1)
    pr_mask = np.expand_dims(pr_mask, axis=-1)
    pr_mask = get_colored_mask(pr_mask)
    plt.imshow(pr_mask) 
    plt.title(f"Prediction")
    plt.axis("off")
    if i >= 4:
        break


# Calculate Dice score

In [None]:
from segmentation_models_pytorch.losses.dice import DiceLoss
from segmentation_models_pytorch.losses.jaccard import JaccardLoss


dice_loss = DiceLoss(
    mode='multiclass',
    from_logits=False,
)

jaccard_loss = JaccardLoss(
    mode='multiclass',
    from_logits=False
)

def calculate_dice():
    """
    function to calculate the DICE score for test dataset
    """

    masks = np.zeros
    dice_scores = []
    jaccard_scores = []

    # -----------------------

    for batch in dataloaders["test"]:
        logits = ensemble_prediction(batch["image"], model_0, model_1, model_2)
        # loop through all images in the test dataset
        for j, (name, image, y_true, y_pred) in enumerate(
            zip(batch["name"], batch["image"], batch["mask"], logits)
        ):
  
            y_pred = np.argmax(y_pred, axis=0)
            y_pred = np.expand_dims(y_pred, axis=0)

        
            # 1-hot, then change to (n, c, h, w)
            y_pred = np.eye(3)[y_pred]
            y_pred = np.transpose(y_pred, (0, 3, 1, 2))
    
            y_true = y_true.long()
    
            # calculate dice and jaccard
            dice_scores.append(
                1 - dice_loss(torch.from_numpy(y_pred), y_true).numpy()
            )
            jaccard_scores.append(
                1 - jaccard_loss(torch.from_numpy(y_pred), y_true).numpy()
            )
 
    print("\nDice Score: ", np.mean(dice_scores))
    print("\nJaccard Score: ", np.mean(jaccard_scores))
    return np.mean(dice_scores)
        
dice = calculate_dice()