# Goals of this Kernel
This kernel will provide you with a starter template to load all images and masks into memory and gets them ready for pytorch lightning.  
  
We will also use Splitter to create a 10 fold Crossvalidation dataset which can easily be extended with more data.

## Why is memory an issue?
Each image is annotated with every single cell of interest. Our goal is to segment them individually.  
Because they can overlap, we cannot simply store a number for each pixel coressponding to the cell, but need a mask for each single neuron.  
This array would be very big and not fit in memory (I tried).  
But because it almost only contains zeros, we can use [sparse matrices](https://sparse.pydata.org/en/stable/) and only store the positive pixels.  
We can easily convert this back into a dense representation at runtime.
  
## Where can I follow you?  
I am glad you asked: https://twitter.com/PSodmann

In [None]:
!pip install git+https://github.com/p-sodmann/splitter -q
!pip install sparse -q

In [None]:
import numpy as np
import pandas as pd
import os
from splitter.splitter import Splitter
from tqdm.auto import tqdm
import sparse
from torch.utils.data import DataLoader, Dataset
import pandas as pd
import imageio
import matplotlib.pyplot as plt
import torch

In [None]:
# ref: https://www.kaggle.com/inversion/run-length-decoding-quick-start
def rle2mask(mask_rle, shape):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (width,height) of array to return 
    Returns numpy array, 1 - mask, 0 - background

    '''
    s = mask_rle.split()
    starts, lengths = [
        np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])
    ]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0] * shape[1], dtype=np.bool)

    for lo, hi in zip(starts, ends):
        img[lo : hi] = 1
    
    return img.reshape(shape)

In [None]:
# pads and truncates the mask to max_size in z direction (number of possible annotated cells in one image)
def pad(array, max_size=128):
    if array.shape[0] <= max_size:
        padded = np.zeros((max_size, array.shape[1], array.shape[2]))
        padded[:array.shape[0]] = array
    else:
        padded = array[:max_size]
    
    return padded

class CellDataset(Dataset):
    def __init__(self, data):
        self.data = data
        
        # tile size
        self.size = 256
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        img = self.data[idx]["image"]
        
        # get a random crop
        x = np.random.randint(img.shape[0] - self.size)
        y = np.random.randint(img.shape[1] - self.size)
        
        # make mask dense
        dense_mask = self.data[idx]["sparse_mask"].todense()
        dense_mask = dense_mask[:, x:x+self.size,y:y+self.size]
        
        # get only masks in the image, which contain positive pixels (neurons)
        filled_mask = dense_mask[np.where(np.sum(dense_mask, axis=(1,2)) > 0)]
        
        # pad in z direction
        padded_mask = pad(filled_mask)
        
        # crop image and return image and mask
        return np.array(np.expand_dims(img[x:x+self.size,y:y+self.size], 0)), padded_mask

In [None]:
import pytorch_lightning as pl

class CellDataModule(pl.LightningDataModule):
    def __init__(self, dataframe, batch_size: int = 32):
        super().__init__()
        
        self.dataframe = dataframe
        self.batch_size = batch_size
        
        # make 10 split cross validation.
        self.splitter = Splitter(10, seed=21188)

    def setup(self, stage=None):
        self.cell_ids = self.dataframe["id"].unique()
        
        # add data to cross-validation.
        # we can add more semi supervised data later without changing the splits
        # https://medium.com/analytics-vidhya/splitting-your-data-growing-beyond-train-test-split-dc0eb83d7dac
        for cell_id in self.cell_ids:
            self.splitter.add(cell_id)
        
        # 8 folds for training, 1 for validation, 1 for testing
        self.train_ids, self.valid_ids, self.tests_ids = self.splitter.get_split([[0,1,2,3,4,5,6,7], [8], [9]])
    
    def load_data(self, item_ids):
        items = []
        for item_id in tqdm(item_ids):
            image = imageio.imread(f'../input/sartorius-cell-instance-segmentation/train/{item_id}.png')
            
            mask = []
            
            # get all annotations for one image
            cells = self.dataframe.loc[self.dataframe["id"] == item_id]
            
            # get all masks for a particular image
            for index, cell in cells.iterrows():
                mask.append(rle2mask(cell["annotation"], image.shape))
                
            # make it sparse, so it fits into memory
            mask = sparse.COO(np.array(mask))
            
            items.append({"image":image, "sparse_mask":mask})
        
        return items
        
    def train_dataloader(self):
        self.train_data = CellDataset(self.load_data(self.train_ids))
        return DataLoader(self.train_data, batch_size=self.batch_size)

    def val_dataloader(self):
        self.valid_data = CellDataset(self.load_data(self.valid_ids))
        return DataLoader(self.valid_data, batch_size=self.batch_size)

    def test_dataloader(self):
        self.tests_data = CellDataset(self.load_data(self.tests_ids))
        return DataLoader(self.tests_data, batch_size=self.batch_size)

In [None]:
annotation_df = pd.read_csv("../input/sartorius-cell-instance-segmentation/train.csv")

cdm = CellDataModule(annotation_df)
cdm.setup()

tdl = cdm.train_dataloader()

We managed to load all data into memory, this only works, because we saved the masks in a sparse format.  
Before using them in a neural network, we need to convert them back into a dense representation, this happens in the dataset.

In [None]:
image_number = 50
cell_number = 1

data = cdm.train_data[image_number]

In [None]:
cell_number = 3

plt.imshow(data[1][cell_number,:,:])
plt.show()

plt.imshow(data[0][0])
plt.imshow(data[1][cell_number,:,:], alpha=0.3)
plt.show()

In [None]:
# overlay of all cells in the image

plt.imshow(data[0][0])

all_masks = np.zeros([256, 256])
for mask in data[1]:
    all_masks += mask
    
plt.imshow(all_masks, alpha=0.3)
plt.show()

Awesome, we managed to load the data and fit it into memory and get our mask back.  
Have fun building a model and competing in this challenge!
  
🐱 Phil