# Image Denoising Tutorial

## Introduction to Denoising

Welcome to this tutorial on image denoising! In this notebook, we will explore what noise is, why it's a problem in imaging, and various methods to remove it. We will cover classic denoising techniques and then dive into a powerful deep learning-based method called Noise2Void.

### What is Noise?

In the context of image processing, **noise** refers to random variations of brightness or color information in an image. It is an undesirable byproduct of image capture and transmission. Noise can be caused by various factors, such as low light conditions, sensor heat, or electronic interference during transmission.

### Why is Denoising Important?

Denoising is a crucial step in many image processing pipelines. Removing noise can:

*   Improve the visual quality of an image.
*   Enhance the performance of subsequent image processing tasks, such as object detection, image segmentation, and feature extraction.

## Simulating Noise

Before we can denoise an image, we need a noisy image to work with. In this section, we'll learn how to add different types of noise to a clean image. This is a common practice to evaluate the performance of denoising algorithms.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from skimage import io, img_as_float

# For this tutorial, we will use the famous "cameraman" image from scikit-image
try:
    from skimage.data import cameraman
    image = cameraman()
except ImportError:
    # Fallback for older scikit-image versions
    from skimage import data
    image = data.camera()

image = img_as_float(image)

### Types of Noise

#### 1. Gaussian Noise

Gaussian noise is a statistical noise that has a probability density function equal to that of the normal distribution, which is also known as the Gaussian distribution. In other words, the noise values are drawn from a Gaussian distribution. It is a very common type of noise, especially in images taken under low-light conditions.

We can control the amount of Gaussian noise by adjusting the `var` (variance) parameter. A higher variance will result in more noise.

In [None]:
# YOUR CODE HERE: Add Gaussian noise to the image and display the result


def add_gaussian_noise(image, mean=0, std=0.1):
    """Adds Gaussian noise using NumPy."""
    

gaussian_noisy_image = add_gaussian_noise(image, 0,0.2)

plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.title("Original Image")
plt.imshow(image, cmap='gray')
plt.axis('off')

plt.subplot(1, 2, 2)
plt.title("Image with Gaussian Noise")
plt.imshow(gaussian_noisy_image, cmap='gray')
plt.axis('off')
plt.show()

#### 2. Salt & Pepper Noise

Salt and pepper noise, also known as impulse noise, is a type of noise that presents itself as sparsely occurring white and black pixels. An image containing salt-and-pepper noise will have dark pixels in bright regions and bright pixels in dark regions.

We can control the amount of salt and pepper noise by adjusting the `amount` parameter, which represents the proportion of pixels to be affected.

In [None]:
# YOUR CODE HERE: Add Salt & Pepper noise to the image and display the result


def add_salt_and_pepper_noise(image, amount=0.05):
    """Adds salt and pepper noise to a grayscale image using NumPy."""
    
sp_noisy_image = add_salt_and_pepper_noise(image, amount=0.1)

plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.title("Original Image")
plt.imshow(image, cmap='gray')
plt.axis('off')

plt.subplot(1, 2, 2)
plt.title("Image with Salt & Pepper Noise")
plt.imshow(sp_noisy_image, cmap='gray')
plt.axis('off')
plt.show()

## Classical Denoising Methods

Now that we know how to add noise, let's explore some classical methods for removing it.

### 1. Gaussian Filter

A Gaussian filter is a linear filter that is widely used for blurring images and removing noise. It works by convolving the image with a Gaussian kernel. The standard deviation (`sigma`) of the Gaussian kernel is a parameter that controls the amount of blurring. A larger `sigma` will result in more blurring and more noise removal, but it can also lead to a loss of image details.

In [None]:

from skimage.filters import gaussian

# YOUR CODE HERE: Denoise the Gaussian noisy image using a Gaussian filter and display the result

plt.figure(figsize=(15, 5))
plt.subplot(1, 3, 1)
plt.title("Original Image")
plt.imshow(image, cmap='gray')
plt.axis('off')

plt.subplot(1, 3, 2)
plt.title("Noisy Image")
plt.imshow(gaussian_noisy_image, cmap='gray')
plt.axis('off')

plt.subplot(1, 3, 3)
plt.title("Denoised with Gaussian Filter")
plt.imshow(denoised_gaussian, cmap='gray')
plt.axis('off')
plt.show()

### 2. Median Filter

The median filter is a non-linear digital filtering technique, often used to remove noise from an image or signal. It is particularly effective at removing salt-and-pepper noise. The median filter works by replacing each pixel's value with the median value of its neighbors.

In [None]:
from skimage.filters import median
from skimage.morphology import disk

# YOUR CODE HERE: Denoise the Salt & Pepper noisy image using a median filter and display the result


plt.figure(figsize=(15, 5))
plt.subplot(1, 3, 1)
plt.title("Original Image")
plt.imshow(image, cmap='gray')
plt.axis('off')

plt.subplot(1, 3, 2)
plt.title("Noisy Image (S&P)")
plt.imshow(sp_noisy_image, cmap='gray')
plt.axis('off')

plt.subplot(1, 3, 3)
plt.title("Denoised with Median Filter")
plt.imshow(denoised_median, cmap='gray')
plt.axis('off')
plt.show()

### 3. BM3D (Block-matching and 3D filtering)

BM3D is a more advanced denoising method that is considered state-of-the-art for classical denoising. It works by finding similar patches in the image, stacking them into a 3D group, filtering the group, and then returning the filtered patches to their original locations. It is particularly effective for Gaussian noise.

*Note: You might need to install the `bm3d` package. You can do this by running `pip install bm3d` in your terminal.*

In [None]:
try:
    import bm3d
    # YOUR CODE HERE: Denoise the Gaussian noisy image using BM3D and display the result

    plt.figure(figsize=(15, 5))
    plt.subplot(1, 3, 1)
    plt.title("Original Image")
    plt.imshow(image, cmap='gray')
    plt.axis('off')
    
    plt.subplot(1, 3, 2)
    plt.title("Noisy Image")
    plt.imshow(gaussian_noisy_image, cmap='gray')
    plt.axis('off')
    
    plt.subplot(1, 3, 3)
    plt.title("Denoised with BM3D")
    plt.imshow(denoised_bm3d, cmap='gray')
    plt.axis('off')
    plt.show()
except ImportError:
    print("BM3D is not installed. Please run 'pip install bm3d' to use this feature.")

## Denoising with Noise2Void

Now, we will explore a powerful deep learning-based method for image denoising called **Noise2Void**. Unlike traditional methods, Noise2Void learns to denoise images directly from noisy images, without needing clean, noise-free training data.

### How Noise2Void Works (A Brief Explanation)

The core idea behind Noise2Void is to train a neural network to predict a pixel's value from its neighborhood, but *without* seeing the center pixel itself. This is achieved by creating a "blind spot" in the receptive field of the network. This forces the network to learn the underlying structure of the image and distinguish it from the random noise. Since the noise is random, the network cannot predict it from the neighboring pixels, and thus learns to ignore it.

### Using the Noise2Void Code

Now, let's use the `noise2void` code in this repository to denoise an image. We will need to:

1.  **Prepare the data:** We'll create masked images for training.
2.  **Create a Noise2Void model:** We'll use the `ResNet` model from `model.py`.
3.  **Train the model:** We'll train the model on our noisy data.
4.  **Predict (Denoise):** We'll use the trained model to denoise a new noisy image.

In [None]:
# YOUR CODE HERE: Prepare the data for Noise2Void by creating a masked image.
import numpy as np
import torch 
from torchvision import transforms
from skimage.transform import resize  # ✅ import resize properly
from noise2void.dataset import Dataset

# Normalize data to [0,1] range
noisyNorm = (gaussian_noisy_image - gaussian_noisy_image.min()) / (gaussian_noisy_image.max() - gaussian_noisy_image.min())

# Ensure dimensions are divisible by 16 (for UNet architecture)
target_size = ((gaussian_noisy_image.shape[0] // 16) * 16, (gaussian_noisy_image.shape[1] // 16) * 16)
gaussian_noisy_image = resize(gaussian_noisy_image, target_size)  # ✅ fix here

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
patch_size = 64
n_patches = 128
batch_size = 4
learning_rate = 0.001
weight_decay = 0.00001
num_epochs = 30

# Initialize data transformations (for PyTorch)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Add channel dimension and convert to float32
gaussian_noisy_image = np.expand_dims(gaussian_noisy_image, axis=-1).astype(np.float32)

# Create dataset
dataset = Dataset(
    data_dir=None,  # We'll use our own data
    transform=transform,
    sgm=25,  # Noise level estimate
    ratio=0.9,
    size_data=(gaussian_noisy_image.shape[0], gaussian_noisy_image.shape[1], 1),
    size_window=(5, 5)
)


### Training the Noise2Void Model

In [None]:
# YOUR CODE HERE: Set up the model, loss, and optimizer, and train the model.
# Initialize model

from noise2void.dataset import *
from noise2void.model import *

net = UNet(nch_in=1, nch_out=1, nch_ker=64, norm='bnorm')
net = net.to(device)

# Set up optimizer
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate, weight_decay=weight_decay)
criterion = nn.MSELoss()

# Convert data to tensor and pad if necessary
input_tensor = torch.from_numpy(gaussian_noisy_image).permute(2, 0, 1).unsqueeze(0)

# Pad input to make dimensions divisible by 16
h, w = input_tensor.shape[2], input_tensor.shape[3]
pad_h = (16 - h % 16) % 16
pad_w = (16 - w % 16) % 16
if pad_h > 0 or pad_w > 0:
    input_tensor = F.pad(input_tensor, (0, pad_w, 0, pad_h), mode='reflect')

input_tensor = input_tensor.to(device)

# Training loop
for epoch in range(num_epochs):
    net.train()
    optimizer.zero_grad()
    
    # Forward pass
    output = net(input_tensor)
    
    # Calculate loss
    loss = criterion(output, input_tensor)
    
    # Backward pass
    loss.backward()
    optimizer.step()
    
    if (epoch + 1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

### Denoising with the Trained Model

In [None]:
# YOUR CODE HERE: Denoise the noisy image with the trained model and display the results.
# Set model to evaluation mode
net.eval()

with torch.no_grad():
    # Get denoised output
    denoised = net(input_tensor)
    
    # Remove padding if added
    if pad_h > 0 or pad_w > 0:
        denoised = denoised[:, :, :h, :w]
    
    # Convert back to numpy
    denoised = denoised.cpu().numpy().squeeze()
    
    # Display results
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 7))
    
    ax1.imshow(gaussian_noisy_image.squeeze(), cmap='gray')
    ax1.set_title('Original Noisy Image')
    ax1.axis('off')
    
    ax2.imshow(denoised, cmap='gray')
    ax2.set_title('Denoised Image')
    ax2.axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Calculate and print PSNR
    mse = np.mean((gaussian_noisy_image.squeeze() - denoised) ** 2)
    psnr = 20 * np.log10(1.0 / np.sqrt(mse))
    print(f'PSNR: {psnr:.2f} dB')