Dataset: [MoNuSeg](https://monuseg.grand-challenge.org/Data/)

Model Architecture: [UNET](https://pypi.org/project/segmentation-models-pytorch/#architectures)

In [None]:
%matplotlib inline

import os

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms
from tqdm import tqdm as loadingbar
import segmentation_models_pytorch as smp

In [None]:
from Datasets.MoNuSeg.FileReader import files, get_image, get_mask
from Datasets.MoNuSeg.FileViewer import show_image, show_mask

## Data Preparation

In [None]:
# Dataset and Dataloader

class MoNuSegDataset(Dataset):

    def __init__(self) -> None:
        super().__init__()
        self.files = files
    
    def __len__(self) -> int:
        return len(self.files)
    
    def __getitem__(self, index) -> tuple:
        filename = self.files[index]
        return get_image(filename), get_mask(filename)

class MoNuSegDataLoader(DataLoader):

    def __init__(self, batch_size=5):
        super().__init__(MoNuSegDataset(), shuffle=True, batch_size=5, drop_last=True)
        torch.manual_seed(0)

dataloader = MoNuSegDataLoader()

# crops an entire batch for faster processing
image_size = 1000
crop_size = 512
ignored_edge_size = 4
def crop_batch(images, masks):
    x, y = torch.randint(ignored_edge_size, image_size - crop_size - ignored_edge_size, size=(1,2)).squeeze().tolist()
    assert (len(images.shape) == len(masks.shape))
    if len(images.shape) == 4:
        images = images[:, :, y:y+crop_size, x:x+crop_size]
        masks = masks[:, :, y:y+crop_size, x:x+crop_size]
    elif len(images.shape) == 3:
        images = images[:, y:y+crop_size, x:x+crop_size]
        masks = masks[:, y:y+crop_size, x:x+crop_size]
    return images, masks

rotations = [0, 90, 180, 270]
def rotate_batch(images, masks):
    rotation = int(np.random.choice(rotations))
    images = transforms.functional.rotate(images, rotation)
    masks = transforms.functional.rotate(masks, rotation)
    return images, masks

def preprocess_batch(images, masks):
    return rotate_batch(*crop_batch(images, masks))

In [None]:
# example use
for batch_num, (images, masks) in enumerate(dataloader):
    images, masks = crop_batch(images, masks)
    images, masks = rotate_batch(images, masks)
    show_image(images[0])
    show_mask(masks[0])
    break

In [None]:
# utils for batch and individual image shaping
def ensure_batch(t):
    if len(t.shape) == 3:
        return t.unsqueeze(0)
    else:
        return t
def ensure_individual_image(t, batch_to_single_index = 0):
    if len(t.shape) == 3:
        return t
    else:
        if t.shape[0] > 1:
            return t[batch_to_single_index]
        else:
            return t.squeeze(0)

## Model Training and Testing

In [None]:
def save_model(model, model_name):
    torch.save(model, f"pretrained-models/{model_name}.model")

def load_model(model_name):
    with open(f"pretrained-models/{model_name}.model", 'r') as f:
        return torch.load(f)

def train_model(model, epochs=3, batch_size=5, loss=None, optimizer=None):
    # process args
    if loss is None: loss = smp.utils.losses.DiceLoss()
    if optimizer is None: optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    # initialize dataloader
    dataloader = MoNuSegDataLoader(batch_size)
    # train model
    for epoch in loadingbar(range(epochs)):
        for batch_num, (images, masks) in enumerate(dataloader):
            images, masks = preprocess_batch(images, masks)
            optimizer.zero_grad()
            yhat_batch = model.forward(images)
            loss_batch = loss(yhat_batch, masks)
            loss_batch.backward()
            optimizer.step()
    return model

def test_model(model, image, mask=None, show_things=False):
    if show_things: show_image(ensure_individual_image(image))
    if show_things and mask is not None: show_mask(ensure_individual_image(mask))
    predicted_mask = model.forward(ensure_batch(image))
    predicted_mask = ensure_individual_image(predicted_mask)
    if show_things: show_mask(predicted_mask)
    return predicted_mask


In [None]:
# intiailize model
model = smp.Unet(
    encoder_name="vgg16",
    activation="sigmoid"
)

# training details
epochs = 40
batch_size = 5
loss = smp.utils.losses.DiceLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

# initialize dataloader
dataloader = MoNuSegDataLoader(batch_size)

# epoch functions
def do_epoch():
    for batch_num, (images, masks) in enumerate(dataloader):
        images, masks = preprocess_batch(images, masks)
        optimizer.zero_grad()
        yhat_batch = model.forward(images)
        loss_batch = loss(yhat_batch, masks)
        loss_batch.backward()
        optimizer.step()
def do_epochs(epochs):
    for epoch_num in loadingbar(range(epochs), desc="Epoch Counter", leave=True):
        do_epoch()

In [None]:
def test_model():
    batch = MoNuSegDataset()[0]
    batch = preprocess_batch(*batch)
    show_image(batch[0])
    show_mask(batch[1])
    test_out = model.forward(batch[0].unsqueeze(0))
    show_mask(test_out.squeeze(0))

In [None]:
do_epochs(5)
test_model()

In [None]:
do_epochs(5)
test_model()

In [None]:
do_epochs(5)
test_model()

In [None]:
import copy
model_save_15_epochs = copy.deepcopy(model)
do_epochs(5)
test_model()

In [None]:
model_save_20_epochs = copy.deepcopy(model)
do_epochs(5)
test_model()

In [None]:
model = model_save_20_epochs
torch.save(model.state_dict(), "temp.model")

## Apply Model to Labelled Cells

In [None]:
# get cell crops
files_eocell = ["96.jpg", "138.jpg", "1584.jpg", "1755.jpg"]
files_neutro = ["193.jpg", "461.jpg", "667.jpg", "816.jpg"]
files = [f"pretrained-data/LabeledCellCrops/{f}" for f in (
    [f"eosinophil/{e}" for e in files_eocell] + 
    [f"neutrophil/{n}" for n in files_neutro]
)]

In [None]:
# Cell crop padding and unpadded

class CellCropPadder(object):
    def __init__(self, size=256, value=0, placement=50) -> None:
        self.size = size
        self.value = value
        self.placement = placement
    def __call__(self, image) -> torch.Tensor:
        image = ensure_individual_image(image)
        padded_image = torch.ones(size=(image.shape[0], self.size, self.size)) * self.value
        padded_image[:, self.placement:self.placement+image.shape[1], self.placement:self.placement+image.shape[2]] += image
        return padded_image

class CellCropUnpadder(object):
    def __init__(self, padded_image):
        self.image = ensure_individual_image(padded_image)
        padded_img_width = self.image.shape[2]
        padded_img_height = self.image.shape[1]
        # initialize padding
        self.left_pad = 0
        self.right_pad = 0
        self.top_pad = 0
        self.bottom_pad = 0
        # left/right pad deals with columns
        while self.col_is_padding(self.left_pad): self.left_pad += 1
        while self.col_is_padding(padded_img_width - 1 - self.right_pad): self.right_pad += 1
        # top/bottom pad deals with rows:
        while self.row_is_padding(self.top_pad): self.top_pad += 1
        while self.row_is_padding(padded_img_height - 1 - self.bottom_pad): self.bottom_pad += 1
    def __call__(self, image):
        padded_image_width = image.shape[2]
        padded_image_height = image.shape[1]
        unpadded_image_width = padded_image_width - (self.left_pad + self.right_pad)
        unpadded_image_height = padded_image_height - (self.top_pad + self.bottom_pad)
        print(unpadded_image_width, unpadded_image_height)
        unpadded_image = image[
            :,
            self.top_pad : self.top_pad + unpadded_image_height,
            self.left_pad : self.left_pad + unpadded_image_width
        ]
        return unpadded_image
    def row_is_padding(self, row_num): return len(torch.unique(self.image[:,row_num,:])) == 1
    def col_is_padding(self, col_num): return len(torch.unique(self.image[:,:,col_num])) == 1
    def get_pads(self): return self.left_pad, self.right_pad, self.top_pad, self.bottom_pad

if False: # test these functions
    unpadded_img = torch.Tensor([[[[1, 2], [3, 4], [5, 6]]]])
    print(unpadded_img.shape)
    print(unpadded_img)
    padder = CellCropPadder(6, placement=1)
    padded_img = padder(unpadded_img)
    print(padded_img)
    unpadder = CellCropUnpadder(padded_img)
    padded_img_with_noise = padded_img.clone() + 1
    print(unpadder.get_pads())
    print(padded_img_with_noise)
    unpadded_img = unpadder(padded_img_with_noise)
    print(unpadded_img)

In [None]:
# process cell crop
def get_cell_crop(cell_crop_filename):
    return pil_to_tensor(Image.open(cell_crop_filename))
def test_model_with_cell_crop(model, cell_crop_filename, threshold=0.5):
    img = get_cell_crop(cell_crop_filename)
    padder = CellCropPadder()
    padded_img = padder(img)
    prediction = test_model(model, padded_img)
    unpadder = CellCropUnpadder(padded_img)
    unpadded_prediction = unpadder(prediction)
    final_prediction = (unpadded_prediction > threshold).to(unpadded_prediction.dtype)
    return final_prediction

In [None]:
for f in files:
    print(f)
    p = test_model_with_cell_crop(model, f)
    show_image(p)
    torchvision.utils.save_image(get_cell_crop(f), f.replace("eosinophil","test-mask").replace("neutrophil","test-mask").replace(".jpg",f"-{'e' if 'eosinophil' in f else 'n'}.jpg"))
    torchvision.utils.save_image(p, f.replace("eosinophil","test-mask").replace("neutrophil","test-mask").replace(".jpg",f"-{'e' if 'eosinophil' in f else 'n'}-m.jpg"))