In [None]:
%matplotlib inline
import sys
sys.path.insert(0, "../..")

# Example. Bacterial motility analysis using MAGIK

## 1. Setup

Imports the objects needed for this example.

In [2]:
import deeptrack as dt
from deeptrack.models.gnns.generators import GraphGenerator

import tensorflow as tf

import pandas as pd
import numpy as np
import scipy.io

from deeptrack.extras import datasets

import logging
logging.disable(logging.WARNING)

--------------------------------------------------------------------------------

  CuPy may not function correctly because multiple CuPy packages are installed
  in your environment:

    cupy, cupy-cuda101

  Follow these steps to resolve this issue:

    1. For all packages listed above, run the following command to remove all
       existing CuPy installations:

         $ pip uninstall <package_name>

      If you previously installed CuPy via conda, also run the following:

         $ conda uninstall cupy

    2. Install the appropriate CuPy package.
       Refer to the Installation Guide for detailed instructions.

         https://docs.cupy.dev/en/stable/install.html

--------------------------------------------------------------------------------



## 2. Overview

In this example, we exemplify how to use [MAGIK](https://arxiv.org/abs/2202.06355) (Motion Analysis through GNN Inductive Knowledge) in its most natural application, trajectory linking. We will analyze bacterial motility experiments and address practical implications of using MAGIK.

## 2. Defining the dataset

### 2.1 Defining the training set

In [4]:
# read .mat file
data = scipy.io.loadmat('datasets/training/EC1212_Si16.mat')

Extract active particles´ trajectories from loaded data

In [5]:
# y-coordinates
ty = np.array(
    data.get("Data")["AtrajectoriesY_px"][0][0].todense()
)
# x-coordinates
tx = np.array(      
    data.get("Data")["AtrajectoriesX_px"][0][0].todense()
)

Missing data is represented by NaN.

In [6]:
ty[ty == 0.0] = np.nan
tx[tx == 0.0] = np.nan

Split broken trajectories into separate sub-trajectories.

In [7]:
ty = list(ty)
tx = list(tx)

def split(a):
    return [a[s] for s in np.ma.clump_unmasked(np.ma.masked_invalid(a))]   

tr_y, tr_x = [], []
for _tx, _ty in zip(tx, ty):
    if (
        (np.count_nonzero(~np.isnan(_tx)) >= 25)
        and (np.diff(np.quantile(_tx[~np.isnan(_tx)], [0.01, 0.99])) > 8)
        and (np.diff(np.quantile(_ty[~np.isnan(_ty)], [0.01, 0.99])) > 8)
    ):
        _tr_x = split(_tx)
        _tr_y = split(_ty)

        for t in range(len(_tr_x)):
            tr_x.append(_tr_x[t])
            tr_y.append(_tr_y[t])

Filter out trajectories with less than 15 points.

In [8]:
tr_x = list(filter(lambda x: len(x) > 15, tr_x))
tr_y = list(filter(lambda x: len(x) > 15, tr_y))

MAGIK models the objects’ motion and physical interactions using a graph representation. Graphs can define arbitrary relational structures between nodes connecting them pairwise through edges. In MAGIK, each node describes an object detection at a specific time.

Create the node dataframe from the training trajectories:

In [9]:
nodesdf = []
for label, (_tx, _ty) in enumerate(zip(tr_x, tr_y)):
    df = pd.DataFrame(
        np.concatenate(
            [
                np.arange(0, len(_tx))[..., np.newaxis],
                np.expand_dims(_tx, -1),
                np.expand_dims(_ty, -1),
            ],
            axis=1,
        ),
        columns=["frame", "centroid-x", "centroid-y"],
    )
    df["label"] = label
    nodesdf.append(df)

nodesdf = pd.concat(nodesdf)
nodesdf = nodesdf.sort_values(by = 'frame').reset_index(drop=True)

# normalize centroids between 0 and 1
nodesdf.loc[:, nodesdf.columns.str.contains("centroid")] = (
    nodesdf.loc[:, nodesdf.columns.str.contains("centroid")]
    / np.array([1000.0, 1000.0])
)

nodesdf[["solution", "set"]] = 0.0
nodesdf = nodesdf.astype({'frame': 'int', 'set': 'int'})

Display the first 20 rows of the dataframe.

In [10]:
nodesdf.head(20)

Unnamed: 0,frame,centroid-x,centroid-y,label,solution,set
0,0,0.3245,0.656706,0,0.0,0
1,0,0.2465,0.1605,122,0.0,0
2,0,0.721529,0.793569,25,0.0,0
3,0,0.353347,0.573347,123,0.0,0
4,0,0.93298,0.897157,124,0.0,0
5,0,0.7135,0.961,125,0.0,0
6,0,0.46098,0.452118,24,0.0,0
7,0,0.553347,0.421347,121,0.0,0
8,0,0.355647,0.726824,126,0.0,0
9,0,0.735865,0.227904,127,0.0,0


``nodesdf`` contains the following columns:

- ``label``: cell label. Only used during training.

- ``centroid-x``: x-centroid coordinate (normalized between 0 and 1).

- ``centroid-y``: y-centroid coordinate (normalized between 0 and 1).

- ``frame``: frame corresponding to the detection.

- ``solution``: node ground truth (ignored). only used for node classification/regression tasks.

- ``set``: Index of the video in the dataset. The first video in the dataset is set 0, the second is set 1, and so on. Useful if multiples videos are available.

Importantly, There are no intrinsic restrictions on the type or number of descriptors (e.g., location and morphological features, image-based quantitiesn) that can be encoded in the node representation. 
In this example, we have only used the position of the object to train the model.

For object linking, the aim of MAGIK is to prune the wrong edges while retaining the true connections, i.e., an edge-classification problem with a binary label (linked/unlinked). We thus define:

In [None]:
# Output type
_OUTPUT_TYPE = "edges"

In MAGIK, nodes are connected to spatio-temporal neighbors within a distance-based likelihood radius:

In [None]:
# Seach radius for the graph edges
radius = 0.2

Finally, Let's create a dummy feature to store our configuration for the graph generation process:

In [None]:
variables = dt.DummyFeature(
    radius=radius,
    output_type=_OUTPUT_TYPE,
    nofframes=3, # time window to associate nodes (in frames) 
)

## 3. Defining the network

MAGIK is defined, with binary crossentropy as loss function.

In [None]:
model = dt.models.gnns.MAGIK(
    dense_layer_dimensions=(64, 96,),      # number of features in each dense encoder layer
    base_layer_dimensions=(96, 96, 96),    # Latent dimension throughout the message passing layers
    number_of_node_features=2,             # Number of node features in the graphs
    number_of_edge_features=1,             # Number of edge features in the graphs
    number_of_edge_outputs=1,              # Number of predicted features
    edge_output_activation="sigmoid",      # Activation function for the output layer
    output_type=_OUTPUT_TYPE,              # Output type. Either "edges", "nodes", or "graph"
)

# Compile model
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
    loss = 'binary_crossentropy',
    metrics=['accuracy'],
)

model.summary()


## 4. Training the network

if ``_LOAD_MODEL`` is set to ``False``, we train the model using the BF-C2DL-MuSC dataset. `GraphGenerator` creates the graph representation from the detections and returns a continuous generator that asynchronously generates graphs during training.

In [None]:
_LOAD_MODEL = False

if _LOAD_MODEL:
    print("Loading model...")
    model.load_weights("")
else:
    generator = GraphGenerator(
        nodesdf=nodesdf,
        properties=["centroid"],
        min_data_size=511,
        max_data_size=512,
        batch_size=8,
        **variables.properties()
    )
    
    with generator:
        model.fit(generator, epochs=10)

## 5. Evaluating the network

Now, let's download our dataset.

In [None]:
# REPLACE BY TESTING DATA

We previously detected the position of the cells in each frame of the dataset using LodeSTAR and stored them in ``test_nodesdf``:

In [None]:
test_nodesdf = # REPLACE BY TESTING DATA

# display the first 20 rows of the dataframe
test_nodesdf.head(20)

Compute predictions for the test set:

In [None]:
pred, gt, scores, graph = dt.models.gnns.get_predictions(
    test_nodesdf, ["centroid"], model, variables
)

Crate dataframe from results:

In [None]:
edges_df, nodes, _ = dt.models.gnns.df_from_results(pred, gt, scores, graph)

# display the first 10 rows of the dataframe
edges_df.head(10)

Compute trajectories, and filter out trajectories less than 8 frames long:

In [None]:
# Get trajectories from results
traj = dt.models.gnns.get_traj(edges_df, th = 8)

Display results:

In [None]:
import glob
import cv2

import matplotlib.pyplot as plt

frames = glob.glob("")

for f, frame in enumerate(frames):
    img = cv2.imread(frame, -1)

    fig = plt.figure(figsize=(10, 10))
    plt.imshow(img)
    plt.text(10, 40, "Frame: " + str(f), fontsize=20, c="white")
    plt.axis("off")

    for i, (t, c) in enumerate(traj):
        detections = nodes[t][(nodes[t, 0] <= f) & (nodes[t, 0] >= f - 10), :]

        if (len(detections) == 0) or (np.max(nodes[t, 0]) < f):
            continue

        plt.plot(detections[:, 2] * 1200, detections[:, 1] * 1200, color = c, linewidth=2)
        plt.scatter(detections[-1, 2] * 1200, detections[-1, 1] * 1200, linewidths=1.5, c = c)
    
    plt.show()