# Cell Detection and Localization

We'll demonstrate how self-supervised learning can be used to detect and localize cells in an image.
We'll use LodeSTAR that permits us to train the neural network using a single crop of a cell without the need for ground truth.

## Download the Dataset

We will use the BF-C2DL-HSC dataset from the Cell Tracking Challenge.
This is a series of videos of proliferating mouse hematopoietic stem cells. 
Importantly, we can use the annotations provided for the challenge to evaluate the detection performance.
This dataset is available at http://data.celltrackingchallenge.net/training-datasets/BF-C2DL-HSC.zip.

In [None]:
import os
from torchvision.datasets.utils import download_url, _extract_zip

dataset_path = os.path.join(".", "cell_detection_dataset")
if not os.path.exists(dataset_path):
    url = ("http://data.celltrackingchallenge.net/training-datasets/"
           "BF-C2DL-HSC.zip")
    download_url(url, ".")
    _extract_zip("BF-C2DL-HSC.zip", dataset_path, None)
    os.remove("BF-C2DL-HSC.zip")

dir = os.path.join(dataset_path, "BF-C2DL-HSC")

In [None]:
import deeptrack as dt
import glob
from skimage.measure import regionprops

dt.config.disable_image_wrapper()

sources = dt.sources.Source(
    image_path=sorted(glob.glob(os.path.join(dir, "02", "*.tif"))),
    label_path=sorted(glob.glob(os.path.join(dir, "02_GT", "TRA", "*.tif"))),
)

image = dt.LoadImage(sources.image_path)[:, :300] / 256
label = dt.LoadImage(sources.label_path)[:, :300] >> regionprops

pipeline = image & label

In [None]:
import matplotlib.pyplot as plt
import skimage.io

plt.figure(figsize=(15, 10))

for plt_index, data_index in enumerate([0, 300, 600, 900, 1200, 1500]):
    image, *props = pipeline(sources[data_index])
    
    plt.subplot(1, 6, plt_index + 1)
    plt.imshow(image, cmap="gray")
    for prop in props:
        plt.scatter(prop.centroid[1], prop.centroid[0], s=5, color="red")
    plt.axis("off")

plt.tight_layout()
plt.show()

In [None]:
import matplotlib.patches as patches

crop_frame_index = 282
crop_size = 50
crop_x0 = 595 - crop_size // 2
crop_y0 = 115 - crop_size // 2

image, *props = pipeline(sources[crop_frame_index])
crop = image[crop_x0 : crop_x0 + crop_size, crop_y0 : crop_y0 + crop_size]


fig = plt.figure(figsize=(2.5, 10))
plt.imshow(image, cmap="gray")
plt.gca().add_patch(
    patches.Rectangle(
        (crop_y0, crop_x0),
        crop_size,
        crop_size,
        linewidth=1,
        edgecolor="r",
        facecolor="none",
    )
)
plt.axis("off")

fig.add_subplot(2, 2, 2)
plt.imshow(crop, cmap="gray")

In [None]:
import numpy as np
import torch
training_pipeline = (
    dt.Value(crop)
    >> dt.Multiply(lambda: np.random.uniform(0.9, 1.1))
    >> dt.Add(lambda: np.random.uniform(-.1, .1))
    >> dt.MoveAxis(-1, 0)
    >> dt.pytorch.ToTensor(dtype=torch.float32)
)

training_dataset = dt.pytorch.Dataset(training_pipeline, length=400, replace=False)

In [None]:
import deeplay as dl 
dataloader = dl.DataLoader(training_dataset, batch_size=8, shuffle=True)

In [29]:

model = dl.LodeSTAR(n_transforms=4, optimizer=dl.Adam(lr=1e-4)).build()

trainer = dl.Trainer(max_epochs=20)
trainer.fit(model, dataloader)

In [None]:
image_index = 1500

image, *props = pipeline(sources[image_index])

torch_image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float()

prediction = model(torch_image)[0].detach().numpy()

x_feature = prediction[0]
y_feature = prediction[1]
mass_feature = prediction[-1]

plt.figure(figsize=(10, 10))

plt.subplot(1, 3, 1)
plt.imshow(image, cmap="gray")
plt.axis("off")

plt.subplot(1, 3, 2)
plt.imshow(mass_feature, cmap="gray")
plt.axis("off")

plt.subplot(1, 3, 3)
plt.imshow(image, cmap="gray")
plt.scatter(
	y_feature.flatten(), 
	x_feature.flatten(), 
	alpha=mass_feature.flatten() / mass_feature.max(),
	s=5,
)
plt.axis("off")

In [None]:
plt.figure(figsize=(15, 10))

for plot_idx, frame_idx in enumerate([0, 300, 600, 900, 1200, 1500]):
    image, *props = pipeline(sources[frame_idx])
    torch_image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float()
    y = model(torch_image.to(model.device))
    y_pred, weights = y[:, :-1], y[:, -1:]
    detections = model.detect(torch_image,
                              alpha=0.1,
                              beta=0.9,
                              mode="constant",
                              cutoff=.5)[0]

    plt.subplot(1, 6, plot_idx + 1)
    plt.imshow(image, cmap="gray")
    plt.scatter(
        detections[:, 1],
        detections[:, 0],
        s=5,
        color="red",
    )
    plt.axis("off")

In [None]:
import tqdm 
import scipy

distance_th = 10

TP = 0
FP = 0
FN = 0

for source in tqdm.tqdm(sources[::10]):
    image, *props = pipeline(source)
    torch_image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float()
    detections = model.detect(torch_image,
                              alpha=0.1,
                              beta=0.9,
                              mode="constant",
                              cutoff=.5)[0]
    
    centroids = np.array([prop.centroid[:2] for prop in props])

    distance_matrix = scipy.spatial.distance_matrix(detections, centroids)
    row_idx, col_idx = scipy.optimize.linear_sum_assignment(distance_matrix)

    filtered_row_ind = row_idx[distance_matrix[row_idx, col_idx] < distance_th]
    filtered_col_ind = col_idx[distance_matrix[row_idx, col_idx] < distance_th]

    TP += len(filtered_row_ind)
    FP += len(detections) - len(filtered_row_ind)
    FN += len(centroids) - len(filtered_col_ind)

f1 = 2 * TP / (2 * TP + FP + FN)

print(f"""
TP: {TP}
FP: {FP}
FN: {FN}
F1: {f1}  
""")