<a href="https://colab.research.google.com/github/HANKSOONG/image-repair-and-restoration-/blob/main/DnCNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!unzip "/content/drive/MyDrive/Colab Notebooks/GOPRO_Large.zip" -d "/content/datasets"

In [None]:
import os

cpu_count = os.cpu_count()

print(f"Number of CPU cores: {cpu_count}")

Number of CPU cores: 8


In [None]:
import os
import torch
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset, random_split
from PIL import Image

class CustomDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = os.path.expanduser(root_dir)
        self.transform = transform
        self.samples = self._load_samples()

    def _load_samples(self):
        samples = []
        for subdir in sorted(os.listdir(self.root_dir)):
            subdir_path = os.path.join(self.root_dir, subdir)
            if os.path.isdir(subdir_path):
                sharp_imgs = sorted(os.listdir(os.path.join(subdir_path, 'sharp')))
                blur_imgs = sorted(os.listdir(os.path.join(subdir_path, 'blur')))

                for sharp_img, blur_img in zip(sharp_imgs, blur_imgs):
                    sharp_path = os.path.join(subdir_path, 'sharp', sharp_img)
                    blur_path = os.path.join(subdir_path, 'blur', blur_img)
                    samples.append((sharp_path, blur_path))
        return samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sharp_path, blur_path = self.samples[idx]
        sharp_img_original = Image.open(sharp_path).convert('RGB')
        sharp_img = sharp_img_original.copy()  # Create a copy for transformation
        blur_img = Image.open(blur_path).convert('RGB')

        sharp_img_original = transforms.ToTensor()(sharp_img_original)

        if self.transform:
            sharp_img = self.transform(sharp_img)
            blur_img = self.transform(blur_img)

        return blur_img, sharp_img, sharp_img_original

# transform function
transform = transforms.Compose([
    transforms.Resize((360, 640)), # Lower resolution
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_dataset = CustomDataset(root_dir='/content/datasets/train', transform=transform)

# Split the data set into training set and validation set
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_subset, val_subset = random_split(train_dataset, [train_size, val_size])


train_loader = DataLoader(
    train_subset,
    batch_size=8,
    shuffle=True,
    num_workers=8
)
val_loader = DataLoader(
    val_subset,
    batch_size=8,
    shuffle=False,
    num_workers=8
)

In [None]:
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import mobilenet_v3_small, MobileNet_V3_Small_Weights
import torch.distributed as distance
from torch.cuda.amp import GradScaler, autocast

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#def Perceptual Loss by MobileNet
class PerceptualLoss(nn.Module):
    def __init__(self):
        super(PerceptualLoss, self).__init__()
        self.mobilenet = mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.DEFAULT).features
        self.mobilenet.eval()
        for param in self.mobilenet.parameters():
            param.requires_grad = False

    def forward(self, input, target):
        input_features = self.mobilenet(input)
        target_features = self.mobilenet(target)
        # Resize if necessary
        if input_features.shape[2:] != target_features.shape[2:]:
            input_features = F.interpolate(input_features, size=target_features.shape[2:], mode='bilinear', align_corners=False)

        loss = nn.functional.mse_loss(input_features, target_features)
        return loss

In [None]:
#Define the DnCNN Model
class DnCNN(nn.Module):
    def __init__(self, channels, num_of_layers=17):
        super(DnCNN, self).__init__()
        kernel_size = 3
        padding = 1
        features = 64
        layers = []

        # Initial convolution layer
        layers.append(nn.Conv2d(in_channels=channels, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
        layers.append(nn.ReLU(inplace=True))

        # Middle layers
        for _ in range(num_of_layers - 2):
            layers.append(nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
            layers.append(nn.BatchNorm2d(features))
            layers.append(nn.ReLU(inplace=True))

        # Final convolution layer
        layers.append(nn.Conv2d(in_channels=features, out_channels=channels, kernel_size=kernel_size, padding=padding, bias=False))

        self.dncnn = nn.Sequential(*layers)

    def forward(self, x):
        out = self.dncnn(x)
        return out


In [None]:
from torch.optim.lr_scheduler import ReduceLROnPlateau
# Creat model and Adam optimizer
denoising_model = DnCNN(channels=3).to(device)
mse_criterion = nn.MSELoss()
perceptual_criterion = PerceptualLoss().to(device)
optimizer = torch.optim.Adam(denoising_model.parameters(), lr=0.0001)

# Initialize the ReduceLROnPlateau scheduler
scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.1, patience=3, verbose=True)

In [None]:
#Defining Loss Function Weights
mse_weight = 1.0
perceptual_weight = 0.1

In [None]:
# Train model
def train_model(model, train_loader, val_loader, mse_criterion, perceptual_criterion, optimizer, num_epochs=100, early_stopping_tolerance=8):
    best_val_loss = float('inf')
    no_improvement_count = 0  # Early stopping counter

    scaler = GradScaler()
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        for blur_img, transformed_sharp_img, _ in train_loader:
            blur_img = blur_img.to(device)
            transformed_sharp_img = transformed_sharp_img.to(device)

            optimizer.zero_grad()

           # Performing forward propagation using the autocast context
            with autocast():
                outputs = model(blur_img)
                mse_loss = mse_criterion(outputs, transformed_sharp_img)
                perceptual_loss = perceptual_criterion(outputs, transformed_sharp_img)
                total_loss = mse_weight * mse_loss + perceptual_weight * perceptual_loss

            # Performing backward propagation and optimization using GradScaler
            scaler.scale(total_loss).backward()
            scaler.step(optimizer)
            scaler.update()

            running_loss += total_loss.item() * blur_img.size(0)
        train_loss = running_loss / len(train_loader.dataset)

        # Validation test
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for blur_img, transformed_sharp_img, _ in val_loader:
                blur_img = blur_img.to(device)
                transformed_sharp_img = transformed_sharp_img.to(device)
                outputs = model(blur_img)
                mse_loss = mse_criterion(outputs, transformed_sharp_img)
                perceptual_loss = perceptual_criterion(outputs, transformed_sharp_img)
                total_loss = mse_weight * mse_loss + perceptual_weight * perceptual_loss
                val_loss += total_loss.item() * blur_img.size(0)

        val_loss /= len(val_loader.dataset)
        print(f'Epoch {epoch+1}/{num_epochs}, Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}')

        scheduler.step(val_loss)
        # Early stopping check
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            no_improvement_count = 0
        else:
            no_improvement_count += 1
            if no_improvement_count >= early_stopping_tolerance:
                print("Stopping early due to no improvement in validation loss")
                break

    return model

train_model(denoising_model, train_loader, val_loader, mse_criterion, perceptual_criterion, optimizer, num_epochs=100, early_stopping_tolerance=8)

In [None]:
# Save model
if torch.cuda.is_available() and torch.cuda.current_device() == 0:
    model_path = '/content/drive/MyDrive/model/DnCNN/denoising_DnCNN.pth'
    model_dir = os.path.expanduser(os.path.dirname(model_path))

    if not os.path.exists(model_dir):
        os.makedirs(model_dir)

    torch.save(denoising_model.state_dict(), os.path.expanduser(model_path))
    print(f"Model saved to {model_path}.")

Model saved to /content/drive/MyDrive/model/DnCNN/denoising_DnCNN.pth.


In [None]:
# Inverse normalization transformation
inv_normalize = transforms.Normalize(
    mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
    std=[1/0.229, 1/0.224, 1/0.225]
)

In [None]:
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
from skimage.metrics import structural_similarity as compare_ssim
from skimage import img_as_float
import torch

# Initialize the sum of PSNR and SSIM
total_psnr_sharp = 0
total_ssim_sharp = 0
total_psnr_blur = 0
total_ssim_blur = 0
num_images = 0

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_path = '/content/drive/MyDrive/model/DnCNN/denoising_DnCNN3.pth'

model = DnCNN(channels=3)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
with torch.no_grad():
    for blur_img, sharp_img, _ in val_loader:
        blur_img = blur_img.to(device)
        sharp_img = sharp_img.to(device)

        denoised_img = model(blur_img)

# Traverse the image to calculate PSNR and SSIM
        for i in range(blur_img.size(0)):
            denoised = inv_normalize(denoised_img[i]).clamp(0, 1)
            sharp = inv_normalize(sharp_img[i]).clamp(0, 1)
            blur = inv_normalize(blur_img[i]).clamp(0, 1)

            denoised_np = denoised.cpu().numpy().transpose(1, 2, 0)
            sharp_np = sharp.cpu().numpy().transpose(1, 2, 0)
            blur_np = blur.cpu().numpy().transpose(1, 2, 0)

    # Calculate PSNR and SSIM
            psnr_sharp = compare_psnr(denoised_np, sharp_np)
            ssim_sharp = compare_ssim(denoised_np, sharp_np, multichannel=True)
            psnr_blur = compare_psnr(denoised_np, blur_np)
            ssim_blur = compare_ssim(denoised_np, blur_np, multichannel=True)

            total_psnr_sharp += psnr_sharp
            total_ssim_sharp += ssim_sharp
            total_psnr_blur += psnr_blur
            total_ssim_blur += ssim_blur
            num_images += 1

# Calculate average PSNR and SSIM
avg_psnr_sharp = total_psnr_sharp / num_images
avg_ssim_sharp = total_ssim_sharp / num_images
avg_psnr_blur = total_psnr_blur / num_images
avg_ssim_blur = total_ssim_blur / num_images

print(f'Average PSNR (Denoised vs Sharp): {avg_psnr_sharp}')
print(f'Average SSIM (Denoised vs Sharp): {avg_ssim_sharp}')
print(f'Average PSNR (Denoised vs Blur): {avg_psnr_blur}')
print(f'Average SSIM (Denoised vs Blur): {avg_ssim_blur}')

  ssim_sharp = compare_ssim(denoised_np, sharp_np, multichannel=True)
  ssim_blur = compare_ssim(denoised_np, blur_np, multichannel=True)


Average PSNR (Denoised vs Sharp): 27.399835853064033
Average SSIM (Denoised vs Sharp): 0.8841692714679836
Average PSNR (Denoised vs Blur): 33.32209119904878
Average SSIM (Denoised vs Blur): 0.97437595981317


In [None]:
import torch
import matplotlib.pyplot as plt

# This part uses CPU calculations to save resources.

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_path = '/content/drive/MyDrive/model/DnCNN/denoising_DnCNN3.pth'

model = DnCNN(channels=3)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
model = model.to(device)

# Def dataload
full_dataset = CustomDataset(root_dir='/content/datasets/train', transform=transform)
full_loader = DataLoader(full_dataset, batch_size=1, shuffle=False)

# Create picture directory
output_dir_base = '/content/drive/MyDrive/image_output'
denoised_dir = os.path.join(output_dir_base, 'denoised')
sharp_resized_dir = os.path.join(output_dir_base, 'sharp_resized')
sharp_original_dir = os.path.join(output_dir_base, 'sharp_original')
blur_resized_dir = os.path.join(output_dir_base, 'blur_resized')

os.makedirs(denoised_dir, exist_ok=True)
os.makedirs(sharp_resized_dir, exist_ok=True)
os.makedirs(sharp_original_dir, exist_ok=True)
os.makedirs(blur_resized_dir, exist_ok=True)

# Process and save images
for i, (blur_img, sharp_img, sharp_img_original) in enumerate(full_loader):
    blur_img = blur_img.to(device)

    # Model prediction and inverse normalization
    with torch.no_grad():
        denoised_img = model(blur_img)
    denoised_img = inv_normalize(denoised_img[0]).clamp(0, 1)

    # Save denoised image
    denoised_img_pil = transforms.ToPILImage()(denoised_img.cpu())
    denoised_img_pil.save(os.path.join(denoised_dir, f'image_{i:04d}.png'))

    # save resized sharp image
    sharp_img = inv_normalize(sharp_img[0]).clamp(0, 1)
    sharp_img_pil = transforms.ToPILImage()(sharp_img.cpu())
    sharp_img_pil.save(os.path.join(sharp_resized_dir, f'image_{i:04d}.png'))

    # save original image
    sharp_img_original_pil = transforms.ToPILImage()(sharp_img_original.cpu().squeeze(0))
    sharp_img_original_pil.save(os.path.join(sharp_original_dir, f'image_{i:04d}.png'))

    # Save resized blur image
    blur_img_resized = inv_normalize(blur_img.squeeze(0)).clamp(0, 1)
    blur_img_resized_pil = transforms.ToPILImage()(blur_img_resized.cpu())
    blur_img_resized_pil.save(os.path.join(blur_resized_dir, f'image_{i:04d}.png'))