### Prepare spatial training data

In [127]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tifffile
import imageio
import xarray
import scipy
import inspect # from package inspect-it
import sys
import torch

To install a pip package in the current Jupyter kernel:

In [9]:
#import sys
#!{sys.executable} -m pip install inspect-it

### Define Class and Funtions 

In [176]:
def logger(comment):
    """Function to print Class::Name - Comments from any classes

    Args:
        comment (str): comment to add to the logger
    """
    class_fun_str = inspect.stack()[1][0].f_locals["self"].__class__.__name__ + "::" + \
                    inspect.stack()[1].function + "()"
    print("{:<40} - {:<50}".format(class_fun_str, comment))

def drawProgressBar(percent, barLen = 100):
    # percent float from 0 to 1. 
    sys.stdout.write("\r")
    sys.stdout.write("[{:<{}}] {:.0f}%".format("=" * int(barLen * percent), barLen, percent * 100))
    sys.stdout.flush()
        
class Cell(object):
    def __init__(self, x, loc, label, e=None, A=None, seg=None, mask=None, img=None, pi=None, z=None):
        """This class contain a Cell and all data related to a single-cell c
        
        Args:
            x (torch.tensor, 1D): the spatial gene expression, from dots confidently attributed to cell c
            A (xarray, 3D): 3D spatial expression containing coordinate of gene counts in local neighborhood
            loc (pd.Series, 1D): contain x,y location of cell c centroid, accessible via .loc.x or .loc.y
            seg (np.array, 2D): segmentation centered on cell c containing neighbor cells in local neighborhood
            mask (np.array, 2D): binary mask from segmentated object in local neighborhood
            img (np.array, 2D): image centered on cell c containing neighbor cells in local neighborhood
            label (int): label of cell c
            e (int): size of local neighborhood window, centroids will be enlarged by +/- e 
            pi (torch 1D tensor): probability to be affiliated to each K cell type
            z (torch 1D tensor): assignment of cell type based on pi_c
        """
        self.x = torch.tensor(x)
        self.A = A
        self.loc = loc
        self.label = label
        self.seg = seg
        self.mask = mask
        self.img = img
        self.pi = pi
        self.z = z
        self.x_start = None ; self.x_end = None
        self.y_start = None ; self.y_end = None
        self.on_boarder = False
        self.input_CNN = None
        
    def build_crops_c(self, A, seg, img, e, x_width, y_width):
        """ Crop full size image, segmentation and matrix A, centered on cell c
            Cell centroids will be enlarged by +/- e"""
        #logger(f'cropping image, mask, segmentation and matrix A for cell {self.label}')
        # record coordinates of local neighborhood
        self.x_start = int(self.loc.x - e) ; self.x_end = int(self.loc.x + e)
        self.y_start = int(self.loc.y - e) ; self.y_end = int(self.loc.y + e)
        # crop input matrices 
        self.A = A[:,self.x_start:self.x_end,self.y_start:self.y_end] # size (G, 2e, 2e), G number of genes
        self.seg = seg[self.x_start:self.x_end,self.y_start:self.y_end]
        # binary mask from segmentation
        self.mask = self.seg.copy()
        self.mask[self.mask > 0] = 1
        # croped image
        self.img = img[self.x_start:self.x_end,self.y_start:self.y_end]
        # 3D tensor used as input for CNN 
        self.input_CNN = torch.tensor(np.concatenate(( self.mask.reshape((1,) + self.mask.shape), 
                                                       self.img.reshape((1,) + self.img.shape), 
                                                       self.A )))
        if self.x_start < 0 or self.y_start < 0 or self.x_end > x_width or self.y_end > y_width:
            self.on_boarder = True
        
class CellContainer(object):
    def __init__(self, img, seg, D, S, e, A=None, X=None, loc=None, cell_labels=None, gene_names=None, 
                 celltype_labels=None, ignore_cells_with_0transript = True):
        """CellContainer class contains all segmented cells form an image, and their associated spatial gene expression
        
        Args:
            img (np.array, 2D): dapi staining raw image, size (width_X, width_Y)
            seg (np.array, 2D): nuclei segmentation based on dapi staining
            D (pd.DataFrame, 3 columns): Gene dots information, columns are (gene_name, x_coord, y_coord).
            S (pd.DataFrame): Average cell-type gene expression from scRNA-seq, of size (K, G), K being the number of cell-types and G the number of genes
            e (int): size of enlargement around cell centroids for making crops used by CNN.
            A (xarray, 3D): contains the dot corresponding the spatial gene counts.xarray of size (G, width_X, width_Y), G being the number of genes, and (width_X,width_Y) the size of the full image. Calculated on the fly from D.
            X (np.array, 2D): contains the gene counts confidently assigned to each cell c segmented in the spatial data. Calculated on the fly from seg, C, and D.
            loc (pandas, 3 columns): cell centroid information, columns are (cell_label, x_coord, y_coord)
            cell_labels (pd.Series, 1D): labels of cells in X, based on labels in segmentation matrix. Calculated on the fly from seg.
            gene_names (pd.Series, 1D): labels of genes in matrix A, X, and M. Calculated on the fly from unique gene names in D.
            celltype_labels (pd.Series, 1D): labels of cell-types in M. 
            ignore_cells_with_0transript (bool): [default: True] defines if cells that get 0 RNA transcript should be ignored
        """
        self.img = img
        self.seg = seg
        self.D = D
        self.S = S
        self.e = e
        self.A = A 
        self.X = X
        self.loc = None
        self.cell_labels = pd.Series(np.unique(self.seg)[1:]) # Build variable with unique Cell IDs
        self.celltype_labels = None
        self.ignore_cells_with_0transript = ignore_cells_with_0transript
        self.X_miss = None # matrix X including cells with 0 transcripts
        self.cell_labels_miss = None # cell labels including cells with 0 transcripts
        self.cells = [] # list containing all Cell in image
        self.x_width = self.seg.shape[0]
        self.y_width = self.seg.shape[1]
        
        # class set-up function 
        self.define_gene_overlap()

    def define_gene_overlap(self):
        """Function to find the list of gene names that overlaps between input data matrices"""
        self.gene_names = pd.Series(sorted(set(self.D.name).intersection(set(self.S.index))))
        self.D = self.D.iloc[np.array(D.name.isin(self.gene_names))]
        self.D.index = range(self.D.shape[0]) # reindex
        self.S = self.S.iloc[np.array(S.index.isin(self.gene_names))]
        self.S.sort_index(inplace=True)
        if len(self.gene_names) > 0:
            logger(f"Found {len(self.gene_names)} overlapping gene names")
        else:
            raise ValueError('No overlap between gene names in D and S')
        
    def get_cell_centroids(self):
        """Function to find cell centroids based on cell segmentation"""
        logger(f'Computing cell centroids based on segmentation file')
        coords = []
        for i in self.cell_labels:
            coords.append([np.mean(x) for x in np.where(seg == i)])
        self.loc = pd.DataFrame(coords, columns = ['x', 'y'], index=self.cell_labels)
        
    def find_closest_cell(self, x, y, max_shift=30):
        """Given (cell_x, cell_y) cell centroid coordinates, returns the closest cell in a radius of max_shift"""
        for i in range(1,max_shift):
            shift=i
            cell_ids = np.unique(self.seg[x-shift:x+shift+1,y-shift:y+shift+1])
            cell_ids = cell_ids[cell_ids!=0] # remove 'cell number 0' corredponding to background
            if any(cell_ids != 0): # test if a cell is detected
                n_cells_detected = len(cell_ids)
                if n_cells_detected == 1: # test if only one cell is detected
                    return cell_ids[0]
                else: # if more than one cell is detected, return nothing
                    return 0 # 'cell number 0' means no cell was detected 
        return 0
    
    def build_X_matrix(self, max_shift=30):
        """Based on segmentation and x,y coordinates of a cell, the closest cell is returned
            If no cell is found in a distance of max_shift around the gene dot, 0 will be
            returns meaning no cell attributed"""
        logger(f'Attributing gene dots to closest cells')
        closest_cells = []
        for c in range(self.D.shape[0]):
            x = self.D.loc[c].x
            y = self.D.loc[c].y
            closest_cells.append(self.find_closest_cell(x, y, max_shift=max_shift))
        # Add info in dataframe D
        self.D.insert(3, 'cell', value=np.array(closest_cells, dtype=int))
        attributed = self.D['cell'].copy()
        attributed[attributed > 0] = 1
        self.D.insert(4, 'attributed', value=np.array(attributed)) 
        
        logger("""Build matrix X based on dot attribution""")
        self.X = self.D.pivot_table(index='cell', columns='name', values='attributed', 
                                    aggfunc= 'sum').fillna(0).astype(int)
        self.X = self.X[1:]
        logger(str(round(100*self.X.shape[0]/len(np.unique(self.seg)))) + '% of cells got at least 1 RNA count') 
        if self.ignore_cells_with_0transript:
            logger("""Ignoring cells with 0 RNA transcript. Updating cell_labels and loc""")
            self.cell_labels = self.X.index
            self.loc = self.loc.iloc[self.loc.index.isin(self.cell_labels)]
        else:
            self.add_missing_cells_in_X()
            
    def add_missing_cells_in_X(self):
        """Adds rows for cells that got 0 RNA transcript attributed"""
        # We get missing cell IDs
        missing_ids = self.cell_labels[~self.cell_labels.isin(self.X.index)].astype(str)
        # We build an empty data.frame with missing cell IDs
        pd_missing_data = pd.DataFrame(np.zeros((len(missing_ids), self.X.shape[1])), index=missing_ids, columns=self.X.columns)
        # We build X_supp which contains cell with 0 counts 
        self.X_miss = pd.concat([self.X, pd_missing_data]).astype(int)
        self.X_miss.index = self.X_miss.index.astype(int)
        self.X_miss = self.X_miss.sort_index()

    def subset_cells(self, cell_ids):
        """Subsetting cells in all tables to fit cell labels in self.cell_labels variable"""
        logger("""Subsetting cells in self.X, self.X, and self.X""")
        
    def build_A_matrix(self):
        """Build matrix A based on gene dots in D"""
        logger("Build matrix A based on gene dots in D")
        self.D.insert(5, 'value', 1)
        all_As = []
        for gene in self.gene_names:
            gene_count_1gene = self.D.iloc[np.array(self.D.name == gene)]
            A_1gene = gene_count_1gene.pivot_table(index='x', columns='y', values='value', 
                                                   aggfunc= 'sum').fillna(0).astype(int)
            A_1gene = A_1gene.reindex(range(self.seg.shape[0]), axis=0, fill_value=0)
            A_1gene = A_1gene.reindex(range(self.seg.shape[1]), axis=1, fill_value=0)
            all_As.append(np.array(A_1gene))
        self.A = xarray.DataArray(np.array(all_As), [ ("genes", self.gene_names), ("x", range(self.seg.shape[0])), ("y", range(self.seg.shape[1]))] )
    
    def build_Cell_objects(self):
        """Build Cell objects based on cell_labels, X, seg, and cell centroids, used as training data for our model."""
        logger("Build Cell objects based on X, A, seg, and cell centroids")
        logger("Iterating over cells:")
        n_cells = len(self.cell_labels)
        for cn, c in enumerate(self.cell_labels):
            Cell_obj = Cell(x=np.array(self.X.loc[c]), loc=self.loc.loc[c], label=c)
            Cell_obj.build_crops_c(A=self.A, seg=self.seg, img=self.img, e=self.e, 
                                   x_width=self.x_width, y_width=self.y_width)
            self.cells.append(Cell_obj)
            progress = (cn + 1) / n_cells
            drawProgressBar(progress)

## Main 

Load necessary tables and build a `CellContainer` object, which will build all `Cell` objects containing all data needed from training our models.

In [154]:
D = pd.read_csv("../data/crop1_genes.csv")
img = tifffile.imread("../data/crop1_dapi.tif").T
seg = imageio.v2.imread('../data/crop1_dapi_cp_masks.png').T

In [155]:
# fix mistake in input D (extra row)
D = D.iloc[np.array(D.y != 1645)].copy()

In [156]:
S = pd.read_csv("../data/scrna_muX_clust16_TaxonomyRank3.csv", index_col='Gene')

In [177]:
CellContainer_obj = CellContainer(img=img.copy(), seg=seg.copy(), D=D.copy(), S=S.copy(), e=50,
                                  ignore_cells_with_0transript=True)
CellContainer_obj.get_cell_centroids()
CellContainer_obj.build_X_matrix()
CellContainer_obj.build_A_matrix()
CellContainer_obj.build_Cell_objects()

CellContainer::define_gene_overlap()     - Found 187 overlapping gene names                  
CellContainer::get_cell_centroids()      - Computing cell centroids based on segmentation file
CellContainer::build_X_matrix()          - Attributing gene dots to closest cells            
CellContainer::build_X_matrix()          - Build matrix X based on dot attribution           
CellContainer::build_X_matrix()          - 82% of cells got at least 1 RNA count             
CellContainer::build_X_matrix()          - Ignoring cells with 0 RNA transcript. Updating cell_labels and loc
CellContainer::build_A_matrix()          - Build matrix A based on gene dots in D            
CellContainer::build_Cell_objects()      - Build Cell objects based on X, A, seg, and cell centroids
CellContainer::build_Cell_objects()      - Iterating over cells:                             

In [160]:
cell_id = 600
all(CellContainer_obj.cells[cell_id].A.genes == CellContainer_obj.gene_names) # display xarray matrix A of 'cell_id'

True

In [178]:
CellContainer_obj.cells[cell_id].input_CNN

tensor([[[   0,    0,    0,  ...,    0,    0,    0],
         [   0,    0,    0,  ...,    0,    0,    0],
         [   0,    0,    0,  ...,    0,    0,    0],
         ...,
         [   0,    0,    0,  ...,    0,    0,    0],
         [   0,    0,    0,  ...,    0,    0,    0],
         [   0,    0,    0,  ...,    0,    0,    0]],

        [[ 985, 1004,  994,  ..., 1873, 1918, 1956],
         [ 994, 1000,  983,  ..., 1851, 1895, 1933],
         [1013, 1026, 1009,  ..., 1816, 1884, 1908],
         ...,
         [1018, 1001,  987,  ...,  930,  939,  931],
         [1012, 1002,  988,  ...,  904,  915,  931],
         [ 994, 1009, 1003,  ...,  906,  912,  908]],

        [[   0,    0,    0,  ...,    0,    0,    0],
         [   0,    0,    0,  ...,    0,    0,    0],
         [   0,    0,    0,  ...,    0,    0,    0],
         ...,
         [   0,    0,    0,  ...,    0,    0,    0],
         [   0,    0,    0,  ...,    0,    0,    0],
         [   0,    0,    0,  ...,    0,    0,    0]],

In [171]:
756*11*11

91476