# DENOISER


## Dataset creation

Dataset: https://www.kaggle.com/datasets/hsankesara/flickr-image-dataset

In [None]:
import os
import cv2
import numpy as np
from tqdm import tqdm
import random

#Define the paths to the directories
raw_data_dir = 'rawdata'
output_dir = 'Data'
clean_dir = os.path.join(output_dir, 'clean')
noisy_dir = os.path.join(output_dir, 'noisy')

os.makedirs(clean_dir, exist_ok=True)
os.makedirs(noisy_dir, exist_ok=True)

In [None]:
def add_gaussian_noise(image, mean=0, sigma=25):
    row, col, ch = image.shape
    gauss = np.random.normal(mean, sigma, (row, col, ch))
    noisy = image + gauss
    noisy = np.clip(noisy, 0, 255)
    return noisy.astype(np.uint8)

def add_salt_and_pepper_noise(image, prob=0.05):
    output = np.copy(image)
    #Salt mode
    salt_mask = np.random.random(image.shape) < prob/2
    output[salt_mask] = 255
    #Pepper mode
    pepper_mask = np.random.random(image.shape) < prob/2
    output[pepper_mask] = 0
    return output

def add_poisson_noise(image, scale=1.0):
    noisy = np.random.poisson(image * scale) / scale
    return np.clip(noisy, 0, 255).astype(np.uint8)

def add_speckle_noise(image, std=0.1):
    noise = np.random.normal(0, std, image.shape)
    noisy = image + image * noise
    return np.clip(noisy, 0, 255).astype(np.uint8)

#List of noise functions and their parameters
noise_functions = [
    (add_gaussian_noise, {}),
    (add_salt_and_pepper_noise, {}),
    (add_poisson_noise, {'scale': 1.0}),
    (add_speckle_noise, {})
]

In [None]:
#Process images
image_files = os.listdir(raw_data_dir)
target_size = (256, 256)

for idx, img_name in enumerate(tqdm(image_files)):
    if not img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
        continue
        
    #Read image
    img_path = os.path.join(raw_data_dir, img_name)
    img = cv2.imread(img_path)
    if img is None:
        continue
        
    #Resize image
    img_resized = cv2.resize(img, target_size)
    
    #Save clean image
    clean_path = os.path.join(clean_dir, f'{idx:05d}.png')
    cv2.imwrite(clean_path, img_resized)
    
    #Select random noise function
    noise_func, params = random.choice(noise_functions)
    
    #Add noise to image
    noisy_img = noise_func(img_resized, **params)
    
    #Save noisy image
    noisy_path = os.path.join(noisy_dir, f'{idx:05d}.png')
    cv2.imwrite(noisy_path, noisy_img)

In [None]:
#Check the number of images generated
print(f"Clean images generated: {len(os.listdir(clean_dir))}")
print(f"Noisy images generated: {len(os.listdir(noisy_dir))}")

## UNet

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.transforms.functional as TF
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from torchmetrics import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
import gc

In [None]:
BATCH_SIZE = 6  
EPOCHS = 100
LEARNING_RATE = 0.001
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
IMAGE_SIZE = 256
CHECKPOINT_DIR = 'checkpoints'

#Memory optimization configurations
torch.backends.cudnn.benchmark = True
torch.cuda.empty_cache()

if not os.path.exists(CHECKPOINT_DIR):
    os.makedirs(CHECKPOINT_DIR)

In [None]:
class DenoisingDataset(Dataset):
    def __init__(self, clean_dir, noisy_dir, transform=None):
        self.clean_dir = clean_dir
        self.noisy_dir = noisy_dir
        self.transform = transform
        self.image_files = os.listdir(clean_dir)
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        clean_image = Image.open(os.path.join(self.clean_dir, img_name))
        noisy_image = Image.open(os.path.join(self.noisy_dir, img_name))
        
        if self.transform:
            clean_image = self.transform(clean_image)
            noisy_image = self.transform(noisy_image)
        
        return noisy_image, clean_image

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, in_channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.conv2 = nn.Conv2d(in_channels, in_channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(in_channels)
        self.relu = nn.LeakyReLU(0.2, inplace=True)

    def forward(self, x):
        residual = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += residual
        out = self.relu(out)
        return out

class SpatialAttention(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, in_channels // 8, 1)
        self.conv2 = nn.Conv2d(in_channels // 8, 1, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        att = self.conv1(x)
        att = self.conv2(att)
        att = self.sigmoid(att)
        return x * att

#Define the UNet model with better skip connections and residual blocks and spatial attention
class EnhancedUNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        #Large encoder with residual blocks and spatial attention
        self.enc1 = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            ResidualBlock(64),
            SpatialAttention(64)
        )
        
        self.enc2 = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            ResidualBlock(128),
            SpatialAttention(128)
        )
        
        self.enc3 = nn.Sequential(
            nn.Conv2d(128, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            ResidualBlock(256),
            SpatialAttention(256)
        )
        
        self.enc4 = nn.Sequential(
            nn.Conv2d(256, 512, 3, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            ResidualBlock(512),
            SpatialAttention(512)
        )
        
        self.bottleneck = nn.Sequential(
            nn.Conv2d(512, 1024, 3, padding=1),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace=True),
            ResidualBlock(1024),
            ResidualBlock(1024),
            SpatialAttention(1024)
        )
        
        self.dec4 = nn.Sequential(
            nn.Conv2d(1024 + 512, 512, 3, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            ResidualBlock(512),
            SpatialAttention(512)
        )
        
        self.dec3 = nn.Sequential(
            nn.Conv2d(512 + 256, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            ResidualBlock(256),
            SpatialAttention(256)
        )
        
        self.dec2 = nn.Sequential(
            nn.Conv2d(256 + 128, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            ResidualBlock(128),
            SpatialAttention(128)
        )
        
        self.dec1 = nn.Sequential(
            nn.Conv2d(128 + 64, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            ResidualBlock(64),
            SpatialAttention(64)
        )
        
        self.final = nn.Sequential(
            nn.Conv2d(64, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(32, 3, 1),
            nn.Sigmoid()
        )
        
        self.pool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        
    def forward(self, x):
        #Encoder path with skip connections
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))
        
        # Bottleneck
        b = self.bottleneck(self.pool(e4))
        
        # Decoder path with skip connections
        d4 = self.dec4(torch.cat([self.upsample(b), e4], dim=1))
        d3 = self.dec3(torch.cat([self.upsample(d4), e3], dim=1))
        d2 = self.dec2(torch.cat([self.upsample(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.upsample(d2), e1], dim=1))
        
        #Final ootput
        out = self.final(d1)
        
        return out

In [None]:
#Data preparation
transform = transforms.Compose([transforms.ToTensor(),])

dataset = DenoisingDataset('Data/clean', 'Data/noisy', transform=transform)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, pin_memory=True, num_workers=2)

In [None]:
#Initialize the model, loss function, and optimizer
model = EnhancedUNet().to(DEVICE)
criterion = nn.L1Loss()  
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)  

#Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)

#VRAM memory management optimization
scaler = torch.cuda.amp.GradScaler()

#Evaluation metrics
psnr = PeakSignalNoiseRatio().to(DEVICE)
ssim = StructuralSimilarityIndexMeasure().to(DEVICE)

In [None]:
#Define an optimized training loop
def train_epoch(model, train_loader, criterion, optimizer, scaler):
    model.train()
    total_loss = 0
    total_psnr = 0
    total_ssim = 0
    
    for noisy, clean in tqdm(train_loader):
        noisy, clean = noisy.to(DEVICE), clean.to(DEVICE)
        
        optimizer.zero_grad()
        
        with torch.cuda.amp.autocast():
            output = model(noisy)
            loss = criterion(output, clean)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        total_loss += loss.item()
        
        with torch.no_grad():
            total_psnr += psnr(output, clean)
            total_ssim += ssim(output, clean)
        
        #Clear CUDA cache to manage memory
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    return total_loss / len(train_loader), total_psnr / len(train_loader), total_ssim / len(train_loader)

@torch.no_grad()
def validate(model, val_loader, criterion):
    model.eval()
    total_loss = 0
    total_psnr = 0
    total_ssim = 0
    
    for noisy, clean in val_loader:
        noisy, clean = noisy.to(DEVICE), clean.to(DEVICE)
        
        with torch.cuda.amp.autocast():
            output = model(noisy)
            loss = criterion(output, clean)
        
        total_loss += loss.item()
        total_psnr += psnr(output, clean)
        total_ssim += ssim(output, clean)
        
        #Clear CUDA cache to manage memory
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    return total_loss / len(val_loader), total_psnr / len(val_loader), total_ssim / len(val_loader)

In [None]:
#Training loop
best_val_loss = float('inf')
train_losses = []
val_losses = []
train_psnrs = []
val_psnrs = []
train_ssims = []
val_ssims = []

for epoch in range(EPOCHS):
    train_loss, train_psnr_val, train_ssim_val = train_epoch(model, train_loader, criterion, optimizer, scaler)
    val_loss, val_psnr_val, val_ssim_val = validate(model, val_loader, criterion)
    
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_psnrs.append(train_psnr_val)
    val_psnrs.append(val_psnr_val)
    train_ssims.append(train_ssim_val)
    val_ssims.append(val_ssim_val)
    
    print(f'Epoch {epoch+1}/{EPOCHS}:')
    print(f'Train Loss: {train_loss:.6f}, Train PSNR: {train_psnr_val:.2f}, Train SSIM: {train_ssim_val:.4f}')
    print(f'Val Loss: {val_loss:.6f}, Val PSNR: {val_psnr_val:.2f}, Val SSIM: {val_ssim_val:.4f}')
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
        }, f'{CHECKPOINT_DIR}/best_model.pth')
    
    if (epoch + 1) % 10 == 0:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
        }, f'{CHECKPOINT_DIR}/checkpoint_epoch_{epoch+1}.pth')
        
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()

In [None]:
#Metrics visualization
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.title('Loss Evolution')
plt.legend()

plt.subplot(1, 3, 2)
plt.plot(train_psnrs, label='Train PSNR')
plt.plot(val_psnrs, label='Val PSNR')
plt.title('PSNR Evolution')
plt.legend()

plt.subplot(1, 3, 3)
plt.plot(train_ssims, label='Train SSIM')
plt.plot(val_ssims, label='Val SSIM')
plt.title('SSIM Evolution')
plt.legend()

plt.tight_layout()
plt.show()

In [None]:
#Load the best model
def load_best_model():
    best_model = EnhancedUNet().to(DEVICE)
    checkpoint = torch.load(f'{CHECKPOINT_DIR}/best_model.pth')
    best_model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loading the best model from epoch:  {checkpoint['epoch']} with validation loss:  {checkpoint['val_loss']:.6f}")
    return best_model

best_model = load_best_model()
best_model.eval()

if torch.cuda.is_available():
    torch.cuda.empty_cache()

with torch.no_grad():
    #Take some test images from the validation set
    test_noisy, test_clean = next(iter(val_loader))
    test_noisy, test_clean = test_noisy.to(DEVICE), test_clean.to(DEVICE)
    test_output = best_model(test_noisy)
    
    #Show the results
    fig, axes = plt.subplots(4, 3, figsize=(15, 20))
    for i in range(4):
        #Noisy image
        noisy_img = test_noisy[i].cpu().permute(1, 2, 0).numpy()
        axes[i, 0].imshow(np.clip(noisy_img, 0, 1))
        axes[i, 0].set_title('Noisy')
        axes[i, 0].axis('off')
        
        #Model output
        denoised_img = test_output[i].cpu().permute(1, 2, 0).numpy()
        axes[i, 1].imshow(np.clip(denoised_img, 0, 1))
        axes[i, 1].set_title('Denoised (Best model)')
        axes[i, 1].axis('off')
        
        #Original clean image
        clean_img = test_clean[i].cpu().permute(1, 2, 0).numpy()
        axes[i, 2].imshow(np.clip(clean_img, 0, 1))
        axes[i, 2].set_title('Clean')
        axes[i, 2].axis('off')
        
        #Calculate PSNR and SSIM for each image
        img_psnr = psnr(test_output[i:i+1], test_clean[i:i+1])
        img_ssim = ssim(test_output[i:i+1], test_clean[i:i+1])
        plt.text(0, -0.5, f'PSNR: {img_psnr:.2f} dB, SSIM: {img_ssim:.4f}', 
                transform=axes[i, 1].transAxes)
    
    plt.suptitle('Denoising results', size=16)
    plt.tight_layout()
    plt.show()

#Calculate PSNR and SSIM for the best model on the validation set
def evaluate_model(model, dataloader):
    model.eval()
    total_psnr = 0
    total_ssim = 0
    n_samples = 0

    with torch.no_grad():
        for noisy, clean in dataloader:
            noisy, clean = noisy.to(DEVICE), clean.to(DEVICE)
            output = model(noisy)
            
            total_psnr += psnr(output, clean) * len(noisy)
            total_ssim += ssim(output, clean) * len(noisy)
            n_samples += len(noisy)
            
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
    
    avg_psnr = total_psnr / n_samples
    avg_ssim = total_ssim / n_samples
    return avg_psnr, avg_ssim

#Evaluate the best model on the validation set
print('\nCalculating PSNR and SSIM for the best model on the validation set...')
val_psnr, val_ssim = evaluate_model(best_model, val_loader)
print(f'Global metrics:')
print(f'PSNR: {val_psnr:.2f} dB')
print(f'SSIM: {val_ssim:.4f}')

In [None]:
import os
from PIL import Image
import torchvision.transforms.functional as TF
from datetime import datetime

#Create a directory for saving results
results_dir = 'denoised_results'
if not os.path.exists(results_dir):
    os.makedirs(results_dir)

#Load the best model
if 'best_model' not in locals():
    best_model = load_best_model()
best_model.eval()

#Function to denoise a single image
def denoise_image(image_path):
    #Load and preprocess the image
    img = Image.open(image_path).convert('RGB')
    img = TF.resize(img, (IMAGE_SIZE, IMAGE_SIZE))
    img_tensor = TF.to_tensor(img).unsqueeze(0).to(DEVICE)
    
    #Apply the model
    with torch.no_grad():
        with torch.cuda.amp.autocast():
            denoised = best_model(img_tensor)
    
    #Convert back to PIL image
    denoised = denoised.squeeze(0).cpu()
    denoised = TF.to_pil_image(denoised)
    return denoised

#Process test images
test_dir = 'Test_images'
print('Processing test imagages...')

test_results = []

for img_name in os.listdir(test_dir):
    if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
        img_path = os.path.join(test_dir, img_name)
        print(f'Processing {img_name}...')
        
        #Process the image
        denoised_img = denoise_image(img_path)
        
        #Save the denoised image
        save_name = f'denoised_{img_name}'
        save_path = os.path.join(results_dir, save_name)
        denoised_img.save(save_path)
        
        #Keep for visualization
        test_results.append((img_path, save_path))

print(f'\nResults saved at:  {results_dir}')

#Visualize some results
n_samples = min(4, len(test_results))
if n_samples == 1:
    #Case for a single image
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
    axes = np.array([[ax1, ax2]])  
else:
    #Case for multiple images
    fig, axes = plt.subplots(n_samples, 2, figsize=(12, 3*n_samples))
    axes = np.atleast_2d(axes) 

for idx, (noisy_path, denoised_path) in enumerate(test_results[:n_samples]):
    #Show original noisy image
    noisy_img = Image.open(noisy_path)
    noisy_img = TF.resize(noisy_img, (IMAGE_SIZE, IMAGE_SIZE))
    axes[idx, 0].imshow(noisy_img)
    axes[idx, 0].set_title('Original')
    axes[idx, 0].axis('off')
    
    #Show denoised image
    denoised_img = Image.open(denoised_path)
    axes[idx, 1].imshow(denoised_img)
    axes[idx, 1].set_title('Denoised')
    axes[idx, 1].axis('off')

plt.tight_layout()
plt.show()