# Identifying Trajectories with MAGIK

This notebook provides you with a complete code example to track the motion of cells using a simplified version of MAGIK.

## Exploring the Cell Tracking Data

Download the dataset ...

In [1]:
import os
from torchvision.datasets.utils import download_url, _extract_zip

dataset_name = "DIC-C2DH-HeLa"
dataset_path = os.path.join(".", "cell_detection_dataset")
if not os.path.exists(dataset_path):
    url = ("http://data.celltrackingchallenge.net/training-datasets/"
        + f"{dataset_name}.zip")
    download_url(url, ".")
    _extract_zip(f"{dataset_name}.zip", dataset_path, None)
    os.remove(f"{dataset_name}.zip")

... load the images ...

In [2]:
import cv2, glob

def load_images(path):
    """Load images."""
    images = []
    for file in sorted(glob.glob(path + "/*.tif")):
        image = cv2.imread(file, cv2.IMREAD_UNCHANGED)
        images.append(image)
    return images

train_image_path = os.path.join(dataset_path, "DIC-C2DH-HeLa", "02")
train_images = load_images(train_image_path)

train_seg_path = os.path.join(dataset_path, "DIC-C2DH-HeLa", "02_ST", "SEG")
train_segs = load_images(train_seg_path)

... and visualize some of the images and corresponding segmentations.

In [None]:
import matplotlib.pyplot as plt

frames_to_plot = [i for i in range(0, len(train_segs), len(train_segs) // 5)]

fig, axs = plt.subplots(2, len(frames_to_plot), figsize=(20, 6))
fig.patch.set_facecolor("white")
for i, frame in enumerate(frames_to_plot):
    axs[0, i].imshow(train_images[frame], cmap="gray")
    axs[0, i].set_title(f"Frame {frame}", fontsize=16)
    axs[0, i].tick_params(axis="both", which="both", bottom=False, top=False, 
                          left=False, right=False, labelleft=False, 
                          labelbottom=False)
    if i == 0: axs[0, i].set_ylabel("Intensity image", fontsize=16)

    axs[1, i].imshow(train_segs[frame], cmap="tab20b")
    axs[1, i].tick_params(axis="both", which="both", bottom=False, top=False, 
                          left=False, right=False, labelleft=False, 
                          labelbottom=False)
    if i == 0: axs[1, i].set_ylabel("Segmentation", fontsize=16)    
plt.subplots_adjust(wspace=0.02, hspace=0.02)
plt.show()

## Creating a Graph From Segmented Images

Implement a class to obtain a graph from the segmentations ...

In [4]:
import numpy as np
import torch
from skimage import measure
from torch_geometric.data import Data

class GraphFromSegmentations:
    """Graph representation of the motion of cells."""
    
    def __init__(self, connectivity_radius, max_frame_distance):
        """Initialize graph."""
        self.connectivity_radius = connectivity_radius
        self.max_frame_distance = max_frame_distance
    
    def compute_node_attr(self, segmentation):
        """Compute node attributes."""
        labels = np.unique(segmentation)
        
        node_attr, indices = [], []
        for label in labels[1:]:
            mask = segmentation == label
            props = measure.regionprops(mask.astype(np.int32))[0]

            centroids = props.centroid / np.array(segmentation.shape)
            eccentricity = props.eccentricity

            node_attr.append([*centroids, eccentricity])
            indices.append(label)
            
        return node_attr, indices
    
    def compute_connectivity(self, x, frames):
        """Compute connectivity."""
        positions = x[:, :2]
        distances = np.linalg.norm(positions[:, None] - positions, axis=-1)

        frame_diff = (frames[:, None] - frames) * -1

        mask = (distances < self.connectivity_radius) & ( 
            (frame_diff <= self.max_frame_distance) & (frame_diff > 0)
        )

        edge_index = np.argwhere(mask) 
        edge_attr = distances[mask] 

        return edge_index, edge_attr
    
    def compute_ground_truth(self, indices, edge_index, relation):
        """Compute ground truth."""
        sender = indices[edge_index[:, 0]] 
        receiver = indices[edge_index[:, 1]]
        self_connections_mask = sender == receiver

        relation_indices = relation[:, [-1, 0]] 
        relation_indices = relation_indices[relation_indices[:, 0] != 0]

        relation_mask = np.zeros(len(edge_index), dtype=bool)
        for i, (s, r) in enumerate(zip(sender, receiver)):
            if np.any((relation_indices == [s, r]).all(1)): 
                relation_mask[i] = True

        ground_truth = self_connections_mask | relation_mask

        return ground_truth
    
    def __call__(self, segmentations, relation):
        """Compute graph."""
        x, node_index_labels, frames = [], [], []
        for frame, segmentation in enumerate(segmentations):
            features, index_labels = self.compute_node_attr(segmentation)
            x.append(features)
            node_index_labels.append(index_labels)
            frames.append([frame] * len(features))
        x = np.concatenate(x)
        node_index_labels = np.concatenate(node_index_labels)
        frames = np.concatenate(frames)

        edge_index, edge_attr = self.compute_connectivity(x, frames)
        edge_ground_truth = self.compute_ground_truth(
            node_index_labels, edge_index, relation
        )

        edge_index = edge_index.T
        edge_attr = edge_attr[:, None]
        edge_ground_truth = edge_ground_truth[:, None]

        graph = \
            Data(x=torch.tensor(x, dtype=torch.float),
            edge_index=torch.tensor(edge_index, dtype=torch.long),
            edge_attr=torch.tensor(edge_attr, dtype=torch.float),
            distance=torch.tensor(edge_attr, dtype=torch.float),
            frames=torch.tensor(frames, dtype=torch.float),
            y=torch.tensor(edge_ground_truth, dtype=torch.float))
        return graph

... instantiate it ...

In [5]:
graph_constructor = GraphFromSegmentations(connectivity_radius=0.2, 
                                           max_frame_distance=2)

... use it to construct the training graph ...

In [6]:
train_file = os.path.join(dataset_path, "DIC-C2DH-HeLa", "02_GT", "TRA", 
                          "man_track.txt")
train_graph = graph_constructor(segmentations=train_segs, 
                                relation=np.loadtxt(train_file, dtype=int))

... and plot the graph.

In [None]:
plt.figure(figsize=(8, 8))
for i, j in train_graph.edge_index.T:
    plt.plot([train_graph.x[i, 1], train_graph.x[j, 1]],
             [train_graph.x[i, 0], train_graph.x[j, 0]], c="k", alpha=0.5)
plt.scatter(train_graph.x[:, 1], train_graph.x[:, 0], 
            c=train_graph.frames, cmap="viridis", zorder=10)
cb = plt.colorbar()
cb.ax.set_title("Frame", fontsize=14)
plt.xlabel("x", fontsize=14); plt.ylabel("y", fontsize=14)
plt.show()

## Building a Training Dataset

Implement a class to prepare the graph dataset ...

In [8]:
class CellTracingDataset(torch.utils.data.Dataset):
    """Class to prepare the graph dataset."""
    
    def __init__(self, graph, Dt, dataset_size, transform=None):
        """Initialize the graph dataset."""
        self.graph, self.Dt, self.dataset_size, self.transform = \
            graph, Dt, dataset_size, transform 

        frames, edge_index = graph.frames, graph.edge_index
        self.pair_frames = torch.stack(
            [frames[edge_index[0, :]], frames[edge_index[1, :]]], axis=1
        )
        self.frames, self.max_frame = frames, frames.max()

    def __len__(self):
        """Obtain length of dataset."""
        return self.dataset_size

    def __getitem__(self, idx):
        frame_idx = np.random.randint(self.Dt, self.max_frame + 1)

        start_frame = frame_idx - self.Dt
        node_mask = (self.frames >= start_frame) & (self.frames < frame_idx)
        x = self.graph.x[node_mask]

        edge_mask = ((self.pair_frames >= start_frame) 
                     & (self.pair_frames < frame_idx))
        edge_mask = edge_mask.all(axis=1) 

        edge_index = self.graph.edge_index[:, edge_mask]
        edge_index -= edge_index.min() 

        edge_attr = self.graph.edge_attr[edge_mask]

        # sample ground truth edges
        ground_truth_edges = self.graph.y[edge_mask]

        graph = Data(x=x, edge_index=edge_index, edge_attr=edge_attr,
                     distance=edge_attr, y=ground_truth_edges)
        if self.transform:
            graph = self.transform(graph)
        return graph

... implement some classes to define random rotation and random flip ...

In [9]:
from math import cos, pi, sin

class RandomRotation:
    """Random rotation."""
    
    def __call__(self, graph):
        """Perform the random rotation."""
        graph = graph.clone()
        node_feats = graph.x[:, :2] - 0.5  # Centered positons.
        angle = np.random.rand() * 2 * pi
        R = torch.tensor([[cos(angle), -sin(angle)], [sin(angle), cos(angle)]])
        rotated_node_attr = torch.matmul(node_feats, R)
        graph.x[:, :2] = rotated_node_attr + 0.5  # Restored positons.
        return graph
    
class RandomFlip:
    """Random flip."""
    
    def __call__(self, graph):
        """Perform the random flip."""
        graph = graph.clone()
        node_feats = graph.x[:, :2] - 0.5  # Centered positons.
        if np.random.randint(2): node_feats[:, 0] *= -1
        if np.random.randint(2): node_feats[:, 1] *= -1
        graph.x[:, :2] = node_feats + 0.5  # Restored positons.
        return graph

... create the training dataset ...

In [10]:
from torchvision.transforms import Compose

train_set = CellTracingDataset(
    train_graph, Dt=5, dataset_size=512, 
    transform=Compose([RandomRotation(), RandomFlip()]),
)

... and define the data loaders.

In [None]:
from torch_geometric.data import DataLoader                                     ### from torch_geometric.loader import DataLoader

train_loader = DataLoader(train_set, batch_size=8, shuffle=True)

## Making MAGIK

Define the MAGIK model ...

In [None]:
import deeplay as dl

model = dl.GraphToEdgeMAGIK([96,] * 4, 1, out_activation=torch.nn.Sigmoid)
model.encoder[0].configure(hidden_features=[32, 64], out_features=96, 
                           out_activation=torch.nn.ReLU)
model.encoder[1].configure(hidden_features=[32, 64], out_features=96, 
                           out_activation=torch.nn.ReLU)
model.head.configure(hidden_features=[64, 32])

print(model)

... train the MAGIK model ...

In [None]:
classifier = dl.BinaryClassifier(model=model, optimizer=dl.Adam(lr=1e-3))
classifier = classifier.create()

trainer = dl.Trainer(max_epochs=10)  ### epochs=10
trainer.fit(classifier, train_loader)

## Evaluating Performance

Load the test data ...

In [14]:
test_image_path = os.path.join(dataset_path, "DIC-C2DH-HeLa", "01")
test_seg_path = os.path.join(dataset_path, "DIC-C2DH-HeLa", "01_ST", "SEG")

test_images = load_images(test_image_path)
test_segs = load_images(test_seg_path)

... construct the test graph ...

In [15]:
test_file = os.path.join(dataset_path, "DIC-C2DH-HeLa", "01_GT", "TRA", 
                         "man_track.txt")
test_graph = graph_constructor(segmentations=test_segs, 
                               relation=np.loadtxt(test_file, dtype=int))

... assess the model performance with the F1-score ...

In [None]:
from sklearn.metrics import f1_score

classifier.eval()
predictions = classifier(test_graph)
predictions = predictions.detach().numpy() > 0.5

ground_truth = test_graph.y

score = f1_score(ground_truth, predictions)
print(f"Test F1 score: {score}")

... implement a class to compute trajectories from MAGIK results ...

In [17]:
import networkx as nx

class ComputeTrajectories:
    """Computation of trajectories."""

    def __call__(self, graph, predictions):
        """Compute trajectories."""
        pruned_edges = self.prune_edges(graph, predictions)
        pruned_graph = nx.Graph()
        pruned_graph.add_edges_from(pruned_edges)
        trajectories = list(nx.connected_components(pruned_graph))
        return trajectories

    def prune_edges(self, graph, predictions):
        """Prune edges."""
        pruned_edges = []
        frame_pairs = np.stack([graph.frames[graph.edge_index[0]], 
                                graph.frames[graph.edge_index[1]]], axis=1)
        for src_cell in np.unique(graph.edge_index[0]): 
            src_cell_mask = graph.edge_index[0] == src_cell
            tgt_cell_candidates = predictions[src_cell_mask] == True
            if np.any(tgt_cell_candidates):
                frame_diff = (frame_pairs[src_cell_mask, 1] 
                            - frame_pairs[src_cell_mask, 0])
                min_frame_diff = frame_diff[tgt_cell_candidates].min()
                tgt_cell_mask = (tgt_cell_candidates 
                                 & (frame_diff == min_frame_diff))
                edge = graph.edge_index[:, src_cell_mask][:, tgt_cell_mask]
                edge = edge.reshape(-1, 2)
                if len(edge) == 1:
                    pruned_edges.append(tuple(*edge.numpy()))
        return pruned_edges

... compute the trajectories ...

In [18]:
compute_trajectories = ComputeTrajectories()
trajectories = compute_trajectories(test_graph, predictions.squeeze())

... and visualize the cell trajectories as a video.

In [19]:
import matplotlib as mpl

# Increase the embedding size limit to 60 MB.
mpl.rcParams["animation.embed_limit"] = 60

In [None]:
from IPython.display import HTML
from matplotlib.animation import FuncAnimation

fig, ax = plt.subplots(figsize=(8, 8))
list_of_colors = plt.get_cmap("tab20b", len(trajectories))

def update(frame):
    """Update frame."""
    ax.clear()
    ax.imshow(test_images[frame], cmap="gray")

    segmentation = test_segs[frame]
    for label in np.unique(segmentation)[1:]:
        contour = measure.find_contours(segmentation == label, 0.5)[0]
        ax.fill(contour[:, 1], contour[:, 0], color=list_of_colors(label),
                alpha=0.5, linewidth=2)
    ax.text(0, -5, f"Frame: {frame}", fontsize=16, c="k")

    for i, t in enumerate(trajectories):
        frames = test_graph.frames[list(t)]
        xy_all = test_graph.x[list(t)] * 512
        xy_frame = xy_all[frames == frame]
        if len(xy_frame) != 0:
            ax.scatter(xy_frame[:, 1], xy_frame[:, 0])
            ax.text(xy_frame[0, 1], xy_frame[0, 0], str(i), fontsize=16, c="w")
                    
            xy_previous = xy_all[(frames <= frame) & (frames >= frame - 10)]
            ax.plot(xy_previous[:, 1], xy_previous[:, 0], c="w")
                    
            ax.plot(xy_frame[max(0, frame - 10):frame, 1], 
                    xy_frame[max(0, frame - 10):frame, 0], c="r")
    return ax

animation = FuncAnimation(fig, update, frames=len(test_segs))
video = HTML(animation.to_jshtml()); plt.close()
video