In [None]:
# %matplotlib inline
# !pip install deeptrack

# Example 1. Single-level trajectory analysis using Transformers


## 1. Setup

Imports the objects needed for this example.


In [None]:
import deeptrack as dt
from deeptrack.extras import datasets

import tensorflow as tf

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

import scipy.sparse


## 2. Overview

In this example, [...]


## 2. Defining the dataset

### 2.1 Defining the training set


In [None]:
# Download the STrajCh dataset
datasets.load("STrajCh")

In [None]:
TRAINING_PATH = "datasets/STrajCh/training/{file}.npz"

# read training data
train_data = ()
for file in ("data", "indices", "labels"):
    train_data += (
        scipy.sparse.load_npz(TRAINING_PATH.format(file=file)).toarray(),
    )


In [None]:
def splitter(randset):
    def inner(inputs):
        data, indices, labels = inputs

        # Convert to numpy array
        data = data._value

        # get indices of the rows belonging to randset
        idx = np.where(indices == randset)[0]

        sdata = data[idx][:, :2]
        sdata = np.concatenate(
            [
                sdata,
                np.array((0, *np.linalg.norm(np.diff(sdata, axis=0), axis=1)))[
                    :, np.newaxis
                ],
                data[idx][:, 2:],
            ],
            axis=1,
        )

        labels = labels[idx]

        return sdata, labels

    return inner


nsamples = np.max(train_data[1])
train_set = dt.Value(lambda: tuple(train_data)) >> dt.Lambda(
    splitter, randset=lambda: np.random.randint(0, nsamples + 1)
)


#### 2.x Visualizing the dataset


In [None]:
fig, axs = plt.subplots(3, 3, figsize=((10, 10)), sharex=True, sharey=True)

cmap = plt.cm.ScalarMappable(
    norm=mpl.colors.Normalize(vmin=0.01, vmax=1.4), cmap=plt.cm.Oranges_r
)

for i in range(9):
    data, labels = train_set.update()()

    data = data[:, :2]

    # extract changepoints
    diff = np.array(labels[1:] - labels[:-1])
    cp = (0, *np.where(diff != 0)[0] + 1, labels.shape[0])

    for idxi, idxj in zip(cp[:-1], cp[1:]):
        axs[i // 3, i % 3].plot(
            data[idxi : idxj + 1, 0],
            data[idxi : idxj + 1, 1],
            c=cmap.to_rgba(labels[idxi])[0],
            zorder=0,
        )
        axs[i // 3, i % 3].scatter(
            data[idxi, 0], data[idxi, 1], c="g", zorder=1, s=20
        )

    # set axis
    axs[i // 3, i % 3].set_xlim([-0.6, 0.6])
    axs[i // 3, i % 3].set_ylim([-0.6, 0.6])
    axs[i // 3, i % 3].set_yticks([-0.5, 0, 0.5])
    axs[i // 3, i % 3].set_xticks([-0.5, 0, 0.5])

# set axis labels
plt.setp(axs[:, 0], ylabel="y-centroid")
plt.setp(axs[-1, :], xlabel="x-centroid")

plt.subplots_adjust(wspace=0.05, hspace=0.05)

In [None]:
data, labels = train_set.update()()

#### 2.x Augment trajectories


In [None]:
def AugmentTrajectories(rotate, translate, flip_x, flip_y):
    """
    Returns a function that augments the input trajectories by applying
    a random rotation, translation, and flip on the centroid coordinates.
    """

    def inner(inputs):
        data, labels = inputs

        # Apply rotation and translation
        centroids = data[:, :2]
        centroids_x = (
            centroids[:, 0] * np.cos(rotate)
            + centroids[:, 1] * np.sin(rotate)
            + translate[0]
        )
        centroids_y = (
            centroids[:, 1] * np.cos(rotate)
            - centroids[:, 0] * np.sin(rotate)
            + translate[1]
        )

        # Apply flip
        if flip_x:
            centroids_x *= -1
        if flip_y:
            centroids_y *= -1

        data[:, 0] = centroids_x
        data[:, 1] = centroids_y

        return data, labels

    return inner

In [None]:
augmented_train_set = train_set >> dt.Lambda(
    AugmentTrajectories,
    rotate=lambda: np.random.rand() * 2 * np.pi,
    translate=lambda: np.random.randn(2) * 0.05,
    flip_x=lambda: np.random.randint(2),
    flip_y=lambda: np.random.randint(2),
)


#### 2.x Pad trajectories


In [None]:
def pad(pad_to):
    def inner(inputs):
        data, labels = inputs

        tlen = int(np.shape(data)[0])

        # create mask
        indices = np.arange(tlen)
        mask = np.stack(
            [np.repeat(indices, tlen), np.tile(indices, tlen)], axis=1
        )

        # pad data
        data = np.pad(data, ((0, pad_to - tlen), (0, 0)), mode="constant")
        labels = np.pad(labels, ((0, pad_to - tlen), (0, 0)), mode="constant")

        # pad mask

        mask = np.pad(
            mask,
            ((0, pad_to ** 2 - np.shape(mask)[0]), (0, 0)),
            mode="constant",
        )

        return (data, mask), labels

    return inner

In [None]:
pad_to = np.unique(
    train_data[1], return_counts=True
)[1].max()

padded_train_set = augmented_train_set >> dt.Lambda(pad, pad_to=pad_to)

#### 3.x Defining data generator


In [None]:
generator = dt.generators.ContinuousGenerator(
    padded_train_set,
    batch_size=8,
    min_data_size=1024,
    max_data_size=1025,
    use_multi_inputs=True,
)


#### 3.x Defining the network


In [None]:
import tensorflow_addons as tfa

model = dt.models.Transformer(
    number_of_node_features=4,
    dense_layer_dimensions=(32, 64, 96),
    number_of_transformer_layers=3,
    base_fwd_mlp_dimensions=256,
    number_of_node_outputs=1,
    node_output_activation="linear",
)
model.summary()


class mae(tf.keras.losses.Loss):
    def call(self, y_true, y_pred):
        return tf.reduce_sum(tf.abs(y_true - y_pred)) / tf.math.count_nonzero(
            y_true, dtype=tf.float32
        )


# Compile model
model.compile(
    tf.keras.optimizers.Adam(learning_rate=0.0001),
    loss=mae(),
)


## 4. Training the network


In [None]:
with generator:
    model.fit(generator, epochs=150)

## 5. Evaluating the network

In [None]:
VALIDATION_PATH = "datasets/STrajCh/validation/{file}.npz"

# read validation data
val_data = ()
for file in ("data", "indices", "labels"):
    val_data += (
        scipy.sparse.load_npz(VALIDATION_PATH.format(file=file)).toarray(),
    )

val_data, idxs , labels= val_data
val_data = val_data[:, 1:]

In [None]:
# sample index
idx = 100

# get indices of the rows belonging to randset
indices = np.where(idxs == idx)[0]

val_sdata = val_data[indices][:, :2]
val_sdata = np.concatenate(
    [
        val_sdata,
        np.array((0, *np.linalg.norm(np.diff(val_sdata, axis=0), axis=1)))[
            :, np.newaxis
        ],
        val_data[indices][:, 2:],
    ],
    axis=1,
)

gt = labels[indices]

In [None]:
import itertools

# Compute predictions
edges = np.array(
    list(itertools.product(*(np.arange(val_sdata.shape[0]),) * 2))
)
pred = model([val_sdata[np.newaxis], edges[np.newaxis]])


In [None]:
plt.plot(gt[:, 0])
plt.plot(pred.numpy()[0, :, 0])