In [2]:
import torch
import torch.optim as optim
import torch.nn as nn
import torchvision.transforms as transforms
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import os
import cv2
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import shutil
import matplotlib.pyplot as plt
import numpy as np
from ipywidgets import interact, IntSlider
import zipfile



In [3]:


class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, features=[64, 128, 256, 512]):
        super(UNet, self).__init__()
        
        # Encoder
        self.encoders = nn.ModuleList()
        for feature in features:
            self.encoders.append(self._conv_block(in_channels, feature))
            in_channels = feature
        
        # Bottleneck
        self.bottleneck = self._conv_block(features[-1], features[-1] * 2)
        
        # Decoder
        self.decoders = nn.ModuleList()
        for feature in reversed(features):
            self.decoders.append(nn.ConvTranspose2d(feature * 2, feature, kernel_size=2, stride=2))
            self.decoders.append(self._conv_block(feature * 2, feature))
        
        # Final Output Layer
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def _conv_block(self, in_channels, out_channels):
        """Double Convolution Block"""
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        skip_connections = []
        for encoder in self.encoders:
            x = encoder(x)
            skip_connections.append(x)
            x = F.max_pool2d(x, kernel_size=2, stride=2)
        
        x = self.bottleneck(x)

        skip_connections = skip_connections[::-1]
        for i in range(0, len(self.decoders), 2):
            x = self.decoders[i](x)  # Upconvolution
            skip_connection = skip_connections[i // 2]

            if x.shape != skip_connection.shape:
                x = F.interpolate(x, size=skip_connection.shape[2:], mode="bilinear", align_corners=True)

            x = torch.cat((skip_connection, x), dim=1)  # Skip connection
            x = self.decoders[i + 1](x)  # Convolution Block

        return self.final_conv(x)

In [4]:
model_path = '/kaggle/input/unet_denosing/pytorch/default/1/unet_model.pth'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(in_channels=3, out_channels=3).to(device)  # Initialize model
model.load_state_dict(torch.load(model_path))  # Load weights into the model
model.eval()



  model.load_state_dict(torch.load(model_path))  # Load weights into the model


UNet(
  (encoders): ModuleList(
    (0): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): ReLU(inplace=True)
    )
    (1): Sequential(
      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): ReLU(inplace=True)
    )
    (2): Sequential(
      (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): ReLU(inplace=True)
    )
    (3): Sequential(
      (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): ReLU(inplace=True)


## Getting Test Dataset

In [5]:
!gdown --fuzzy --id 1UZA_AEdV5EgqWl9lozYo12YrET-Pno6L

Downloading...
From (original): https://drive.google.com/uc?id=1UZA_AEdV5EgqWl9lozYo12YrET-Pno6L
From (redirected): https://drive.google.com/uc?id=1UZA_AEdV5EgqWl9lozYo12YrET-Pno6L&confirm=t&uuid=a24060cf-cf85-4952-962f-7a42a0bbb869
To: /kaggle/working/LSDIR_DIV2K_Test_Sigma50.zip
100%|██████████████████████████████████████| 1.15G/1.15G [00:13<00:00, 87.9MB/s]


In [8]:
class TestDenoiseDataset(Dataset):
    def __init__(self, noisy_dir, transform=None):
        self.noisy_images = sorted(os.listdir(noisy_dir))
        self.noisy_dir = noisy_dir
        self.transform = transform

    def __len__(self):
        return len(self.noisy_images)

    def __getitem__(self, idx):
        noisy_path = os.path.join(self.noisy_dir, self.noisy_images[idx])

        # Load noisy image
        noisy_img = cv2.imread(noisy_path)
        noisy_img = cv2.cvtColor(noisy_img, cv2.COLOR_BGR2RGB)  # Convert BGR to RGB
        noisy_img = Image.fromarray(noisy_img)  # Convert to PIL

        # Apply transformations
        if self.transform:
            noisy_img = self.transform(noisy_img)

        return noisy_img, self.noisy_images[idx]  # Return filename for saving

In [6]:


zip_file = "/kaggle/working/LSDIR_DIV2K_Test_Sigma50.zip"
extract_folder = "/kaggle/working/LSDIR_DIV2K_Test_Sigma50"

shutil.unpack_archive(zip_file, extract_folder)
# os.remove(zip_file)
print("Unzipped files")

Unzipped and deleted!


In [9]:
test_noisy_dir = "/kaggle/working/LSDIR_DIV2K_Test_Sigma50"

test_dataset = TestDenoiseDataset(test_noisy_dir, transform=transforms.ToTensor())
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

## Noised to denoised image

In [10]:


def denoise_images_with_slider(model, dataloader, device):
    model.eval()
    
    noisy_images = []
    denoised_images = []

    with torch.no_grad():  
        for noisy_imgs, _ in dataloader:
            noisy_imgs = noisy_imgs.to(device)

            # Predict denoised images
            denoised_imgs = model(noisy_imgs)

            # Convert tensors to NumPy for visualization
            for i in range(noisy_imgs.shape[0]):
                noisy_np = noisy_imgs[i].cpu().permute(1, 2, 0).numpy()
                denoised_np = denoised_imgs[i].cpu().permute(1, 2, 0).numpy()
                noisy_images.append(noisy_np)
                denoised_images.append(denoised_np)
    
    def show_images(index):
        fig, axs = plt.subplots(1, 2, figsize=(10, 5))
        axs[0].imshow(noisy_images[index])
        axs[0].set_title("Noisy Image")
        axs[0].axis("off")

        axs[1].imshow(denoised_images[index])
        axs[1].set_title("Denoised Image")
        axs[1].axis("off")

        plt.show()

    # Create interactive slider
    interact(show_images, index=IntSlider(0, 0, len(noisy_images) - 1, 1));

# Run the function
denoise_images_with_slider(model, test_loader, device)

interactive(children=(IntSlider(value=0, description='index', max=199), Output()), _dom_classes=('widget-inter…

## Collecting readme info

In [18]:
def collect_info_and_denoise(model, dataloader, device, info_dir):
    # Collect runtime, GPU/CPU info, and extra data usage
    runtime_per_image = []
    cpu_or_gpu = "GPU" if device.type == "cuda" else "CPU"
    
    extra_data_used = 0  

    # Start processing the images and measure runtime
    start_time = time.time()

    with torch.no_grad():  
        for noisy_imgs, _ in dataloader:
            noisy_imgs = noisy_imgs.to(device)

            # Record start time for each image
            start_img_time = time.time()

            # Predict denoised images (no need to save them)
            denoised_imgs = model(noisy_imgs)

            # Record end time for each image
            end_img_time = time.time()
            img_runtime = end_img_time - start_img_time
            runtime_per_image.append(img_runtime)

    # Calculate average runtime per image
    avg_runtime_per_image = sum(runtime_per_image) / len(runtime_per_image)
    
    # Write the information into a readme.txt file
    readme_content = f"""runtime per image [s]: {avg_runtime_per_image}
CPU[1] / GPU[0] : 1 if using CPU or 0 if using GPU: {cpu_or_gpu == 'GPU'}
Extra Data [1] / No Extra Data [0] : {extra_data_used}
Other description: Solution based on UNet architecture. The model uses the provided training data.
"""
    
    # Save the readme.txt file
    with open(os.path.join(info_dir, "readme.txt"), "w") as readme_file:
        readme_file.write(readme_content)
        
    print(f"readme.txt file has been saved to {info_dir}")

In [19]:
# Define paths
import time
info_dir = "runtime_info"
os.makedirs(info_dir, exist_ok=True)

# Call the function to collect info and generate readme.txt
collect_info_and_denoise(model, test_loader, device, info_dir)

readme.txt file has been saved to runtime_info


### Renaming files

In [35]:
import os
from PIL import Image
import zipfile

# Directory where the denoised images are saved
denoised_results_dir = "denoised_results"
output_dir = "denoised_results_renamed"  # You can overwrite the existing folder if you prefer
os.makedirs(output_dir, exist_ok=True)

# Function to compress images with Pillow
def compress_image(image_path):
    with Image.open(image_path) as img:
        # Save with lossless compression
        img.save(image_path, format="PNG", optimize=True)

# Iterate through each denoised image in the folder
for file in os.listdir(denoised_results_dir):
    denoised_image_path = os.path.join(denoised_results_dir, file)

    if os.path.isfile(denoised_image_path):
        # Remove the '_denoised' part from the filename to match the original input name
        base_filename = file.replace("_denoised", "")  # Remove '_denoised' from the filename
        new_image_path = os.path.join(output_dir, base_filename)  # New path with original filename

        # Open and save the denoised image with the original filename
        with Image.open(denoised_image_path) as img:
            img.save(new_image_path, format="PNG")  # Save as PNG

        # Compress the saved image to reduce the file size
        compress_image(new_image_path)  # Compress the image after saving

print(f"All denoised images have been renamed, compressed, and saved to {output_dir}.")

All denoised images have been renamed, compressed, and saved to denoised_results_renamed.


In [36]:
import os
import zipfile

# Directory where denoised images are already saved
denoised_results_dir = "denoised_results_renamed"
zip_filename = "denoised_results.zip"

# Create a ZIP file with the denoised images (no folder structure inside the ZIP)
with zipfile.ZipFile(zip_filename, 'w', zipfile.ZIP_DEFLATED) as zipf:
    # Iterate over each file in the denoised_results directory
    for file in os.listdir(denoised_results_dir):
        file_path = os.path.join(denoised_results_dir, file)
        
        if os.path.isfile(file_path):  # Ensure it's a file and not a directory
            zipf.write(file_path, os.path.basename(file_path))  # Add file to the ZIP without folder structure

print(f"Denoised images have been zipped into {zip_filename}")


Denoised images have been zipped into denoised_results.zip


In [37]:
print(f"ZIP file size: {os.path.getsize(zip_filename) / (1024 * 1024)} MB")


ZIP file size: 812.6927871704102 MB


In [34]:
os.remove(zip_filename)

## Saving images denoised from model

In [11]:

# Directory to save the denoised images temporarily
output_dir = "denoised_results"
os.makedirs(output_dir, exist_ok=True)

# Switch the model to evaluation mode
model.eval()

# Disable gradient calculation during inference
with torch.no_grad():
    for noisy_imgs, filenames in test_loader:
        noisy_imgs = noisy_imgs.to(device)

        # Denoise the noisy images
        denoised_imgs = model(noisy_imgs)

        # Normalize outputs to ensure they are in the range [0, 1]
        denoised_imgs = torch.clamp(denoised_imgs, 0, 1)  # Ensure values are within [0, 1]

        # Save the denoised images
        for i in range(denoised_imgs.shape[0]):  # Handle batch processing
            denoised_pil = transforms.ToPILImage()(denoised_imgs[i].cpu())  # Convert tensor to PIL image
            
            # Handle filenames to ensure they are unique and safe for saving
            filename = filenames[i]
            base_filename, ext = os.path.splitext(filename)
            output_filename = f"{base_filename}_denoised{ext}"

            # Save the denoised image to the output directory
            denoised_pil.save(os.path.join(output_dir, output_filename))

# Create a zip file with the denoised images
zip_filename = "denoised_results.zip"
with zipfile.ZipFile(zip_filename, 'w', zipfile.ZIP_DEFLATED) as zipf:
    # Add all files from the output directory to the zip file
    for root, dirs, files in os.walk(output_dir):
        for file in files:
            zipf.write(os.path.join(root, file), os.path.relpath(os.path.join(root, file), output_dir))

# remove the temporary directory after zipping the images
# import shutil
# shutil.rmtree(output_dir)

print(f"Denoised images have been zipped into {zip_filename}")


Denoised images have been zipped into denoised_results.zip


## Comparison betweeen noised and denoised images

In [12]:


# Set directories
noisy_dir = "/kaggle/working/LSDIR_DIV2K_Test_Sigma50"  # Folder with noisy test images
denoised_dir = "/kaggle/working/denoised_results"  # Folder with saved denoised images

# Load image filenames
noisy_images = sorted(os.listdir(noisy_dir))
denoised_images = sorted(os.listdir(denoised_dir))

# Ensure matching images
assert len(noisy_images) == len(denoised_images), "Mismatch in number of images!"

# Load and preprocess images
def load_image(image_path):
    image = Image.open(image_path).convert("RGB")
    return np.array(image)

# Interactive visualization
def show_images(index):
    noisy_img = load_image(os.path.join(noisy_dir, noisy_images[index]))
    denoised_img = load_image(os.path.join(denoised_dir, denoised_images[index]))

    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    axes[0].imshow(noisy_img)
    axes[0].set_title("Noisy Image")
    axes[0].axis("off")

    axes[1].imshow(denoised_img)
    axes[1].set_title("Denoised Image")
    axes[1].axis("off")

    plt.show()

# Create interactive slider
interact(show_images, index=IntSlider(0, 0, len(noisy_images) - 1, 1));

interactive(children=(IntSlider(value=0, description='index', max=199), Output()), _dom_classes=('widget-inter…