# Self-Supervised Learning and LodeSTAR

We'll explore self-supervised learning and how LodeSTAR exploits symmetries to locate microscopic particles.

## Create the Dataset

We start by creating the dataset, including a particle ...

In [1]:
import deeptrack as dt 
from numpy.random import uniform

image_size = 51

particle = dt.PointParticle(
    position=lambda:uniform(image_size / 2 - 5, image_size / 2 + 5, size=2),
)

... and the imaging optics ...

In [None]:
optics = dt.Fluorescence(output_region=(0, 0, image_size, image_size))

... which we combine in a simulation pipeline ...

In [None]:
import torch

simulation = (
    optics(particle) 
    >> dt.NormalizeMinMax(0, 1)
    >> dt.Gaussian(sigma=0.1)
    >> dt.MoveAxis(-1, 0)
    >> dt.pytorch.ToTensor(dtype=torch.float32)
)    

... which we have used to create a train and a test datasets.

In [None]:
train_dataset = dt.pytorch.Dataset(simulation, length=100)
test_dataset = dt.pytorch.Dataset(simulation & particle.position, length=5000)

Finally, we plot some examples of particles and the relative positions.

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 2))
for i in range(5):
	image, position = test_dataset[i]
	plt.subplot(1, 5, i + 1)
	plt.imshow(image[0], cmap="gray", origin="lower")
	plt.scatter(position[1], position[0], c="r")
plt.tight_layout()
plt.show()

## Learn from Translations

We implement a cobvolutional neural network with a dense top, with two outputs corresponding to the coordinates of the particle position.

In [None]:
import deeplay as dl
import torch.nn as nn

backbbone = dl.ConvolutionalNeuralNetwork(in_channels=1,
                                          hidden_channels=[16, 32, 64],
                                          out_channels=128,
                                          pool=nn.MaxPool2d(2))

# We use LazyLinear not to have to calculate the output size of the CNN.
model = dl.Sequential(backbbone,
                      nn.Flatten(),
                      nn.LazyLinear(2)).create()

print(model)

We define the transformation that we will apply and with which we will teach the neural network to be consistent.

In [None]:
from kornia.geometry.transform import translate

def image_translation(batch, translation):
    """Translate a batch of images."""
    # Flip the translation to match the image coordinate system.
    xy_flipped_translation = translation[:, [1, 0]]
    return translate(batch, xy_flipped_translation, padding_mode="reflection")

We also define the inverse transformation.

In [None]:
def inverse_translation(predicted_position, applied_translation):
    """Invert transaltion of output positions."""
    return predicted_position - applied_translation

In [None]:
class ParticleLocalizer(dl.Application):
    """LodeSTAR implementation with translations."""
    
    def __init__(self, model, n_transforms=8, **kwargs):
        """Initialize the ParticleLocalizer."""
        self.model = model
        self.n_transforms = n_transforms
        super().__init__(**kwargs)

    def forward(self, x):
        """Forward pass through the model."""
        return self.model(x)
    
    def random_arguments(self):
        """Generate random arguments for transformations."""
        translation = (torch.rand(self.n_transforms, 2).float().to(self.device) 
                       * 5 - 2.5)
        return {"translation": translation}

    def forward_transform(self, x, translation):
        """Apply forward translation to the image."""
        return image_translation(x, translation)

    def inverse_transform(self, x, translation):
        """Apply inverse translation to the image."""
        return inverse_translation(x, translation)

    def training_step(self, image, batch_idx):
        """Perform a single training step."""

        # Create a batch of images by applying random translations.
        image, *_ = image  # Take the first image from the input batch.
        batch = image.repeat(self.n_transforms, 1, 1, 1)

        # Get arguments for the random transforms.
        kwargs = self.random_arguments()
        transformed_batch = self.forward_transform(batch, **kwargs)

        # Predict the position of the particle.
        pred_position = self(transformed_batch)
        # Invert the translation to get the predicted position in the original image.
        pred_position = self.inverse_transform(pred_position, **kwargs)

        # Get average predicted position.
        average_pred_position = (pred_position
                                 .mean(dim=0, keepdim=True)
                                 .repeat(self.n_transforms, 1))  # Repeat the average position to match the batch size.

        # Calculate the loss. 
        # We minimize the distance between each prediction and their average
        # which effectively minimizes the variance of the predictions on the original image.
        loss = self.loss(pred_position, average_pred_position)
        self.log("loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss

We then instantiate the `ParticleLocalizer`.

In [None]:
localizer = ParticleLocalizer(
    model, n_transforms=8, loss=nn.L1Loss(), optimizer=dl.Adam(lr=5e-4)
).create()

We define the dataloader and trainer, and finally we train the model.

In [None]:
from torch.utils.data import DataLoader

dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)
trainer = dl.Trainer(max_epochs=100)
trainer.fit(localizer, dataloader)

### Evaluate Performance

We start by evaulating the positions of some particles in the test dataset.

In [None]:
images, positions = zip(*test_dataset)
images = torch.stack(images)
positions = torch.stack(positions)

predictions = localizer(images).detach().numpy()

We write a function to plot the predicted positions as a function of the real ones, and add it to `fnc_lodestar.py`.

```python
import matplotlib.pyplot as plt

def plot_position_comparison(positions, predictions):
    """Plot comparison between predicted and real particle positions."""

    plt.figure(figsize=(14, 8))
    grid = plt.GridSpec(4, 7, wspace=.2, hspace=.1)

    plt.subplot(grid[1:, :3])
    plt.scatter(positions[:, 0], predictions[:, 0], alpha=.5) 
    plt.axline((25, 25), slope=1, color="black")
    plt.xlabel("True Horizontal Position")
    plt.ylabel("Predicted Horizontal Position")
    plt.axis("equal")    

    plt.subplot(grid[1:, 4:])
    plt.scatter(positions[:, 1], predictions[:, 1], alpha=.5)
    plt.axline((25, 25), slope=1, color="black")
    plt.xlabel("True Vertical Position")
    plt.ylabel("Predicted Vertical Position")
    plt.axis("equal")

    plt.show()
```

In [None]:
from fnc_lodestar import plot_position_comparison

plot_position_comparison(positions, predictions)

In [None]:
reflected_images = images.flip(dims=(2, 3))

direct_preds = localizer(images).detach().numpy()
reflected_preds = localizer(reflected_images).detach().numpy()

predictions_with_difference = ((direct_preds - reflected_preds) / 2 
               + image_size / 2 - 0.5)

plot_position_comparison(positions, predictions_with_difference)

## Learn also from Reflections

We now add also reflections to the learning process.

In [None]:
def flip_transform(image, should_flip, dim):
    """Conditionally flip images along a specified dimension."""
    should_flip = should_flip.view(-1, 1, 1, 1)  # Reshape the should_flip tensor for broadcasting.
    return torch.where(should_flip, image.flip(dims=(dim,)), image)  # Flip the images where should_flip is True.

def inverse_flip_transform(x, should_flip, dim):
    """Conditionally inverse flip transformation based on the should_flip condition."""
    should_flip_mask = torch.zeros_like(x).bool()
    should_flip_mask[should_flip, dim] = 1
    return torch.where(should_flip_mask, -x, x)   # Apply the inverse flip where should_flip_mask is True


class ParticleLocalizerWithReflections(ParticleLocalizer):
    """ParticleLocalizer with additional reflection (flip) transformations."""
    
    def forward_transform(self, batch, translation, should_flip_x, should_flip_y):
        """Apply forward translations and reflections to the batch."""
        x = image_translation(batch, translation)
        x = flip_transform(x, should_flip_x, dim=3)
        x = flip_transform(x, should_flip_y, dim=2)
        return x
    
    def inverse_transform(self, x, translation, should_flip_x, should_flip_y):
        """Apply the inverse transformations to the predictions."""
        x = inverse_flip_transform(x, should_flip_y, dim=0)
        x = inverse_flip_transform(x, should_flip_x, dim=1)
        x = inverse_translation(x, translation)
        return x
    
    def random_arguments(self):
        """Generate random arguments for translation and flips."""
        return {
            "translation": torch.rand(self.n_transforms, 2).float().to(self.device) * 5 - 2.5,
            "should_flip_x": torch.rand(self.n_transforms).float().to(self.device) > .5,
            "should_flip_y": torch.rand(self.n_transforms).float().to(self.device) > .5,
        }

In [None]:
localizer_with_reflections = ParticleLocalizerWithReflections(
    model, n_transforms=8, loss=nn.L1Loss(), optimizer=dl.Adam(lr=1e-3)
).create()

trainer = dl.Trainer(max_epochs=100)
trainer.fit(localizer, dataloader)

In [None]:
predictions = localizer(images).detach().numpy() + image_size / 2 - 0.5

plot_position_comparison(positions, predictions)

## Improve Performance with LodeSTAR

We can improve the performance using the `LodeSTART` model from `deeplay`.

In [None]:
dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)

lodestar = dl.LodeSTAR(optimizer=dl.Adam(lr=1e-4)).build()

trainer = dl.Trainer(max_epochs=100)
trainer.fit(lodestar, dataloader)

In [None]:
lodestar_predictions = lodestar.pooled(images).detach().numpy() 

plot_position_comparison(positions, lodestar_predictions)