Tutorial 5: Fused Gromov-Wasserstein distances with segmented image files
=========================================================================

Fused Gromov-Wasserstein is a variant of Gromov-Wasserstein which associates to each point in the two cells a feature vector, and tries to match points in the two cells with similar feature vectors.
In this tutorial, we illustrate an application of Fused GW where we process two-dimensional segmented images, imaged with multiple channels, and align cells in a way that attempts to align points with similar levels of each cell intensity.

Here, rather than analyze individual cells, we will analyze gastruloid images. We start by reading the cell data into memory and casting it into the appropriate format. This basic file IO code varies by application.

In [1]:
from os.path import join
import numpy as np
from tqdm import tqdm
import skimage.io as skio

def process_gastroid_images(data_path, dataset_dir, gastr_meta, im_fname_pre, im_fname_suf):
    image_gastrs = []
    gastr_metadata_list = []

    for ann in tqdm(gastr_meta['annotations']):
        # load image
        im_phase_path = join(data_path, dataset_dir, im_fname_pre + ann['location']['XY'] + im_fname_suf)
        im_green_path = im_phase_path.replace('Phase', 'Green')
        im_red_path = im_phase_path.replace('Phase', 'Red')
        im_phase = skio.imread(im_phase_path)
        im_green = skio.imread(im_green_path)
        im_red = skio.imread(im_red_path)

        nrow, ncol = im_green.shape

        image_intensities = np.stack((im_phase, im_green, im_red), axis=0)
        mask_verts = np.stack([np.array((pt['x'], pt['y'])) for pt in ann['coordinates']], axis=0)
        for row in mask_verts:
            assert (0 <= row[0] and row[0] < nrow)
            assert (0 <= row[1] and row[1] < ncol)

        image_gastrs.append((image_intensities, mask_verts))

        # save metadata
        gastr_meta_i = {'clone':ann['tags'], 
                     'location_XY':[ann['location']['XY']],
                     'location_Z':[ann['location']['Z']],
                     'time':[ann['location']['Time']],
                     'id':[ann['id']],
                     'dataset_id':[ann['datasetId']],
                     'dataset_dir':[dataset_dir], 'im_path':[im_green_path]
                   }
        gastr_metadata_list.append(gastr_meta_i)

    return(image_gastrs, gastr_metadata_list)

In [2]:
import os
import json
data_path = "/home/jovyan/dropbox/Projects/PGC022_Spatial_Protein_GW/Data/Input Data - May 2024"
dataset_dir = 'Ewx_eWx_ewX_fixed/'
im_fname_pre = 'Phase_EwxeWxewX_'
im_fname_suf = '_1_00d00h00m.tif'
# read gastruloid metadata and mask info from json file
json_fname = next((fname for fname in os.listdir(join(data_path,dataset_dir)) if '.json' in fname))
json_path = os.path.join(data_path, dataset_dir, json_fname)
with open(json_path) as json_data:
    gastr_meta = json.load(json_data)
# process gastroid images
image_gastrs, gastr_metadata_list = process_gastroid_images(data_path, dataset_dir, gastr_meta, im_fname_pre, im_fname_suf)

  0%|          | 0/31 [00:00<?, ?it/s]

100%|██████████| 31/31 [00:01<00:00, 24.01it/s]


`gastr_metadata_list` is irrelevant to the fused GW functionality, but will be useful later in the analysis. The important list is `image_gastrs`. Inspect `image_gastrs` to see what we need for each cell. Each element of the list `image_gastrs` is an ordered pair `(image_intensities, vertex_coords)`, where `image_intensities` is of shape `(3, nrow, ncol)` - it contains three color channels of an `nrow x ncol` dimensional image. `vertex_coords` is a list of polygon vertices of shape `(z,2)`, where `z` is the number of vertices in the bounding polygon for the cell.

The basic class CAJAL provides for applying Fused GW to cell images is the `CellImage` class, and its constructor takes `image_intensities` and `vertex_coords` as arguments.

In [3]:
from cajal.sample_seg import CellImage

cell_list = list(tqdm( (CellImage(image_intensities, vertex_list, downsample=8) for (image_intensities, vertex_list) in image_gastrs ), total = len(image_gastrs)))

100%|██████████| 31/31 [00:26<00:00,  1.16it/s]


Some of the cells are too small to give interesting results, and we filter these out of the list. We discard cells with fewer than five points.

In [4]:
cells = [ (cell, metadata) for (cell, metadata) in zip(cell_list, gastr_metadata_list) if cell.distance_matrix.shape[0] >= 5]
cell_list, gastr_metadata_list = list(zip(*cells))

Once we have all of the cells, we can standardize so that the distribution of pixel intensities across each channel has a standard deviation of one (across all cells)

In [5]:
from cajal.sample_seg import normalize_across_cells
normalize_across_cells(cell_list)

In [6]:
from cajal.sample_seg import fused_gromov_wasserstein

fgw= fused_gromov_wasserstein(cell_list[0], cell_list[1], channels=(1,2), log=True)

In [7]:
from cajal.sample_seg import fused_gromov_wasserstein_parallel

fused_gromov_wasserstein_parallel(cell_list[:5], channels=(1,2), alpha=0.5)

100%|██████████| 10/10 [00:21<00:00,  2.10s/it]


array([[   0.        ,  749.52240592,  313.31525804,  430.54228087,
         312.05444085],
       [ 749.52240592,    0.        ,  210.82831396, 1606.5842935 ,
         718.72868684],
       [ 313.31525804,  210.82831396,    0.        ,  787.26256275,
         259.11088005],
       [ 430.54228087, 1606.5842935 ,  787.26256275,    0.        ,
         261.22256883],
       [ 312.05444085,  718.72868684,  259.11088005,  261.22256883,
           0.        ]])