# Image Denoising Tutorial

## Introduction to Denoising

Welcome to this tutorial on image denoising! In this notebook, we will explore what noise is, why it's a problem in imaging, and various methods to remove it. We will cover classic denoising techniques and then dive into a powerful deep learning-based method called Noise2Void.

### What is Noise?

In the context of image processing, **noise** refers to random variations of brightness or color information in an image. It is an undesirable byproduct of image capture and transmission. Noise can be caused by various factors, such as low light conditions, sensor heat, or electronic interference during transmission.

### Why is Denoising Important?

Denoising is a crucial step in many image processing pipelines. Removing noise can:

*   Improve the visual quality of an image.
*   Enhance the performance of subsequent image processing tasks, such as object detection, image segmentation, and feature extraction.

## Simulating Noise

Before we can denoise an image, we need a noisy image to work with. In this section, we'll learn how to add different types of noise to a clean image. This is a common practice to evaluate the performance of denoising algorithms.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from skimage import io, img_as_float
from skimage.util import random_noise

# For this tutorial, we will use the famous "cameraman" image from scikit-image
try:
    from skimage.data import cameraman
    image = cameraman()
except ImportError:
    # Fallback for older scikit-image versions
    from skimage import data
    image = data.camera()

image = img_as_float(image)

### Types of Noise

#### 1. Gaussian Noise

Gaussian noise is a statistical noise that has a probability density function equal to that of the normal distribution, which is also known as the Gaussian distribution. In other words, the noise values are drawn from a Gaussian distribution. It is a very common type of noise, especially in images taken under low-light conditions.

We can control the amount of Gaussian noise by adjusting the `var` (variance) parameter. A higher variance will result in more noise.

In [None]:
def add_gaussian_noise(image, mean=0, var=0.01):
    """Adds Gaussian noise to an image."""
    return random_noise(image, mode='gaussian', seed=None, clip=True, mean=mean, var=var)

gaussian_noisy_image = add_gaussian_noise(image, var=0.05)

plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.title("Original Image")
plt.imshow(image, cmap='gray')
plt.axis('off')

plt.subplot(1, 2, 2)
plt.title("Image with Gaussian Noise")
plt.imshow(gaussian_noisy_image, cmap='gray')
plt.axis('off')
plt.show()

#### 2. Salt & Pepper Noise

Salt and pepper noise, also known as impulse noise, is a type of noise that presents itself as sparsely occurring white and black pixels. An image containing salt-and-pepper noise will have dark pixels in bright regions and bright pixels in dark regions.

We can control the amount of salt and pepper noise by adjusting the `amount` parameter, which represents the proportion of pixels to be affected.

In [None]:
def add_salt_and_pepper_noise(image, amount=0.05):
    """Adds salt and pepper noise to an image."""
    return random_noise(image, mode='s&p', seed=None, clip=True, amount=amount)

sp_noisy_image = add_salt_and_pepper_noise(image, amount=0.1)

plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.title("Original Image")
plt.imshow(image, cmap='gray')
plt.axis('off')

plt.subplot(1, 2, 2)
plt.title("Image with Salt & Pepper Noise")
plt.imshow(sp_noisy_image, cmap='gray')
plt.axis('off')
plt.show()

## Classical Denoising Methods

Now that we know how to add noise, let's explore some classical methods for removing it.

### 1. Gaussian Filter

A Gaussian filter is a linear filter that is widely used for blurring images and removing noise. It works by convolving the image with a Gaussian kernel. The standard deviation (`sigma`) of the Gaussian kernel is a parameter that controls the amount of blurring. A larger `sigma` will result in more blurring and more noise removal, but it can also lead to a loss of image details.

In [None]:
from skimage.filters import gaussian

denoised_gaussian = gaussian(gaussian_noisy_image, sigma=1)

plt.figure(figsize=(15, 5))
plt.subplot(1, 3, 1)
plt.title("Original Image")
plt.imshow(image, cmap='gray')
plt.axis('off')

plt.subplot(1, 3, 2)
plt.title("Noisy Image")
plt.imshow(gaussian_noisy_image, cmap='gray')
plt.axis('off')

plt.subplot(1, 3, 3)
plt.title("Denoised with Gaussian Filter")
plt.imshow(denoised_gaussian, cmap='gray')
plt.axis('off')
plt.show()

### 2. Median Filter

The median filter is a non-linear digital filtering technique, often used to remove noise from an image or signal. It is particularly effective at removing salt-and-pepper noise. The median filter works by replacing each pixel's value with the median value of its neighbors.

In [None]:
from skimage.filters import median
from skimage.morphology import disk

denoised_median = median(sp_noisy_image, disk(3))

plt.figure(figsize=(15, 5))
plt.subplot(1, 3, 1)
plt.title("Original Image")
plt.imshow(image, cmap='gray')
plt.axis('off')

plt.subplot(1, 3, 2)
plt.title("Noisy Image (S&P)")
plt.imshow(sp_noisy_image, cmap='gray')
plt.axis('off')

plt.subplot(1, 3, 3)
plt.title("Denoised with Median Filter")
plt.imshow(denoised_median, cmap='gray')
plt.axis('off')
plt.show()

### 3. BM3D (Block-matching and 3D filtering)

BM3D is a more advanced denoising method that is considered state-of-the-art for classical denoising. It works by finding similar patches in the image, stacking them into a 3D group, filtering the group, and then returning the filtered patches to their original locations. It is particularly effective for Gaussian noise.

*Note: You might need to install the `bm3d` package. You can do this by running `pip install bm3d` in your terminal.*

In [None]:
try:
    import bm3d
    denoised_bm3d = bm3d.bm3d(gaussian_noisy_image, sigma_psd=0.1, stage_arg=bm3d.BM3DStages.ALL_STAGES)

    plt.figure(figsize=(15, 5))
    plt.subplot(1, 3, 1)
    plt.title("Original Image")
    plt.imshow(image, cmap='gray')
    plt.axis('off')
    
    plt.subplot(1, 3, 2)
    plt.title("Noisy Image")
    plt.imshow(gaussian_noisy_image, cmap='gray')
    plt.axis('off')
    
    plt.subplot(1, 3, 3)
    plt.title("Denoised with BM3D")
    plt.imshow(denoised_bm3d, cmap='gray')
    plt.axis('off')
    plt.show()
except ImportError:
    print("BM3D is not installed. Please run 'pip install bm3d' to use this feature.")

## Denoising with Noise2Void

Now, we will explore a powerful deep learning-based method for image denoising called **Noise2Void**. Unlike traditional methods, Noise2Void learns to denoise images directly from noisy images, without needing clean, noise-free training data.

### How Noise2Void Works (A Brief Explanation)

The core idea behind Noise2Void is to train a neural network to predict a pixel's value from its neighborhood, but *without* seeing the center pixel itself. This is achieved by creating a "blind spot" in the receptive field of the network. This forces the network to learn the underlying structure of the image and distinguish it from the random noise. Since the noise is random, the network cannot predict it from the neighboring pixels, and thus learns to ignore it.

### Using the Noise2Void Code

Now, let's use the `noise2void` code in this repository to denoise an image. We will need to:

1.  **Prepare the data:** We'll create masked images for training.
2.  **Create a Noise2Void model:** We'll use the `ResNet` model from `model.py`.
3.  **Train the model:** We'll train the model on our noisy data.
4.  **Predict (Denoise):** We'll use the trained model to denoise a new noisy image.

In [None]:
import torch
import torch.nn as nn
from torch.optim import Adam
from noise2void.model import ResNet
from noise2void.dataset import ToTensor, Normalize, Denormalize
import copy

# Let's use the Gaussian noisy image we created earlier
noisy_image = gaussian_noisy_image.copy()

# The Noise2Void dataset implementation uses a function to generate a mask.
# We will replicate that functionality here.
def generate_mask(input, ratio=0.9, size_window=(5, 5)):
    size_data = input.shape
    num_sample = int(size_data[0] * size_data[1] * (1 - ratio))

    mask = np.ones(size_data)
    output = input.copy()

    for ich in range(size_data[2]):
        idy_msk = np.random.randint(0, size_data[0], num_sample)
        idx_msk = np.random.randint(0, size_data[1], num_sample)

        idy_neigh = np.random.randint(-size_window[0] // 2 + size_window[0] % 2, size_window[0] // 2 + size_window[0] % 2, num_sample)
        idx_neigh = np.random.randint(-size_window[1] // 2 + size_window[1] % 2, size_window[1] // 2 + size_window[1] % 2, num_sample)

        idy_msk_neigh = idy_msk + idy_neigh
        idx_msk_neigh = idx_msk + idx_neigh

        idy_msk_neigh = idy_msk_neigh + (idy_msk_neigh < 0) * size_data[0] - (idy_msk_neigh >= size_data[0]) * size_data[0]
        idx_msk_neigh = idx_msk_neigh + (idx_msk_neigh < 0) * size_data[1] - (idx_msk_neigh >= size_data[1]) * size_data[1]

        id_msk = (idy_msk, idx_msk, ich)
        id_msk_neigh = (idy_msk_neigh, idx_msk_neigh, ich)

        output[id_msk] = input[id_msk_neigh]
        mask[id_msk] = 0.0

    return output, mask

# Prepare the data
if noisy_image.ndim == 2:
    noisy_image = np.expand_dims(noisy_image, axis=2)

input_image, mask = generate_mask(noisy_image)

# The model expects the data in a dictionary
data = {'input': input_image, 'label': noisy_image, 'mask': mask}

# We will use the ToTensor transform to convert the numpy arrays to tensors
to_tensor = ToTensor()
tensor_data = to_tensor(data)

# The input, label, and mask are now tensors
input_tensor = tensor_data['input'].unsqueeze(0) # Add a batch dimension
label_tensor = tensor_data['label'].unsqueeze(0) # Add a batch dimension
mask_tensor = tensor_data['mask'].unsqueeze(0)   # Add a batch dimension

# Display the masked image
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.title("Original Noisy Image")
plt.imshow(noisy_image.squeeze(), cmap='gray')
plt.axis('off')

plt.subplot(1, 2, 2)
plt.title("Masked Input for N2V")
plt.imshow(input_image.squeeze(), cmap='gray')
plt.axis('off')
plt.show()

### Training the Noise2Void Model

In [None]:
# Set up the model, loss, and optimizer
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# The model is a ResNet, as used in the original repository
netG = ResNet(nch_in=1, nch_out=1, nch_ker=64, norm='bnorm').to(device)

# The loss function is L1 Loss
loss_fn = nn.L1Loss().to(device)

# The optimizer is Adam
optimizer = Adam(netG.parameters(), lr=1e-3)

# Move tensors to the selected device
input_tensor = input_tensor.to(device)
label_tensor = label_tensor.to(device)
mask_tensor = mask_tensor.to(device)

# Training loop
num_epochs = 100
for epoch in range(num_epochs):
    netG.train()

    # Forward pass
    output_tensor = netG(input_tensor)

    # Calculate the loss only on the masked pixels
    loss = loss_fn(output_tensor * (1 - mask_tensor), label_tensor * (1 - mask_tensor))

    # Backward pass and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

print("Training finished!")

### Denoising with the Trained Model

In [None]:
# Denoise the image
netG.eval()
with torch.no_grad():
    # For denoising, we use the original noisy image as input
    denoised_tensor = netG(label_tensor) # Use the unmasked noisy image

# Move the denoised tensor to the CPU and convert to a numpy array
denoised_image = denoised_tensor.squeeze().cpu().numpy()

# Display the results
plt.figure(figsize=(15, 5))
plt.subplot(1, 3, 1)
plt.title("Original Image")
plt.imshow(image.squeeze(), cmap='gray')
plt.axis('off')

plt.subplot(1, 3, 2)
plt.title("Noisy Image")
plt.imshow(noisy_image.squeeze(), cmap='gray')
plt.axis('off')

plt.subplot(1, 3, 3)
plt.title("Denoised with Noise2Void")
plt.imshow(denoised_image, cmap='gray')
plt.axis('off')
plt.show()