# Welcome to this single-cell alternating oxygen analysis notebook

This notebook is designed to perform segmentation and tracking of E. coli time-lapses and has been jointly developed by Keitaro Kasahara and Johannes Seiffarth 💪

The original paper can be found here: [Unveiling microbial single-cell growth dynamics under rapid periodic oxygen oscillations](https://doi.org/10.1039/D5LC00065C) where K. Kasahara et al. applied alternating oxygen conditions (between 0 and 21 %) at different switching rates to living E. coli cell populations.

In this notebook we extend the population based analysis to a single-cell view and address the question how individual cells react to the switches in oxygen conditions.

In the analysis we perform the following steps:

1. Perform segmentation on an omero sequence
2. Extracting individual cell information
3. Filtering cells based on there individual information to reduce the number of artifacts (inlcuding borders)
4. Perform tracking
6. Perform population based analysis
7. Perform single-cell based analysis


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
if os.environ.get("JYPN_NO_DEP_INSTALL", None) is None:
    %pip uninstall acia -y
    %pip install acia==0.3.1
    
    # dependencies for segmentation
    %pip install torch torchvision torchaudio # --index-url https://download.pytorch.org/whl/cpu
    %pip install omnipose==1.0.6
    %pip install natsort
    %pip install scipy==1.11.4
    
    # dependencies for tracking
    %pip install "trackastra"
else:
    print("Running in scaling mode! Do not install requirements!")

In [None]:
import os
from pathlib import Path

# get the acia unit registry
from acia import ureg

phase_contrast_channel = 0

# use current working directory as default storage folder for outputs
storage_folder = os.getcwd()

# size of a single pixel in the image
pixel_size = 0.074 * ureg.micrometer

# subsampling factor: makes analysis faster but leads to temporal resolution loss. Factor of 1 is orignal imaging sequence.
subsampling_factor = 10

# define the imaging interval
imaging_interval = "10 * second"

# define the number of frames (use None for the full sequence)
number_of_frames = None

image_id = "34849.tif"

num_images = None

In [None]:
image_path = Path(image_id)

# correct the imaging interval with the subsampling factor
imaging_interval = ureg.Quantity(imaging_interval) * subsampling_factor

# create the output directory
output_path = Path(storage_folder) / "output/"
output_path.mkdir(parents=True, exist_ok=True)

# make path relative (advantage in video embedding)
output_path_rel = output_path.relative_to(Path(os.getcwd()))

if not image_path.exists():
    # download data if necessary
    !wget -O 34849.tif https://fz-juelich.sciebo.de/s/7w2atDj5YjfdPjZ/download

In [None]:
import torch
import logging

try:
  import google.colab
  IN_COLAB = True
except:
  IN_COLAB = False

cuda = torch.cuda.is_available()

if not cuda:
  logging.warning("You are not using GPU computation. Thus the deep learning segmentation might take a while!")
  if IN_COLAB:
    logging.warning("Please go to 'Runtime > Change runtime type' in order to select a GPU based runtime in colab!")

In [None]:
from pathlib import Path
import tifffile
from acia.segm.local import THWCSequenceSource
import numpy as np
from tqdm.auto import tqdm

image_stack = tifffile.imread(image_path)

# bring the image stack into TxHxWxC (time, height, width, channels) format
source = THWCSequenceSource(image_stack[::subsampling_factor,...])

# Information about the image stack

In [None]:
import matplotlib.pyplot as plt

T = source.size_t
C = source.size_c

# display markdown
from IPython.display import Video, Markdown, display
display(Markdown("# Image information"))

table = f"""
| Value    | Content |
| --- | --- |
| Image Path | {image_path} |
| T Size | { T } |
| C Size | { C } |
| Channels | {','.join([f"{c}" for c in range(C)])} |
| Imaging Interval | {imaging_interval} |
| Pixel Size | {pixel_size} |
| Phase-Contrast Channel | {phase_contrast_channel} |
| Image dtype | {image_stack.dtype}
"""

display(Markdown(table))
display(Markdown(f"## Preview of channels"))

t = T // 2

image = source.get_frame(t).raw

fig, ax = plt.subplots(1, C, figsize=(15, 15))
for i, c in enumerate(range(0, C)):       # Channel index starts at 1

    if C > 1:
        loc_ax = ax[i]
    else:
        loc_ax = ax

    loc_ax.imshow(image[...,c], cmap="gray")
    loc_ax.set_title(f"Channel {i}, t: {t}")

if num_images is None:
  num_images = T

plt.tight_layout()

# 1. Cell Segmentation

No we specify the segmentation model: [Omnipose](https://doi.org/10.1101/2021.11.03.467199) and the channel we want to select to extract the image data. The channel data can be observed in the [Omero Web Viewer](http://ibt056.ibt.kfa-juelich.de:4080/). Please keep in mind that you have to enter the channel value+1 in `channels`. With the model and image sequence we kick off the segmentation.

In [None]:
import torch
from acia.segm.processor.omnipose import OmniposeSegmenter

# connect to remote machine learning model
model = OmniposeSegmenter()

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

# perform overlay prediction
print("Perform Prediction...")
with torch.no_grad():
  result = model(source.to_channel(phase_contrast_channel), omnipose_parameters=dict(batch_size=30))

To validate the segmentation result, we create a short video:

In [None]:
import acia
from acia.viz import render_segmentation_mask, render_video, render_time, render_scalebar
import numpy as np
from acia import ureg

# video rendering configuration
video_config = dict(codec="vp9", ffmpeg_params = ["-crf", "30", "-b:v", "0", "-speed", "1"])

# scalebar placement
scalebar_config = dict(
    xy_position=(300, 625),
    size_of_pixel = pixel_size,
    bar_width=10 * ureg.micrometer, # width of the scalebar
    bar_height="1 micrometer" # height of the scalebar
)

# timestamp placement
time_config = dict(
    xy_position=(300, 50),
    timepoints=np.array(list(range(num_images))) * imaging_interval, # timepoints of the individual frames (with correct unit)
    background_color = (0, 0, 0),
)

# framerate of the video
framerate=20

# Make a video with
video_file = str(output_path_rel / "segmented.mp4")

# do the different rendering steps sequentially
source_rend = render_time(source.to_rgb(), **time_config)
source_rend = render_scalebar(source_rend, **scalebar_config)
source_rend = render_segmentation_mask(source_rend, result, alpha=0.5)
render_video(source_rend, filename=video_file, **video_config, framerate=framerate)

# Display the rendered segmentation
from IPython.display import Video, Markdown, display
display(Markdown("# Your segmentation"))

from moviepy.editor import *
myvideo =  VideoFileClip(video_file)
myvideo.ipython_display(maxduration=400)

# 2. Extracting individual cell properties

Now that we have the cell segmentation, we can move on and extract individual cell properties like Area, Time, Length, ....
and visualize them in a table:

In [None]:
from acia.analysis import ExtractorExecutor, AreaEx, IdEx, FrameEx, TimeEx, LengthEx, PositionEx
from acia import ureg
import numpy as np
import pint

ex = ExtractorExecutor()

df = ex.execute(result, source, [
    # define the cell properties that you want to extract here
    AreaEx(input_unit=pixel_size ** 2),  # pass the correct area of pixels
    FrameEx(),
    LengthEx(input_unit=pixel_size),  # pass the correct size of pixels
    PositionEx(input_unit=pixel_size),
    TimeEx(input_unit=imaging_interval),  # 1/6 = one picture every 10 min, 1/60 = every 1 minutes, 1/360 = every 10 seconds
])

print(df)

# 3. Filtering artifacts in segmentation

In the segmentation, we can often observe artifacts, that is objects that are mistakenly recoginzed as cells. To reduce the number of artifacts in our analysis we can utilize some simple filtering functionality for the area: We only keep all the objects that have an area between `min_area` and `max_area` as defined below in the code:

In [None]:
import matplotlib.pyplot as plt

min_area = 0.7  # the minimal area in micrometer ** 2. All smaller objects are dropped
max_area = 15 # the maximal area in micrometer ** 2. All larger objects are dropped
# usually max 15

fig, ax = plt.subplots(2, 1, facecolor='white', figsize=(15,10))

area_unit = ex.units['area']

# plot the area distribution before filtering
ax[0].hist(df['area'], bins=100)
ax[0].set_title('Area distribution before filtering')
ax[0].set_ylabel('Frequency')
ax[0].set_xlabel(f'Cell area [${area_unit:~L}$]')
ax[0].set_yscale("log")

# filter by position: cell center should at least be .5 micrometer away from border
margin = .5
img = source.get_frame(0).raw
left, top = 0,0
bottom, right = np.array(img.shape[:2]) * pixel_size.to("micrometer").magnitude

# filter by cell area
filtered_df = df[(min_area < df['area']) & (df['area'] < max_area) & ~(df["position_x"] < margin) & ~(df["position_x"] > right - margin) & ~(df["position_y"] < margin) & ~(df["position_y"] > bottom - margin)]

# plot the area distribution after filtering
ax[1].hist(filtered_df['area'], bins=100)
ax[1].set_title('Area distribution after filtering')
ax[1].set_ylabel('Frequency')
ax[1].set_xlabel(f'Cell area [${area_unit:~L}$]')
ax[1].set_yscale("log")

plt.tight_layout()

# export with decimal . and separation ;
filtered_df.to_csv(str(output_path / 'allcells.csv'), decimal='.', sep=';')

print("Done")

And now let's look at the new video with filtered content

# 4. Render filtered segmentation

In [None]:
from acia.base import Overlay
import numpy as np

# ids in the filtered dataframe
id_set = set(filtered_df.index)

# store segmentation
filtered_overlay = Overlay([c for c in result if c.id in id_set])

# Make a video with
video_file = str(output_path_rel / "filter_segmented.mp4")

source_time = render_time(source.to_rgb(), (800, 50), timepoints=np.array(range(num_images)) * imaging_interval)
source_scalebar = render_scalebar(source_time, xy_position=(750, 1050), size_of_pixel = pixel_size, bar_width="10 micrometer", bar_height="1 micrometer")
source_segm = render_segmentation_mask(source_scalebar, filtered_overlay, alpha=0.5)
render_video(source_segm, filename=video_file, **video_config, framerate=framerate)

# display in markdown
display(Markdown("# Your filtered segmentation"))
myvideo = VideoFileClip(video_file)
myvideo.ipython_display()

# 5. Perform Tracking

In [None]:
from acia.tracking.processor.trackastra import TrackastraTracker

tracker = TrackastraTracker()

ov, tracklet_graph, tracking_graph = tracker(source.to_channel(phase_contrast_channel), filtered_overlay)

In [None]:
tracking_graph.number_of_nodes(), tracking_graph.number_of_edges()

In [None]:
from acia.viz import render_tracking_mask, render_tracking

segm_sequence = render_tracking_mask(source.to_channel(phase_contrast_channel).to_rgb(), ov)
tracked_sequence = render_tracking(segm_sequence, ov, tracking_graph)

In [None]:
from acia.viz import render_video

video_config = dict(codec="vp9", ffmpeg_params = ["-crf", "36", "-b:v", "0", "-speed", "1"])

video_file = str(output_path_rel / "tracking.mp4")
render_video(tracked_sequence, video_file, 20, **video_config)

In [None]:
# display in markdown
display(Markdown("# Your tracked sequence"))
myvideo = VideoFileClip(video_file)
myvideo.ipython_display()

# 6. Tracking Analysis

# Now remove artifacts coming from inconsistent segmentation

Segmentation is especially inconsistent when cells divide, flickering between a single and two cells.
We merge two cell detections pushing the cell division event to the latest possible.

In [None]:
from acia.base import Contour
from shapely.ops import unary_union
from shapely.geometry import MultiPolygon
import networkx as nx
from acia.base import Instance

def remove_incosistent_seg(sub_ov: Overlay, tracklet_graph: nx.DiGraph, num_timesteps=3):
    """Removes incosistent segmentation using the tracking information"""

    def cond(n, graph, num_nodes=num_timesteps):
        """ returns true if this is an event where the tracking is inconsistent and should be joined """
        if graph.out_degree(n) != 2:
            return False

        children = sorted(graph.successors(n), key=lambda v: graph.out_degree(v))

        # check whether we have a dead end and a continous cell
        if not (graph.out_degree(children[0]) == 0 and graph.out_degree(children[1]) >= 1):
            return False

        # check that the dead end durtion is not too long
        dur = graph.nodes[children[0]]["end_frame"] - graph.nodes[children[0]]["start_frame"]
        if dur > num_nodes:
            return False

        return True

    # collect all the siblsings that should be joined
    to_join = []
    for n in tracklet_graph.nodes:
        # check the join condition
        if cond(n, tracklet_graph):

            children = sorted(tracklet_graph.successors(n), key=lambda v: tracklet_graph.out_degree(v))
            to_join.append(children)

    new_ov = Overlay([])

    remove_labels = set([join_set[0] for join_set in to_join])

    # create the new overay where masks are joined
    for i, ov in enumerate(sub_ov.timeIterator()):
        frame_label_set = set([it.label for it in ov])

        to_add = []
        to_remove = []

        for join_set in to_join:
            if set(join_set).issubset(frame_label_set):
                #print(f"Frame: {i} -> Need to change overlay")

                def label_lookup(label):
                    return [cont for cont in ov if cont.label == label][0]

                #print(ov.cont_lookup)
                polys = [label_lookup(join_set[0]).polygon.buffer(2), label_lookup(join_set[1]).polygon.buffer(5)]

                res_poly = unary_union(polys)
                res_poly = res_poly.buffer(-5)

                if isinstance(res_poly, MultiPolygon):
                    area_before = res_poly.area
                    max_size_index = np.argmax([g.area for g in res_poly.geoms])
                    res_poly = res_poly.geoms[max_size_index]
                    logging.warning(f"Need to fix multipolygon. Area from {area_before} to {res_poly.area}")


                # this polygon needs to be added
                cont = Contour(np.stack(res_poly.exterior.xy, axis=-1), -1, frame=i, id=label_lookup(join_set[1]).id, label=join_set[1])
                #print(cont.coordinates)
                to_add.append(cont)
                to_remove.append(join_set[1])

        all_remove = remove_labels.union(set(to_remove))
        new_ov.add_contours([cont for cont in ov if cont.label not in all_remove] + to_add)

    # remove joined labels from the tracklet graph
    remove_labels = set([join_set[0] for join_set in to_join])
    for n in remove_labels:
        tracklet_graph.remove_node(n)

    # join tracklets (we have remove wrong divisions but still need to join the tracklets)
    tracklets_to_join = []

    for n in list(nx.dfs_preorder_nodes(tracklet_graph)):
        if tracklet_graph.out_degree(n) == 1:
            tracklets_to_join.append((n, list(tracklet_graph.successors(n))[0]))

    relabel_actions = {n:n for n in tracklet_graph.nodes}

    for a,b in tracklets_to_join:

        relabel_actions[b] = relabel_actions[a]

    # actually join the tracklets
    for b,a in relabel_actions.items():

        # join the two
        b_children = tracklet_graph.successors(b)

        # ensure connectivity
        for b_child in b_children:
            tracklet_graph.add_edge(a, b_child)

        # update end frame
        tracklet_graph.nodes[a]["end_frame"] = np.max([tracklet_graph.nodes[a]["end_frame"], tracklet_graph.nodes[b]["end_frame"]])

    # remove nodes
    tracklet_graph.remove_nodes_from(set(tracklet_graph.nodes).difference(relabel_actions.values()))

    for cont in new_ov:
        if isinstance(cont, Instance) and cont.label != relabel_actions[cont.label]:
            cont.mask = (cont.mask == cont.label) * relabel_actions[cont.label]
        cont.label = relabel_actions[cont.label]

    return new_ov, tracklet_graph

In [None]:
from acia.tracking.utils import merge_incosistent_segmentation

new_ov, tracklet_graph = merge_incosistent_segmentation(ov, tracklet_graph)

In [None]:
from acia.analysis import ExtractorExecutor, AreaEx, IdEx, FrameEx, TimeEx, LengthEx, PositionEx, LabelEx
from acia import ureg
import numpy as np

ex = ExtractorExecutor()

df = ex.execute(new_ov, source, [
    # define the cell properties that you want to extract here
    LabelEx(),
    AreaEx(input_unit=pixel_size ** 2),  # pass the correct area of pixels
    LengthEx(input_unit=pixel_size),  # pass the correct size of pixels
    PositionEx(input_unit=pixel_size),
    FrameEx(),
    TimeEx(input_unit=imaging_interval),  # 1/6 = one picture every 10 min, 1/60 = every 1 minutes, 1/360 = every 10 seconds
])

In [None]:
df

# 🔍 We are first investigating the temporal population development

In [None]:
switch_interval = 0.5

In [None]:
import seaborn as sns

fig, axes = plt.subplots(1, 1, figsize=(6, 6))

sns.lineplot(df[df.label.isin(tracklet_graph.nodes)].groupby(["frame", "time"]).agg("sum"), x="time", y="area", ax=axes)
plt.yscale("log")

colors = ["green", "red"]
state = 0
x_start = 0
total_time = 3

for _ in range(int(np.ceil(total_time / switch_interval))):
    plt.axvspan(xmin=x_start, xmax=x_start+switch_interval, color=colors[state], alpha=0.2)

    x_start += switch_interval
    state = (state + 1) % 2

plt.xlim((0, 3))

plt.ylabel("Total Single-Cell Area [$\mu m^2$]")
plt.xlabel("Time [h]")
plt.grid(True)

plt.tight_layout()

plt.savefig(output_path / "tsca.png", dpi=300)

## Investigate the lineage

In [None]:
new_tracklet_graph = tracklet_graph.copy()

nodes_to_remove = []
for n in new_tracklet_graph:
    if (new_tracklet_graph.out_degree(n) == 0) and (new_tracklet_graph.in_degree(n) == 0):
        nodes_to_remove.append(n)

new_tracklet_graph.remove_nodes_from(nodes_to_remove)

In [None]:
from acia.tracking.utils import tracklet_to_tracking

new_tracking = tracklet_to_tracking(new_ov, new_tracklet_graph)

# add time and label information
for n in new_tracking:
    new_tracking.nodes[n]["time"] = df.loc[n]["time"]
    new_tracking.nodes[n]["label"] = df.loc[n]["label"]

In [None]:
from acia.viz import hierarchy_pos_loop_multi, plot_lineage_tree

G = new_tracking

fig, axes = plt.subplots(1, 1, figsize=(10, 6))

colors = ["green", "red"]
state = 0
x_start = 0
total_time = 3

for _ in range(int(np.ceil(total_time / switch_interval))):
    plt.axvspan(xmin=x_start, xmax=x_start+switch_interval, color=colors[state], alpha=0.2)

    x_start += switch_interval
    state = (state + 1) % 2

plt.xlim((0, 3))

roots = [n for n, d in G.in_degree() if d == 0 and G.nodes[n]["frame"] == 0]
print(roots)
pos = hierarchy_pos_loop_multi(G, roots)
plot_lineage_tree(G, pos, mode='horizontal', draw_labels=False, tick_color=None, branch_color="black", flip_vertical=False, ax=axes, y_attr="time", label_attr="label")

axes.set_xlabel("Time [h]")

plt.tight_layout()

plt.savefig(output_path / "lineage.png", dpi=300)

In [None]:
from acia.viz import plot_cell_lineage

import plotly.io as pio
pio.renderers.default = "notebook_connected"

fig, ax = plt.subplots(1, 1)

# Plotly interactive plot
fig = plot_cell_lineage(
    new_tracking, orientation='horizontal', time_feature='time',
    show_label=False, label_name='label',
    node_marker="o", node_ms=5,
    line_color="blue", line_lw=2,
    mark_births=True, birth_color="orange", birth_marker=">", birth_ms=7,
    mark_ends=True, end_color="black", end_marker='s', end_ms=7,
    #figure_title="Plotly interactive lineage"
)

## Investigate single-cell behavior

In [None]:
df_timings = df[["label", "time"]].groupby("label").agg(["min", "max"]).reset_index()

# cells living between 1.4h and 1.6h
interesting_labels = df_timings[(df_timings[("time", "min")] < 1.4) & (df_timings[("time", "max")] > 1.6)]["label"].to_list()

base_labels = set([new_tracking.nodes[n]["label"] for n in new_tracking])
interesting_labels = list(set(interesting_labels).intersection(base_labels))
print("Interesting cell candidates: ", interesting_labels)

### Compute temporal single-cell development

In [None]:
from scipy.ndimage import gaussian_filter1d

df["deriv_filtered"] = np.nan
df["image_id"] = image_id

sigma = 4

# Compute derivative
for il in np.unique(df["label"]):
    y = np.array(df[df.label==il]["area"].to_list())
    y_filtered = np.array(gaussian_filter1d(y, sigma=sigma))
    derivative_filtered = gaussian_filter1d((y[1:] - y[:-1]) * (1/imaging_interval.to("hour").magnitude), sigma=sigma)

    val = [np.nan] + list(derivative_filtered)

    df.loc[df.label==il, "deriv_filtered"] = val
    df.loc[df.label==il, "area_filtered"] = y_filtered


df[df.label.isin(interesting_labels)].to_csv(output_path / "interesting_single_cell.csv")

### Show single-cell development with oxygen fluctuations

In [None]:
import matplotlib.gridspec as gridspec

gs = gridspec.GridSpec(2*len(interesting_labels),2, hspace=0.0)
fig = plt.figure(figsize=(6, 2*len(interesting_labels)))

axes = np.zeros((len(interesting_labels), 2), dtype=object)

for i, l in enumerate(interesting_labels):
    axes[i,0] = fig.add_subplot(gs[i, 0])
    axes[i,1] = fig.add_subplot(gs[i, 1])

    axes[i,0].set_xlim((1.3, 1.7))
    axes[i,0].set_ylim((3, 6.5))

    sns.lineplot(df[df["label"] == l], x="time", y="area_filtered", ax=axes[i,0], marker="|", markeredgecolor="black")
    sns.lineplot(df[df["label"] == l], x="time", y="deriv_filtered", ax=axes[i, 1], marker="|", markeredgecolor="black")


    axes[i,0].set_xlim((1.3, 1.7))
    axes[i,1].set_xlim((1.3, 1.7))

    axes[i,0].set_ylim((3, 6.5))
    axes[i,1].set_ylim((0, 14))

axes[0,0].set_title("Single Cell Area")
axes[0,1].set_title("Single Cell Instant growth rate")

for ax in axes[:-1].flatten():
    ax.tick_params(
        axis='x',          # changes apply to the x-axis
        which='both',      # both major and minor ticks are affected
        bottom=False,      # ticks along the bottom edge are off
        top=False,         # ticks along the top edge are off
        labelbottom=False) # labels along the bottom edge are off
    ax.sharex(axes[-1, 0])

for i in range(2):
    axes[-1, i].set_xlabel("Time [h]")


colors = ["green", "red"]
total_time = 3

for i, ax in enumerate(axes[:,0]):
    state = 0
    x_start = 0
    for _ in range(int(np.ceil(total_time / switch_interval))):
        ax.axvspan(xmin=x_start, xmax=x_start+switch_interval, color=colors[state], alpha=0.2)

        x_start += switch_interval
        state = (state + 1) % 2

    ax.grid(True)
    ax.set_ylabel(f"Cell #{interesting_labels[i]}\n[$\mu m^2$]")

for i, ax in enumerate(axes[:,1]):
    state = 0
    x_start = 0
    for _ in range(int(np.ceil(total_time / switch_interval))):
        ax.axvspan(xmin=x_start, xmax=x_start+switch_interval, color=colors[state], alpha=0.2)

        x_start += switch_interval
        state = (state + 1) % 2

    ax.grid(True)
    ax.set_ylabel(r"[$\frac{\mu m^2}{h}$]")

plt.tight_layout()

plt.savefig(output_path / "sca.png", dpi=300)


## 🔁 Reproducibility Information

pip and conda environment details

In [None]:
%pip freeze

In [None]:
%mamba env export