In [1]:
import cv2
import numpy as np
import torch
from torchvision import transforms
from skimage import img_as_float, img_as_ubyte
from skimage.restoration import denoise_wavelet
from PIL import Image

In [2]:
from torch.hub import load_state_dict_from_url

MODEL_URL = "https://github.com/SaoYan/DnCNN-PyTorch/raw/master/TrainingCodes/dncnn_sigma25.pth"

class DnCNN(torch.nn.Module):
    def __init__(self, channels=1, num_of_layers=17):
        super(DnCNN, self).__init__()
        layers = []
        layers.append(torch.nn.Conv2d(channels, 64, kernel_size=3, padding=1, bias=False))
        layers.append(torch.nn.ReLU(inplace=True))
        for _ in range(num_of_layers-2):
            layers.append(torch.nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False))
            layers.append(torch.nn.BatchNorm2d(64))
            layers.append(torch.nn.ReLU(inplace=True))
        layers.append(torch.nn.Conv2d(64, channels, kernel_size=3, padding=1, bias=False))
        self.dncnn = torch.nn.Sequential(*layers)

    def forward(self, x):
        out = self.dncnn(x)
        return x - out  # Residual learning

In [3]:
model = DnCNN(channels=1)
state_dict = load_state_dict_from_url(MODEL_URL, map_location='cpu')
model.load_state_dict(state_dict)
model.eval()

Downloading: "https://github.com/SaoYan/DnCNN-PyTorch/raw/master/TrainingCodes/dncnn_sigma25.pth" to C:\Users\91909/.cache\torch\hub\checkpoints\dncnn_sigma25.pth


HTTPError: HTTP Error 404: Not Found

In [4]:
def denoise_image(image_path, output_path):
    # Load and preprocess image
    image = Image.open(image_path).convert('L')  # Convert to grayscale
    transform = transforms.Compose([transforms.ToTensor()])
    img_tensor = transform(image).unsqueeze(0)  # Add batch dimension

    # Denoise using DnCNN
    with torch.no_grad():
        denoised_tensor = model(img_tensor)

    # Convert back to image
    denoised_image = denoised_tensor.squeeze().numpy()
    denoised_image = np.clip(denoised_image, 0, 1)  # Normalize

    # Save the denoised image
    denoised_image = img_as_ubyte(denoised_image)
    cv2.imwrite(output_path, denoised_image)
    print(f"Denoised image saved at {output_path}")

In [None]:
denoise_image("noisy_image.png", "denoised_image.png")