<a href="https://colab.research.google.com/github/Gaurav14cs17/DeNoisingHomogeneous-Region/blob/main/Image_Denoising_using_AutoEncode.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.
import kagglehub
greatgamedota_ffhq_face_data_set_path = kagglehub.dataset_download('greatgamedota/ffhq-face-data-set')

print('Data source import complete.')


<span style="font-size:18px;">A note before starting, this notebook was run locally using GPU and due to the long time it would take to be executed again I'll be copying some of the outputs from the local notebook and put them in markdown. I hope you find the notebook beneficial, and if you have any piece of advice feel free to share it in the comments.

# Introduction

<span style="font-size:18px;"> The context of this notebook is an academic project where we were demanded to denoise images (without specifying the method). In the beginning I tried a non-ML solution that is called BM3D (Block Matching and 3D filtering), then I went ahead to experiment with the Deep Learning approach using the Auto-Encoder architecture.

# Setting up the notebook

In [None]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

!pip install bm3d

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import bm3d

import torch.cuda
import torch.nn as nn
from torch.nn import Linear, ReLU, MSELoss, L1Loss, Sequential, Conv2d, ConvTranspose2d, MaxPool2d, AdaptiveAvgPool2d, Module, BatchNorm2d, Sigmoid, Dropout
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import random_split
from torchvision import datasets, transforms
import os

In [None]:
gpu_avail = torch.cuda.is_available()
print(f"Is the GPU available? {gpu_avail}")

In [None]:
greatgamedota_ffhq_face_data_set_path

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("Device", device)

Device cuda

<span style="font-size:18px;">I downloaded and used the dataset available in this link : https://www.kaggle.com/greatgamedota/ffhq-face-data-set

In [None]:
PATH =  "/root/.cache/kagglehub/datasets/greatgamedota/ffhq-face-data-set/versions/2"

In [None]:
data_dir = PATH

transform = transforms.Compose([transforms.Resize(256),
                                transforms.ToTensor()
                               ])
dataset = datasets.ImageFolder(data_dir, transform=transform)
lengths = [int(len(dataset)*0.8), int(len(dataset)*0.2)]
train_dataset, val_dataset = random_split(dataset, lengths)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=16, shuffle=True)

In [None]:
def imshow(image, ax=None, title=None, normalize=True):
    """Imshow for Tensor."""
    if ax is None:
        fig, ax = plt.subplots()
    image = image.numpy().transpose((1, 2, 0))

    if normalize:
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        image = std * image + mean
        image = np.clip(image, 0, 1)

    ax.imshow(image)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.tick_params(axis='both', length=0)
    ax.set_xticklabels('')
    ax.set_yticklabels('')

    return ax

In [None]:
# Visualizing an image from the dataset
images, _ = next(iter(val_dataloader))
imshow(images[1], normalize=False)

# Expermenting with the noises

White Additive Noise :

In [None]:
noisy_images = (images + torch.normal(0,0.2, images.shape)).clip(0,1)

In [None]:
imshow(noisy_images[1], normalize=False)

Masking Noise :

In [None]:
a = 0.7*torch.ones(images.shape)
bernouilli_noisy_images = images*torch.bernoulli(a)

In [None]:
imshow(bernouilli_noisy_images[1], normalize=False)

Poisson Noise :

In [None]:
a = 5*torch.ones(images.shape)
p = torch.poisson(a)
p_norm = p/p.max()

In [None]:
poisson_noisy_images = (images + p_norm).clip(0,1)

In [None]:
imshow(poisson_noisy_images[1], normalize=False)

In [None]:
EPS = 1e-8

<span style="font-size:18px;"> These are some helper functions to evaluate the the performance of the denoising approachs:

In [None]:
def PSNR(input, target):
    return -10*torch.log10(torch.mean((input - target) ** 2, dim=[1, 2, 3])+EPS)

In [None]:
def MSE(input, target):
    return torch.mean((input - target) ** 2, dim=[1, 2, 3])

PSNR of Normal Noise Images :

In [None]:
PSNR(images, noisy_images)

tensor([15.0455, 15.1718, 14.7140, 14.7304, 15.3507, 15.3572, 14.9787, 14.8285,
        14.7231, 14.9067, 15.6183, 14.5245, 14.7026, 15.0107, 14.9028, 14.9212])

PSNR of Masked Images :

In [None]:
PSNR(images, bernouilli_noisy_images)

tensor([10.7218, 13.6338, 12.5109, 11.8838, 11.0169,  7.7124, 10.1220, 12.0363,
        11.0440,  9.5477, 16.9152,  9.7393, 11.7690, 10.0014,  9.8994, 12.6178])

PSNR of Poisson Noise Images :

In [None]:
PSNR(images, poisson_noisy_images)

tensor([11.7429, 10.9976, 11.1412, 11.1908, 11.7629, 13.7808, 11.9429, 11.2966,
        11.3740, 12.1885, 10.8384, 11.8166, 11.2383, 12.0369, 12.0351, 11.1616])

<span style="font-size:18px;">So these scores represent a baseline for us, the models that we will create need to have a significantly higher PSNR to be taken into consideration.

# Denoising using BM3D

In [None]:
denoised_image = bm3d.bm3d(noisy_images[1].permute(1,2,0), sigma_psd=30/255, stage_arg=bm3d.BM3DStages.HARD_THRESHOLDING)

In [None]:
plt.imshow(denoised_image)

In [None]:
bernouilli_denoised_image = bm3d.bm3d(bernouilli_noisy_images[1].permute(1,2,0), sigma_psd=30/255, stage_arg=bm3d.BM3DStages.HARD_THRESHOLDING)

In [None]:
plt.imshow(bernouilli_denoised_image)

In [None]:
poisson_denoised_image = bm3d.bm3d(poisson_noisy_images[1].permute(1,2,0), sigma_psd=15/255, stage_arg=bm3d.BM3DStages.HARD_THRESHOLDING).clip(0,1)

In [None]:
plt.imshow(poisson_denoised_image)

<span style="font-size:18px;">The BM3D method has given a significant improvement when used for the normal noise, but for the other two noises it wasn't as remarkable.

<span style="font-size:18px;">(You can find a comparative illustrative in the end of the notebook)

# Denoising using Deep Learning

## Autoencoders

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import kaiming_normal_

class HaarDWT(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.in_channels = in_channels

        # Orthonormal Haar wavelet kernels
        ll = torch.tensor([[1, 1], [1, 1]], dtype=torch.float32) * 0.5
        lh = torch.tensor([[1, -1], [1, -1]], dtype=torch.float32) * 0.5
        hl = torch.tensor([[1, 1], [-1, -1]], dtype=torch.float32) * 0.5
        hh = torch.tensor([[1, -1], [-1, 1]], dtype=torch.float32) * 0.5

        kernel = torch.stack([ll, lh, hl, hh], dim=0).unsqueeze(1)
        self.register_buffer('filters', kernel.repeat(in_channels, 1, 1, 1))

    def forward(self, x):
        # Symmetric padding for perfect reconstruction
        x = F.pad(x, (1, 0, 1, 0), mode='reflect')
        return F.conv2d(x, self.filters, stride=2, groups=self.in_channels)

class HaarIDWT(nn.Module):
    def __init__(self, out_channels):
        super().__init__()
        self.out_channels = out_channels

        # Adjoint operators for perfect reconstruction
        ll = torch.tensor([[1, 1], [1, 1]], dtype=torch.float32) * 0.5
        lh = torch.tensor([[1, 1], [-1, -1]], dtype=torch.float32) * 0.5
        hl = torch.tensor([[1, -1], [1, -1]], dtype=torch.float32) * 0.5
        hh = torch.tensor([[1, -1], [-1, 1]], dtype=torch.float32) * 0.5

        kernel = torch.stack([ll, lh, hl, hh], dim=0).unsqueeze(1)
        self.register_buffer('inv_filters', kernel.repeat(out_channels, 1, 1, 1))

    def forward(self, x):
        # Matching padding for perfect reconstruction
        x = F.pad(x, (0, 1, 0, 1), mode='reflect')
        return F.conv_transpose2d(x, self.inv_filters, stride=2, groups=self.out_channels)

class PerfectReconstructionTest(nn.Module):
    """Module to verify perfect reconstruction property"""
    def __init__(self, channels):
        super().__init__()
        self.dwt = HaarDWT(channels)
        self.idwt = HaarIDWT(channels)

    def forward(self, x):
        coeffs = self.dwt(x)
        return self.idwt(coeffs)

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(channels, channels, 3, padding=1),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Conv2d(channels, channels, 3, padding=1)
        )

    def forward(self, x):
        return x + self.conv(x)

class FrequencyAttention(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.fc = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, channels//4, 1),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Conv2d(channels//4, channels, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return x * self.fc(x)

class EdgeAwareDenoiser(nn.Module):
    def __init__(self, in_channels, base_channels=32):
        super().__init__()
        self.in_channels = in_channels

        # Encoder
        self.enc1 = nn.Sequential(
            nn.Conv2d(in_channels, base_channels, 3, padding=1),
            nn.LeakyReLU(0.1, inplace=True)
        )
        self.enc2 = nn.Sequential(
            nn.Conv2d(base_channels, base_channels*2, 3, stride=2, padding=1),
            nn.LeakyReLU(0.1, inplace=True)
        )
        self.enc3 = nn.Sequential(
            nn.Conv2d(base_channels*2, base_channels*4, 3, stride=2, padding=1),
            nn.LeakyReLU(0.1, inplace=True)
        )

        # Bottleneck with attention
        self.bottleneck = nn.Sequential(
            ResidualBlock(base_channels*4),
            FrequencyAttention(base_channels*4),
            ResidualBlock(base_channels*4)
        )

        # Decoder
        self.dec3 = nn.Sequential(
            nn.ConvTranspose2d(base_channels*4, base_channels*2, 2, stride=2),
            nn.LeakyReLU(0.1, inplace=True)
        )
        self.dec2 = nn.Sequential(
            nn.ConvTranspose2d(base_channels*4, base_channels, 2, stride=2),
            nn.LeakyReLU(0.1, inplace=True)
        )
        self.dec1 = nn.Conv2d(base_channels*2, in_channels, 3, padding=1)

        # Edge enhancement
        self.edge_conv = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, 3, padding=1),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Conv2d(in_channels, in_channels, 3, padding=1)
        )

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x, mask):
        # Apply frequency mask
        x = x * mask

        # Encoder path
        e1 = self.enc1(x)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)

        # Bottleneck
        b = self.bottleneck(e3)

        # Decoder path with skip connections
        d3 = self.dec3(b)
        d2 = self.dec2(torch.cat([d3, e2], dim=1))
        d1 = self.dec1(torch.cat([d2, e1], dim=1))

        # Edge-aware refinement
        edges = self.edge_conv(x)
        return d1 + edges

class DWT2DenoisingModel(nn.Module):
    def __init__(self, in_channels=3):
        super().__init__()
        self.in_channels = in_channels

        # Wavelet transforms
        self.dwt1 = HaarDWT(in_channels)
        self.dwt2 = HaarDWT(in_channels)
        self.idwt2 = HaarIDWT(in_channels)
        self.idwt1 = HaarIDWT(in_channels)

        # Low-frequency processor
        self.ll_processor = nn.Sequential(
            ResidualBlock(in_channels),
            FrequencyAttention(in_channels),
            ResidualBlock(in_channels)
        )

        # Edge and flatness detection
        self.edge_detector = nn.Sequential(
            nn.Conv2d(in_channels, in_channels//2, 3, padding=1),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Conv2d(in_channels//2, 1, 3, padding=1),
            nn.Sigmoid()
        )

        # Main denoiser
        self.denoiser = EdgeAwareDenoiser(in_channels=6*in_channels)

        # Initialize
        self._initialize_weights()
        self.recon_test = PerfectReconstructionTest(in_channels)

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def compute_frequency_mask(self, ll_band):
        # Flatness mask
        mean = F.avg_pool2d(ll_band, 3, stride=1, padding=1)
        variance = (ll_band - mean).pow(2)
        flatness = 1 - torch.tanh(variance.mean(dim=1, keepdim=True) * 5

        # Edge mask
        edges = self.edge_detector(ll_band)

        # Combined mask
        return torch.sigmoid(flatness + edges)

    def forward(self, x):
        B, C, H, W = x.shape

        # Pad to multiples of 4
        pad_h = (4 - H % 4) % 4
        pad_w = (4 - W % 4) % 4
        x_pad = F.pad(x, (0, pad_w, 0, pad_h), mode='reflect')

        # Level 1 decomposition
        dwt1_out = self.dwt1(x_pad)
        LL1, LH1, HL1, HH1 = torch.chunk(dwt1_out, 4, dim=1)

        # Level 2 decomposition
        dwt2_out = self.dwt2(LL1)
        LL2, LH2, HL2, HH2 = torch.chunk(dwt2_out, 4, dim=1)

        # Process LL band
        LL2_processed = self.ll_processor(LL2)

        # Compute frequency-aware mask
        freq_mask = self.compute_frequency_mask(LL2_processed)

        # Prepare denoiser input
        denoiser_input = torch.cat([
            LH2 * freq_mask,
            HL2 * freq_mask,
            HH2 * freq_mask,
            F.interpolate(LH1, scale_factor=0.5, mode='bilinear') * freq_mask,
            F.interpolate(HL1, scale_factor=0.5, mode='bilinear') * freq_mask,
            F.interpolate(HH1, scale_factor=0.5, mode='bilinear') * freq_mask
        ], dim=1)

        # Denoise high frequencies
        denoised = self.denoiser(denoiser_input, freq_mask)
        den_LH2, den_HL2, den_HH2, den_LH1, den_HL1, den_HH1 = torch.chunk(denoised, 6, dim=1)

        # Reconstruct level 2
        idwt2_in = torch.cat([LL2_processed, den_LH2, den_HL2, den_HH2], dim=1)
        LL1_recon = self.idwt2(idwt2_in)

        # Reconstruct level 1
        idwt1_in = torch.cat([
            LL1_recon,
            F.interpolate(den_LH1, size=LL1_recon.shape[-2:], mode='bilinear'),
            F.interpolate(den_HL1, size=LL1_recon.shape[-2:], mode='bilinear'),
            F.interpolate(den_HH1, size=LL1_recon.shape[-2:], mode='bilinear')
        ], dim=1)

        recon = self.idwt1(idwt1_in)

        # Crop to original size
        return recon[:, :, :H, :W]

class MultiScaleLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()

        # Sobel filters for edge detection
        sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).view(1, 1, 3, 3)
        sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32).view(1, 1, 3, 3)
        self.register_buffer('sobel_x', sobel_x)
        self.register_buffer('sobel_y', sobel_y)

    def forward(self, output, target):
        # Pixel-level loss
        mse_loss = self.mse(output, target)

        # Edge loss
        def sobel(t):
            gx = F.conv2d(t, self.sobel_x.repeat(t.size(1), 1, 1, 1), padding=1, groups=t.size(1))
            gy = F.conv2d(t, self.sobel_y.repeat(t.size(1), 1, 1, 1), padding=1, groups=t.size(1))
            return torch.sqrt(gx**2 + gy**2 + 1e-6)

        edge_loss = F.l1_loss(sobel(output), sobel(target))

        # Frequency loss
        dwt = HaarDWT(3)
        out_coeffs = dwt(output)
        tgt_coeffs = dwt(target)
        freq_loss = sum(F.mse_loss(o, t) for o, t in zip(out_coeffs, tgt_coeffs))

        return mse_loss + 0.3 * edge_loss + 0.2 * freq_loss



In [None]:
loss_module = MultiScaleLoss()

In [None]:
def eval_model(model, val_dataloader, noise_type, noise_parameter):
    model.eval()
    psnr = []
    mse = []
    with torch.no_grad():
        for images, _ in val_dataloader:
            if (noise_type == "normal"):
                noisy_images = (images + torch.normal(0,noise_parameter,images.shape)).clip(0,1)
            elif (noise_type == "bernoulli"):
                a = noise_parameter*torch.ones(images.shape)
                noisy_images = images*torch.bernoulli(a)
            elif (noise_type == "poisson"):
                a = noise_parameter*torch.ones(images.shape)
                p = torch.poisson(a)
                p_norm = p/p.max()
                noisy_images = (images + p_norm).clip(0,1)
            images = images.to(device)
            noisy_images = noisy_images.to(device)
            preds = model(images)
            psnr.extend(PSNR(images.cpu().detach(), preds.cpu().detach()))
            mse.extend(MSE(images.cpu().detach(), preds.cpu().detach()))
        print(f"Peak Signal to Noise Ratio:   Mean: {np.array(psnr).mean()} || Std: {np.array(psnr).std()}")
        print(f"Mean Squared Error:   Mean: {np.array(mse).mean()} || Std: {np.array(mse).std()}")
        return np.array(psnr).mean(), np.array(mse).mean()

In [None]:
def train_model(model, noise_type, noise_parameter, optimizer, train_dataloader, val_dataloader, loss_module, target_type="clean", num_epochs=30):
    model.train()
    epoch_num = []
    mse_train = []
    mse_val = []
    psnr_train = []
    psnr_val = []
    mse = 0.0
    psnr = 0.0
    for epoch in range(num_epochs):
        for images, _ in train_dataloader:
            targets = torch.clone(images)
            if (noise_type == "normal"):
                images = (images + torch.normal(0,noise_parameter,images.shape)).clip(0,1)
            elif (noise_type == "bernoulli"):
                a = noise_parameter*torch.ones(images.shape)
                images = images*torch.bernoulli(a)
            elif (noise_type == "poisson"):
                a = noise_parameter*torch.ones(images.shape)
                p = torch.poisson(a)
                p_norm = p/p.max()
                images = (images + p_norm).clip(0,1)
            if (target_type == "noisy"):
                if (noise_type == "normal"):
                    targets = (targets + torch.normal(0,noise_parameter,targets.shape)).clip(0,1)
                elif (noise_type == "bernoulli"):
                    a = noise_parameter*torch.ones(targets.shape)
                    targets = targets*torch.bernoulli(a)
                elif (noise_type == "poisson"):
                    a = noise_parameter*torch.ones(targets.shape)
                    p = torch.poisson(a)
                    p_norm = p/p.max()
                    targets = (targets + p_norm).clip(0,1)
            images = images.to(device)
            targets = targets.to(device)
            preds = model(images)
            loss = loss_module(preds, targets)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        if (epoch%3 == 0):
            print(f"********EPOCH {epoch+1}:********")
            epoch_num.append(epoch+1)
            print("Train set:")
            psnr, mse = eval_model(model, train_dataloader, noise_type, noise_parameter)
            psnr_train.append(psnr)
            mse_train.append(mse)
            print("Validation set:")
            psnr, mse = eval_model(model, val_dataloader, noise_type, noise_parameter)

            psnr_val.append(psnr)
            mse_val.append(mse)

<span style="font-size:18px;">In the training and validation functions I've taken into consideration various noise types to make it easier if we want to generalize a model by training it each time on a different type and with different parameters.

In [None]:
normal_ae_mse = DWT2DenoisingModel()
normal_optimizer_mse = optim.Adam(normal_ae_mse.parameters(), lr=1e-3)
normal_ae_mse = normal_ae_mse.to(device)

In [None]:
bernoulli_ae = DWT2DenoisingModel()
bernoulli_optimizer = optim.Adam(bernoulli_ae.parameters(), lr=1e-3)
bernoulli_ae = bernoulli_ae.to(device)

In [None]:
poisson_ae = DWT2DenoisingModel()
poisson_optimizer = optim.Adam(poisson_ae.parameters(), lr=1e-3)
poisson_ae = poisson_ae.to(device)

In [None]:
%%time
train_model(normal_ae_mse, "normal", 0.2, normal_optimizer_mse, train_dataloader, val_dataloader, loss_module)

In [None]:
ae = DWT2DenoisingModel()
ae = ae.to(device)

In [None]:
normal_ae.eval()

In [None]:
images, _ = next(iter(train_dataloader))
images = images.float().to(device)
output = normal_ae(images)

In [None]:
imshow(images[1].cpu().detach(), normalize=False)

In [None]:
images = (images + torch.normal(0,0.2,images.shape)).clip(0,1)

In [None]:
imshow(images[1], normalize=False)

In [None]:
output = normal_ae_mse(images.to(device))

In [None]:
imshow(output[1].cpu().detach(), normalize=False)

In [None]:
images, _ = next(iter(val_dataloader))
images = images.float().to(device)
output = normal_ae_mse(images)

In [None]:
imshow(images[1].cpu().detach(), normalize = False)

In [None]:
noisy_images = (images + torch.normal(0,0.2,images.shape).to(device)).clip(0,1)

In [None]:
imshow(noisy_images[1].cpu().detach(), normalize = False)

In [None]:
output = normal_ae_mse(images)

In [None]:
imshow(output[1].cpu().detach(), normalize=False)

In [None]:
MSE(images, output)

In [None]:
PSNR(images, output)

In [None]:
a = 5*torch.ones(images.shape)
p = torch.poisson(a)
p_norm = p/p.max()
images = (images + p_norm.to(device)).clip(0,1)

In [None]:
imshow(images[1].cpu().detach(), normalize=False)

In [None]:
output = normal_ae_mse(images)

In [None]:
imshow(output[1].cpu().detach(), normalize=False)

<span style="font-size:18px;">The model actually gave a good, but quite oversmoothed, result on a Poisson-Noisy image even though it wasn't trained on it.

In [None]:
normal_noisy_ae = autoencoders()
normal_noisy_optimizer = optim.Adam(normal_noisy_ae.parameters(), lr=1e-2)
normal_noisy_ae = normal_noisy_ae.to(device)

In [None]:
%%time
train_model(normal_noisy_ae, "noisy", 0.2, normal_noisy_optimizer, train_dataloader, loss_module)

<span style="font-size:18px;">Here I actually wanted to try an idea presented by J.Lehtinen et al. in the paper "Noise2Noise: Learning Image Restoration without Clean Data" ; basically providing noisy target and the optimization problem would lead the model to find the clean representation based on the hypothesis that the noise is centered.

In [None]:
images = images.float().to(device)
output = normal_noisy_ae(images)

In [None]:
imshow(images[1].cpu().detach(), normalize=False)

In [None]:
imshow(output[1].cpu().detach(), normalize=False)

In [None]:
images = (images + torch.normal(0,0.2,images.shape).to(device)).clip(0,1)

In [None]:
imshow(images[1].cpu().detach(), normalize=False)

In [None]:
output = normal_noisy_ae(images)

In [None]:
imshow(output[1].cpu().detach(), normalize=False)

<span style="font-size:18px;">Unfortunately, after training for a long time (around 6 hours), the model wasn't able to provide the desired result. This can be due to the limited size of the model or the number of the training epochs.

In [None]:
normal_noisy_optimizer = optim.Adam(normal_noisy_ae.parameters(), lr=0.005)

In [None]:
%%time
train_model(normal_noisy_ae, "clean", 0.2, normal_noisy_optimizer, train_dataloader, loss_module, num_epochs = 10)

In [None]:
imshow(images[1].cpu().detach(), normalize=False)

In [None]:
noisy_images = (images + torch.normal(0,0.2,images.shape)).clip(0,1).to(device)

In [None]:
imshow(noisy_images[1].cpu().detach(), normalize=False)

In [None]:
output = normal_noisy_ae(noisy_images)

In [None]:
imshow(output[1].cpu().detach(), normalize=False)

<span style="font-size:18px;">I trained the model this time on clean target (keeping the same model that was trained on noisy data), but it seems like there is no hope with it so we decided to move on from the idea.

In [None]:
# Saving the weights of the first model
state_dict = normal_ae_mse.state_dict()
print(state_dict)

In [None]:
loss_module = L1Loss()

In [None]:
normal_ae_mae = autoencoders()
normal_optimizer_mae = optim.Adam(normal_ae_mae.parameters(), lr=1e-3)
normal_ae_mae = normal_ae_mae.to(device)

In [None]:
%%time
train_model(normal_ae_mae, "normal", 0.2, normal_optimizer_mae, train_dataloader, val_dataloader, loss_module)

<span style="font-size:18px;">We trained the second model (same architecture but different loss function).

In [None]:
images = images.float().to(device)

In [None]:
imshow(images[1].cpu().detach(), normalize = False)

In [None]:
output = normal_ae_mae(images)

In [None]:
imshow(output[1].cpu().detach(), normalize = False)

In [None]:
noisy_images = (images.cpu().detach() + torch.normal(0,0.2,images.shape)).clip(0,1).to(device)

In [None]:
imshow(noisy_images[1].cpu().detach(), normalize = False)

In [None]:
output = normal_ae_mae(noisy_images)

In [None]:
imshow(output[1].cpu().detach(), normalize = False)

In [None]:
state_dict = torch.load("normal_ae_mse_30epochs.tar")
normal_ae_mse = autoencoders()
normal_ae_mse.load_state_dict(state_dict)

In [None]:
normal_ae_mse = normal_ae_mse.to(device)

In [None]:
output1 = normal_ae_mse(noisy_images)

In [None]:
imshow(output1[1].cpu().detach(), normalize = False)

In [None]:
state_dict = normal_mae.state_dict()

## Results from the Auto-Encoder Models

In [None]:
normal_ae_mse = autoencoders()
state_dict = torch.load("normal_ae_mse_30epochs.tar")
normal_ae_mse.load_state_dict(state_dict)

In [None]:
normal_ae_mse = normal_ae_mse.to(device)

In [None]:
psnr = []
with torch.no_grad():
        for i, (images, _) in enumerate(val_dataloader):
            noisy_images = (images + torch.normal(0,0.2,images.shape)).clip(0,1)
            noisy_images = noisy_images.to(device)
            images = images.to(device)
            preds = normal_ae_mse(noisy_images)
            if (i<4):
                print(PSNR(images.cpu().detach(), preds.cpu().detach()))
            psnr.extend(PSNR(images.cpu().detach(), preds.cpu().detach()))

In [None]:
mse_psnr = psnr.copy()

In [None]:
mse_psnr = np.array(mse_psnr)

In [None]:
print(f"The mean of the PSNR is {mse_psnr.mean()} and the standard deviation of the PSNR is {mse_psnr.std()}")

The mean of the PSNR is 31.12668800354004 and the standard deviation of the PSNR is 1.1921919584274292

In [None]:
normal_ae_mae = autoencoders()
state_dict = torch.load("normal_ae_mae_30epochs.tar")
normal_ae_mae.load_state_dict(state_dict)

In [None]:
normal_ae_mae = normal_ae_mae.to(device)

In [None]:
psnr = []
with torch.no_grad():
        for i, (images, _) in enumerate(val_dataloader):
            noisy_images = (images + torch.normal(0,0.2,images.shape)).clip(0,1)
            noisy_images = noisy_images.to(device)
            images = images.to(device)
            preds = normal_ae_mae(noisy_images)
            if (i<4):
                print(PSNR(images.cpu().detach(), preds.cpu().detach()))
            psnr.extend(PSNR(images.cpu().detach(), preds.cpu().detach()))

In [None]:
mae_psnr = np.array(psnr.copy())
print(f"The mean of the PSNR is {mae_psnr.mean()} and the standard deviation of the PSNR is {mae_psnr.std()}")

The mean of the PSNR is 31.070117950439453 and the standard deviation of the PSNR is 1.2800648212432861

In [None]:
psnr=[]
for i, (images, _) in enumerate(val_dataloader):
    noisy_images = (images + torch.normal(0,0.2,images.shape)).clip(0,1)
    for j in range(images.shape[0]):
        psnr.append(PSNR(images[j,:,:,:].unsqueeze(0), torch.tensor(bm3d.bm3d(noisy_images[j,:,:,:].permute(1,2,0), sigma_psd=30/255, stage_arg=bm3d.BM3DStages.HARD_THRESHOLDING)).permute(2,0,1).unsqueeze(0)))
    if (i<4):
        print(psnr)

In [None]:
bm3d_psnr = torch.tensor(psnr.copy())
print(f"The mean of the PSNR is {bm3d_psnr.mean()} and the standard deviation of the PSNR is {bm3d_psnr.std()}")

The mean of the PSNR is 23.5729346128809 and the standard deviation of the PSNR is 0.4025119048011183


![first_result.png](attachment:2d4d44fd-d242-4052-8bce-d8d76bef175b.png)