#  Denoising Images with an Encoder-Decoder

This notebook provides you with a complete code example that generates noisy brightfield microscopy images of particles, trains an encoder-decoder to denoise them, and visualizes the results.

## Generating the Data

Define a spherical particle ...

In [None]:
import deeptrack as dt
import numpy as np

particle = dt.Sphere(position=np.array([0.5, 0.5]) * 64, position_unit="pixel",
                     radius=500 * dt.units.nm, refractive_index=1.45 + 0.02j)

... define the microscope to observe the particle ...

In [None]:
brightfield_microscope = dt.Brightfield(wavelength=500 * dt.units.nm, NA=1.0,
    resolution=1 * dt.units.um, magnification=10, refractive_index_medium=1.33, 
    upsample=2, output_region=(0, 0, 64, 64))

... obtain the image of the microscopic particle ...

In [None]:
illuminated_sample = brightfield_microscope(particle)

... simulate the clean image of the particle ...

In [None]:
import torch

clean_particle = illuminated_sample >> dt.NormalizeMinMax() \
    >> dt.MoveAxis(2, 0) >> dt.pytorch.ToTensor(dtype=torch.float)

... simulate the noisy image of the particle ...

In [None]:
noise = dt.Poisson(snr=lambda: 2.0 + np.random.rand())

noisy_particle = illuminated_sample >> noise >> dt.NormalizeMinMax() \
    >> dt.MoveAxis(2, 0) >> dt.pytorch.ToTensor(dtype=torch.float)

... combine the noisy and clean particle images into a single simulation pipeline ...

In [None]:
pip = noisy_particle & clean_particle

... and plot a few noisy and corresponding clean particle images.

In [None]:
import matplotlib.pyplot as plt

def plot_image(title, image):
    """Plot a grayscale image with a title."""
    plt.imshow(image, cmap="gray")
    plt.title(title, fontsize=30)
    plt.axis("off")
    plt.show()

In [None]:
for i in range(5):
    input, target = pip.update().resolve()
    plot_image(f"Input Image {i}", input.permute(1, 2, 0))
    plot_image(f"Target Image {i}", target.permute(1, 2, 0))

## Creating a Dataset

Define a class representing a simulated dataset to generate the data for the training of the denoising encoder-decoder ...

In [None]:
class SimulatedDataset(torch.utils.data.Dataset):
    """Simulated dataset generating pairs of noisy and clean images."""

    def __init__(self, pip, buffer_size, replace=0):
        """Initialize the dataset."""
        self.pip, self.buffer_size, self.replace = pip, buffer_size, replace
        self.images = [pip.update().resolve() for _ in range(buffer_size)]

    def __len__(self):
        """Return the size of the dataset buffer."""
        return self.buffer_size

    def __getitem__(self, idx):
        """Retrieve a noisy-clean image pair from the dataset."""
        if np.random.rand() < self.replace:
            self.images[idx] = self.pip.update().resolve()
        image_pair = self.images[idx]
        noisy_image, clean_image = image_pair[0], image_pair[1]
        return noisy_image, clean_image

... and create the dataset and the data loader.

In [None]:
dataset = SimulatedDataset(pip, buffer_size=256, replace=0.1)
loader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=True)

## Defining and Training the Encoder-Decoder

Define the encoder-decoder ...

In [None]:
import deeplay as dl

encoderdecoder = dl.ConvolutionalEncoderDecoder2d(in_channels=1, 
    encoder_channels=[16, 16], out_channels=1)

print(encoderdecoder)

... compile it ...

In [None]:
regressor_template = dl.Regressor(model=encoderdecoder, loss=torch.nn.L1Loss(),
                            optimizer=dl.Adam())
ed = regressor_template.create()

print(ed)

... and train it.

In [None]:
ed_trainer = dl.Trainer(max_epochs=150, accelerator="auto")
ed_trainer.fit(ed, loader)

## Testing the Trained Encoder-Decoder

In [None]:
for i in range(5):
    input, target = pip.update().resolve()
    predicted = ed(input.unsqueeze(0)).detach()
    
    plot_image(f"Input Image {i}", input[0, :, :])
    plot_image(f"Target Image {i}", target[0, :, :])
    plot_image(f"Predicted Image {i}", predicted[0, 0, :, :])

## Checking Absence of Mode Collapse

Use the trained encoder-decoder with a blank image.

In [None]:
blank = brightfield_microscope(particle ^ 0)
blank_pip = blank >> noise >> dt.NormalizeMinMax() >> dt.MoveAxis(2, 0) \
    >> dt.pytorch.ToTensor(dtype=torch.float)

for i in range(5):
    blank_image = blank_pip.update().resolve()
    blank_predicted = ed(blank_image.unsqueeze(0)).detach()
    plot_image(f"Input Image {i}", blank_image[0, :, :])
    plot_image(f"Predicted Image {i}", np.square(blank_predicted[0, 0, :, :]))

## Checking Generalization Capabilities

Define a pipeline with a particle with varying position and radius ...

In [None]:
diverse_particle = dt.Sphere(
    position=lambda: np.array([0.2, 0.2] + np.random.rand(2) * 0.6) * 64, 
    radius=lambda: 500 * dt.units.nm * (1 + np.random.rand()), 
    position_unit="pixel", refractive_index=1.45 + 0.02j)
diverse_illuminated_sample = brightfield_microscope(diverse_particle)
diverse_clean_particle = diverse_illuminated_sample \
    >> dt.NormalizeMinMax() >> dt.MoveAxis(2, 0) \
    >> dt.pytorch.ToTensor(dtype=torch.float)
diverse_noisy_particle = diverse_illuminated_sample >> noise \
    >> dt.NormalizeMinMax() >> dt.MoveAxis(2, 0) \
    >> dt.pytorch.ToTensor(dtype=torch.float)
diverse_pip = diverse_noisy_particle & diverse_clean_particle

... and denoise images of diverse particles using the trained encoder-decoder.

In [None]:
for i in range(5):
    diverse_input, diverse_target = diverse_pip.update().resolve()
    diverse_predicted = ed(diverse_input.unsqueeze(0)).detach()
        
    plot_image(f"Input Image {i}", diverse_input[0, :, :])
    plot_image(f"Target Image {i}", diverse_target[0, :, :])
    plot_image(f"Predicted Image {i}", diverse_predicted[0, 0, :, :])

## Improving the Training

Train with the dataset with varying parameters ...

In [None]:
diverse_dataset = SimulatedDataset(diverse_pip, buffer_size=256, replace=0.1)
diverse_loader = torch.utils.data.DataLoader(diverse_dataset, batch_size=8, 
                                             shuffle=True)
diverse_ed = regressor_template.create()
diverse_ed_trainer = dl.Trainer(max_epochs=150, accelerator="auto")
diverse_ed_trainer.fit(diverse_ed, diverse_loader)


... and plot the images obtained with the improved training.

In [None]:
for i in range(5):
    diverse_input, diverse_target = diverse_pip.update().resolve()
    diverse_predicted = diverse_ed(diverse_input.unsqueeze(0)).detach()
        
    plot_image(f"Input Image {i}", diverse_input[0, :, :])
    plot_image(f"Target Image {i}", diverse_target[0, :, :])
    plot_image(f"Predicted Image {i}", diverse_predicted[0, 0, :, :])