In [11]:
import torch
import cv2
import os
import glob
import numpy as np
from celldet_utils import output2file
from torchvision.utils import save_image
import pytorch_lightning as pl
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.encoders import get_preprocessing_fn

# Load unet model for inference

In [12]:
preprocessing_fn = get_preprocessing_fn('resnet50', pretrained='imagenet')


SIZE = 192
N_INPUT_CHANNEL = 3
N_CLASSES = 3


class NucleiTypes:
    BACKGROUND = 0
    NUCLEUS = 1
    BORDER = 2


NUCLEI_CLASS_MAP = {
    0: NucleiTypes.BACKGROUND,
    1: NucleiTypes.NUCLEUS,
    2: NucleiTypes.BORDER
}


NUCLEI_CLASSES = len(set(NUCLEI_CLASS_MAP.values()))


# Define the path to the segmentation model
nuclei_model_list = [
    {
        "architecture": "unet",
        "encoder_name": "resnet34",
        "weight": "lightning_logs_astro/version_0/checkpoints/epoch=45-step=736.ckpt"
    },
    {
        "architecture": "unet++",
        "encoder_name": "mobilenet_v2",
        "weight": "lightning_logs_astro/version_1/checkpoints/epoch=64-step=1040.ckpt"
    },
    {
        "architecture": "manet",
        "encoder_name": "efficientnet-b0",
        "weight": "lightning_logs_astro/version_2/checkpoints/epoch=47-step=768.ckpt"
    }
]


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class SegmentationCell(pl.LightningModule):
    def __init__(self, architecture="unet", encoder_name="resnet34"):
        super().__init__()

        # TODO: add more architecture if you like
        if architecture == "unet":
            self.model = smp.Unet(
                encoder_name=encoder_name,        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
                encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
                in_channels=N_INPUT_CHANNEL,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
                classes=NUCLEI_CLASSES,                      # model output channels (number of classes in your dataset)
            )
        elif architecture == "unet++":
            self.model = smp.UnetPlusPlus(
                encoder_name=encoder_name,        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
                encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
                in_channels=N_INPUT_CHANNEL,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
                classes=NUCLEI_CLASSES,                      # model output channels (number of classes in your dataset)
            )
        elif architecture == "manet":
            self.model = smp.MAnet(
                encoder_name=encoder_name,        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
                encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
                in_channels=N_INPUT_CHANNEL,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
                classes=NUCLEI_CLASSES,                      # model output channels (number of classes in your dataset)
            )

    def forward(self, x):
        return self.model(x.float())


# load nuclei models (this can be 1 or 3 models)
# Here, I only do 1 model for fast trial
nuclei_models = [
    SegmentationCell.load_from_checkpoint(
        model_config["weight"],
        architecture=model_config["architecture"],
        encoder_name=model_config["encoder_name"]
    ).to(device)
    for model_config in nuclei_model_list[:1]
]


def ensemble_prediction(image_batch, *models):
    """
    This function is for doing prediction using ensemble models
    """
    with torch.no_grad():
        for model in models:
            model.eval()

        logits = [model(image_batch.to(device)) for model in models]
        logits = [torch.nn.functional.softmax(logit, dim=1).cpu() for logit in logits]
        logits = np.concatenate([logit[:, :, np.newaxis, :, :] for logit in logits], axis=2)

        logits = np.mean(logits, axis=2)

    return logits


def get_segmentation_nuclei(image):
    """
    Get the cropped of the cell, predict the mask and return the mask in the original shape
    """
    # we only need to get 1 channel for the nuclei
    image = np.array(image)[..., 0:1]
    image = np.repeat(image, 3, axis=-1)

    original_h, original_w, _ = image.shape

    image = cv2.resize(image, (192, 192))

    image = preprocessing_fn(image)

    image = image.transpose(2, 0, 1)
    image = np.expand_dims(image, 0)
    image = torch.from_numpy(image).float()

    with torch.no_grad():
        #output = nuclei_segmentation_model.model(image).numpy()
        output = ensemble_prediction(image, *nuclei_models)
        output = output[0].transpose(1, 2, 0).argmax(axis=-1, keepdims=True).astype(np.uint8)
        output = np.expand_dims(cv2.resize(output, (original_w, original_h)), axis=-1)

    return output

Lightning automatically upgraded your loaded checkpoint from v1.9.1 to v2.1.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint lightning_logs_astro/version_0/checkpoints/epoch=45-step=736.ckpt`


In [13]:
def save_segmented_nuclei(source_folder, destination_folder):
    """
    Take .jpg, .jpeg, .png files from source folder, segmentate it and save it in destination folder
    """
    # Create directory if it doesn't exist
    os.makedirs(destination_folder, exist_ok=True)
    
    # Get the list of files
    file_list = glob.glob(source_folder + '/*.png')
    print(f'Loaded {len(file_list)} files')

    # Get only .jpg, .jpeg, .png files, segmentate using get_segmentation_nuclei() and save in destination directory
    for file_name in file_list:
        file_name = file_name.split('/')[-1]
        
        source_file_path = os.path.join(source_folder, file_name)
        destination_file_path = os.path.join(destination_folder, file_name)

        image = cv2.cvtColor(cv2.imread(source_file_path), cv2.COLOR_BGR2RGB)

        output = get_segmentation_nuclei(image)
        output2file(output, 2., destination_file_path)

In [14]:
save_segmented_nuclei(source_folder = "cropped", destination_folder = "output/segmented_astrocytes")