# DENOISER


## Dataset creation

In [None]:
import os
import numpy as np
from PIL import Image
import torch
from torchvision import transforms
from tqdm import tqdm
import glob

def get_sidd_image_pairs():
    base_dir = 'test/Data'
    image_pairs = []
        
    for subdir in os.listdir(base_dir):
        subdir_path = os.path.join(base_dir, subdir)
        if os.path.isdir(subdir_path):
            gt_images = glob.glob(os.path.join(subdir_path, '*GT_SRGB*.PNG'))
            for gt_path in gt_images:
                noisy_path = gt_path.replace('GT_SRGB', 'NOISY_SRGB')
                if os.path.exists(noisy_path):
                    image_pairs.append((gt_path, noisy_path))
    
    return image_pairs

data2_dir = 'data2'
if not os.path.exists(data2_dir):
    os.makedirs(data2_dir)

clean_patches_dir = os.path.join(data2_dir, 'clean')
noisy_patches_dir = os.path.join(data2_dir, 'noisy')

if not os.path.exists(clean_patches_dir):
    os.makedirs(clean_patches_dir)
if not os.path.exists(noisy_patches_dir):
    os.makedirs(noisy_patches_dir)

processed_images_info = {}

print('Searching for image pairs...')
image_pairs = get_sidd_image_pairs()
print(f'Found {len(image_pairs)} image pairs.')

#Process each pair of images
print('\nProcessing images...')
for clean_path, noisy_path in tqdm(image_pairs):
    base_name = os.path.splitext(os.path.basename(clean_path))[0]
    
    #Process the clean image
    patches, positions, patch_dims, original_size = create_patches(
        clean_path,
        output_dir=clean_patches_dir,
        prefix=f'{base_name}_clean_')
    
    #Process the noisy image
    create_patches(
        noisy_path,
        output_dir=noisy_patches_dir,
        prefix=f'{base_name}_noisy_')
    
    processed_images_info[base_name] = {
        'patch_positions': positions,
        'patch_dimensions': patch_dims,
        'original_size': original_size,
        'clean_path': clean_path,
        'noisy_path': noisy_path}
    
#Save the processing information to reconstruct later
import json
with open(os.path.join(data2_dir, 'processing_info.json'), 'w') as f:
    json.dump(processed_images_info, f)

print('\nProcessing completed. Patches saved at:', data2_dir)

## UNet

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.transforms.functional as TF
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from torchmetrics import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
import gc
import json
import torchvision.transforms as transforms
from scipy.signal import windows

In [2]:
BATCH_SIZE = 6  
EPOCHS = 5
LEARNING_RATE = 0.001
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
IMAGE_SIZE = 256
CHECKPOINT_DIR = 'checkpoints'

#Memory optimization configurations
torch.backends.cudnn.benchmark = True
torch.cuda.empty_cache()

if not os.path.exists(CHECKPOINT_DIR):
    os.makedirs(CHECKPOINT_DIR)

In [None]:
class DenoisingDataset(Dataset):
    def __init__(self, clean_dir, noisy_dir):
        self.clean_dir = clean_dir
        self.noisy_dir = noisy_dir
        self.to_tensor = transforms.ToTensor()
        
        self.clean_patches = [f for f in os.listdir(clean_dir) if f.endswith('.png')]
        self.clean_patches.sort() 
        
    def __len__(self):
        return len(self.clean_patches)
    
    def __getitem__(self, idx):
        clean_name = self.clean_patches[idx]
        noisy_name = clean_name.replace('clean', 'noisy')
        
        clean_path = os.path.join(self.clean_dir, clean_name)
        noisy_path = os.path.join(self.noisy_dir, noisy_name)
        
        clean_image = Image.open(clean_path).convert('RGB')
        noisy_image = Image.open(noisy_path).convert('RGB')
        
        clean_image = self.to_tensor(clean_image)
        noisy_image = self.to_tensor(noisy_image)
        
        return noisy_image, clean_image

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, in_channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.conv2 = nn.Conv2d(in_channels, in_channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(in_channels)
        self.relu = nn.LeakyReLU(0.2, inplace=True)

    def forward(self, x):
        residual = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += residual
        return self.relu(out)

class SpatialAttention(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, in_channels // 8, 1)
        self.conv2 = nn.Conv2d(in_channels // 8, 1, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        attention = self.conv1(x)
        attention = self.conv2(attention)
        attention = self.sigmoid(attention)
        return x * attention

class EnhancedUNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.enc1 = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            ResidualBlock(64),
            SpatialAttention(64)
        )
        
        self.enc2 = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            ResidualBlock(128),
            SpatialAttention(128)
        )
        
        self.enc3 = nn.Sequential(
            nn.Conv2d(128, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            ResidualBlock(256),
            SpatialAttention(256)
        )
        
        self.enc4 = nn.Sequential(
            nn.Conv2d(256, 512, 3, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            ResidualBlock(512),
            SpatialAttention(512)
        )
        
        self.bottleneck = nn.Sequential(
            nn.Conv2d(512, 1024, 3, padding=1),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace=True),
            ResidualBlock(1024),
            ResidualBlock(1024),
            SpatialAttention(1024)
        )
        
        self.dec4 = nn.Sequential(
            nn.Conv2d(1024 + 512, 512, 3, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            ResidualBlock(512),
            SpatialAttention(512)
        )
        
        self.dec3 = nn.Sequential(
            nn.Conv2d(512 + 256, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            ResidualBlock(256),
            SpatialAttention(256)
        )
        
        self.dec2 = nn.Sequential(
            nn.Conv2d(256 + 128, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            ResidualBlock(128),
            SpatialAttention(128)
        )
        
        self.dec1 = nn.Sequential(
            nn.Conv2d(128 + 64, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            ResidualBlock(64),
            SpatialAttention(64)
        )
        
        self.final = nn.Sequential(
            nn.Conv2d(64, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(32, 3, 1),
            nn.Sigmoid()
        )
        
        self.pool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        
    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))
        
        b = self.bottleneck(self.pool(e4))
        
        d4 = self.dec4(torch.cat([self.upsample(b), e4], dim=1))
        d3 = self.dec3(torch.cat([self.upsample(d4), e3], dim=1))
        d2 = self.dec2(torch.cat([self.upsample(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.upsample(d2), e1], dim=1))
        
        return self.final(d1)

In [None]:
#Data preparation
to_tensor = transforms.Compose([transforms.ToTensor(),])

dataset = DenoisingDataset('data2/clean', 'data2/noisy')
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, pin_memory=True, num_workers=2)

In [None]:
model = EnhancedUNet().to(DEVICE)
criterion = nn.L1Loss()  
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)  

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)

#VRAM memory management optimization
scaler = torch.cuda.amp.GradScaler('cuda')

#Evaluation metrics
psnr = PeakSignalNoiseRatio().to(DEVICE)
ssim = StructuralSimilarityIndexMeasure().to(DEVICE)

In [None]:
def train_epoch(model, train_loader, criterion, optimizer, scaler):
    model.train()
    total_loss = 0
    total_psnr = 0
    total_ssim = 0
    
    for noisy, clean in tqdm(train_loader):
        noisy, clean = noisy.to(DEVICE), clean.to(DEVICE)
        
        optimizer.zero_grad()
        
        with torch.amp.autocast('cuda'):
            output = model(noisy)
            loss = criterion(output, clean)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        total_loss += loss.item()
        
        with torch.no_grad():
            total_psnr += psnr(output, clean)
            total_ssim += ssim(output, clean)
        
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    return total_loss / len(train_loader), total_psnr / len(train_loader), total_ssim / len(train_loader)

@torch.no_grad()
def validate(model, val_loader, criterion):
    model.eval()
    total_loss = 0
    total_psnr = 0
    total_ssim = 0
    
    for noisy, clean in val_loader:
        noisy, clean = noisy.to(DEVICE), clean.to(DEVICE)
        
        with torch.cuda.amp.autocast():
            output = model(noisy)
            loss = criterion(output, clean)
        
        total_loss += loss.item()
        total_psnr += psnr(output, clean)
        total_ssim += ssim(output, clean)
        
        #Clear CUDA cache to manage memory
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    return total_loss / len(val_loader), total_psnr / len(val_loader), total_ssim / len(val_loader)

In [None]:
def create_gaussian_window(patch_size=256):
    #Create a 1D Gaussian window
    window = windows.gaussian(patch_size, patch_size/6)

    #Convert to 2D Gaussian window
    window_2d = np.outer(window, window)
    window_2d = window_2d / window_2d.max()
    
    return window_2d

def create_patches(image_path, patch_size=256, overlap=26, output_dir=None, prefix=''):
    img = Image.open(image_path).convert('RGB')
    width, height = img.size
    
    stride = patch_size - overlap
    
    n_patches_w = (width - overlap) // stride
    n_patches_h = (height - overlap) // stride
    
    patches = []
    patch_positions = []
    
    window = create_gaussian_window(patch_size)
    
    #Extract patches
    for i in range(n_patches_h):
        for j in range(n_patches_w):
            left = j * stride
            top = i * stride
            right = left + patch_size
            bottom = top + patch_size
            
            if right > width or bottom > height:
                continue
            
            patch = img.crop((left, top, right, bottom))
            
            if output_dir:                
                patch_name = f'{prefix}patch_{i}_{j}.png'
                patch_path = os.path.join(output_dir, patch_name)
                patch.save(patch_path)
            
            patches.append(patch)
            patch_positions.append((left, top))
    
    return patches, patch_positions, (n_patches_h, n_patches_w), (width, height), window

def reconstruct_image(patches, patch_positions, original_size, window, patch_size=256):
    original_width, original_height = original_size
    
    #Create a blank image for reconstruction
    reconstructed = np.zeros((original_height, original_width, 3))
    weights = np.zeros((original_height, original_width))
    
    for patch, (left, top) in zip(patches, patch_positions):
        patch_array = np.array(patch)
        
        #Apply the Gaussian window
        for c in range(3):  
            reconstructed[top:top+patch_size, left:left+patch_size, c] += \
                patch_array[:, :, c] * window
        
        #Accumulate weights
        weights[top:top+patch_size, left:left+patch_size] += window
    
    weights = np.maximum(weights, 1e-10)
    for c in range(3):
        reconstructed[:, :, c] /= weights
    
    reconstructed = np.clip(reconstructed, 0, 255).astype(np.uint8)
    reconstructed = Image.fromarray(reconstructed)
    
    return reconstructed

def evaluate_full_images(model, device, processing_info_path='data2/processing_info.json'):
    model.eval()
    
    if not os.path.exists('final_results'):
        os.makedirs('final_results')
    
    with open(processing_info_path, 'r') as f:
        processing_info = json.load(f)
    
    results = []
    transform = transforms.ToTensor()
    
    for img_name, info in processing_info.items():
        print(f'Processing image:  {img_name}...')
        original_size = tuple(info['original_size'])
        
        base_path = info['noisy_path']
        clean_path = info['clean_path']
        
        patches_noisy, positions, _, _, window = create_patches(base_path, output_dir=None)
        patches_clean, _, _, _, _ = create_patches(clean_path, output_dir=None)
 
        denoised_patches = []
        with torch.no_grad():
            for patch in patches_noisy:
                patch_tensor = transform(patch).unsqueeze(0).to(device)
                denoised_patch = model(patch_tensor)
                denoised_patch = transforms.ToPILImage()(denoised_patch.squeeze().cpu())
                denoised_patches.append(denoised_patch)
        
        #Reconstruct the full images
        noisy_full = reconstruct_image(patches_noisy, positions, original_size, window)
        clean_full = reconstruct_image(patches_clean, positions, original_size, window)
        denoised_full = reconstruct_image(denoised_patches, positions, original_size, window)
        
        base_name = os.path.splitext(os.path.basename(base_path))[0]
        noisy_save_path = os.path.join('final_results', f'{base_name}_noisy.png')
        clean_save_path = os.path.join('final_results', f'{base_name}_clean.png')
        denoised_save_path = os.path.join('final_results', f'{base_name}_denoised.png')
        
        noisy_full.save(noisy_save_path)
        clean_full.save(clean_save_path)
        denoised_full.save(denoised_save_path)
        
        #Calculate metrics
        clean_tensor = transform(clean_full).unsqueeze(0).to(device)
        denoised_tensor = transform(denoised_full).unsqueeze(0).to(device)
        
        psnr_val = psnr(denoised_tensor, clean_tensor)
        ssim_val = ssim(denoised_tensor, clean_tensor)
        
        results.append({
            'image_name': base_name,
            'psnr': psnr_val.item(),
            'ssim': ssim_val.item(),
            'paths': {
                'noisy': noisy_save_path,
                'clean': clean_save_path,
                'denoised': denoised_save_path
            }
        })
        
        print(f'PSNR: {psnr_val:.2f}, SSIM: {ssim_val:.4f}')
    
    with open('final_results/metrics.json', 'w') as f:
        json.dump(results, f, indent=4)
    
    return results

In [None]:
#-------------------- Training loop --------------------
best_val_loss = float('inf')
train_losses = []
val_losses = []
train_psnrs = []
val_psnrs = []
train_ssims = []
val_ssims = []

for epoch in range(EPOCHS):
    train_loss, train_psnr_val, train_ssim_val = train_epoch(model, train_loader, criterion, optimizer, scaler)
    val_loss, val_psnr_val, val_ssim_val = validate(model, val_loader, criterion)
    
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_psnrs.append(train_psnr_val)
    val_psnrs.append(val_psnr_val)
    train_ssims.append(train_ssim_val)
    val_ssims.append(val_ssim_val)
    
    print(f'Epoch {epoch+1}/{EPOCHS}:')
    print(f'Train Loss: {train_loss:.6f}, Train PSNR: {train_psnr_val:.2f}, Train SSIM: {train_ssim_val:.4f}')
    print(f'Val Loss: {val_loss:.6f}, Val PSNR: {val_psnr_val:.2f}, Val SSIM: {val_ssim_val:.4f}')
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
        }, f'{CHECKPOINT_DIR}/best_model.pth')
    
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'val_loss': val_loss,
    }, f'{CHECKPOINT_DIR}/checkpoint_epoch_{epoch+1}.pth')
                
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()


print('\nTraining completed, evaluating on full images...')
final_results = evaluate_full_images(model, DEVICE)

def show_final_results(results, num_samples=4):
    num_samples = min(num_samples, len(results))
    fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5*num_samples))
    
    for i, result in enumerate(results[:num_samples]):
        noisy = Image.open(result['paths']['noisy'])
        clean = Image.open(result['paths']['clean'])
        denoised = Image.open(result['paths']['denoised'])
        
        axes[i, 0].imshow(noisy)
        axes[i, 0].set_title('Noisy')
        axes[i, 0].axis('off')
        
        axes[i, 1].imshow(denoised)
        axes[i, 1].set_title(f'Denoised\nPSNR: {result["psnr"]:.2f}, SSIM: {result["ssim"]:.4f}')
        axes[i, 1].axis('off')
        
        axes[i, 2].imshow(clean)
        axes[i, 2].set_title('Clean')
        axes[i, 2].axis('off')
    
    plt.tight_layout()
    plt.show()

show_final_results(final_results)

In [None]:
#Load and evaluate the best model
import os
import torch
import matplotlib.pyplot as plt
from PIL import Image
import json
from torchvision import transforms

BEST_MODEL_PATH = os.path.join(CHECKPOINT_DIR, 'best_model.pth')

if os.path.exists(BEST_MODEL_PATH):
    print(f'Loading best model from {BEST_MODEL_PATH}')
    
    best_model = EnhancedUNet().to(DEVICE)
    
    #Load the best model checkpoint
    checkpoint = torch.load(BEST_MODEL_PATH, map_location=DEVICE)
    best_model.load_state_dict(checkpoint['model_state_dict'])
    best_epoch = checkpoint['epoch']
    best_val_loss = checkpoint['val_loss']
    
    print(f'Best model from epoch {best_epoch+1} with validation loss {best_val_loss:.6f}')
    
    best_model.eval()
    
    BEST_RESULTS_DIR = 'best_model_results'
    if not os.path.exists(BEST_RESULTS_DIR):
        os.makedirs(BEST_RESULTS_DIR)
    
    #Evaluate the best model on full images
    print('\nEvaluating the best model with full images...')
    best_results = evaluate_full_images(best_model, DEVICE)
    
    for result in best_results:
        base_name = os.path.basename(result['paths']['denoised'])
        new_path = os.path.join(BEST_RESULTS_DIR, base_name)
        img = Image.open(result['paths']['denoised'])
        img.save(new_path)
        
        result['paths']['best_denoised'] = new_path
    
    with open(os.path.join(BEST_RESULTS_DIR, 'best_model_metrics.json'), 'w') as f:
        json.dump(best_results, f, indent=4)
    
    print('\nShowing best model results:')
    show_final_results(best_results)
    
    #Compare best model with final model
    def compare_models_results(final_results, best_results, num_samples=2):
        num_samples = min(num_samples, len(final_results), len(best_results))
        fig, axes = plt.subplots(num_samples, 4, figsize=(20, 5*num_samples))
        
        for i in range(num_samples):
            final_result = final_results[i]
            best_result = best_results[i]
            
            noisy = Image.open(final_result['paths']['noisy'])
            clean = Image.open(final_result['paths']['clean'])
            final_denoised = Image.open(final_result['paths']['denoised'])
            best_denoised = Image.open(best_result['paths']['denoised'])
            
            axes[i, 0].imshow(noisy)
            axes[i, 0].set_title('Noisy')
            axes[i, 0].axis('off')
            
            axes[i, 1].imshow(best_denoised)
            axes[i, 1].set_title(f'Best Model\nPSNR: {best_result["psnr"]:.2f}, SSIM: {best_result["ssim"]:.4f}')
            axes[i, 1].axis('off')
            
            axes[i, 2].imshow(final_denoised)
            axes[i, 2].set_title(f'Final Model\nPSNR: {final_result["psnr"]:.2f}, SSIM: {final_result["ssim"]:.4f}')
            axes[i, 2].axis('off')
            
            axes[i, 3].imshow(clean)
            axes[i, 3].set_title('Clean')
            axes[i, 3].axis('off')
        
        plt.tight_layout()
        plt.savefig(os.path.join(BEST_RESULTS_DIR, 'model_comparison.png'))
        plt.show()
    
    print('\nComparing best model with final model:')
    compare_models_results(final_results, best_results)
    
    #Calculate and print average metrics
    final_avg_psnr = sum(r['psnr'] for r in final_results) / len(final_results)
    final_avg_ssim = sum(r['ssim'] for r in final_results) / len(final_results)
    best_avg_psnr = sum(r['psnr'] for r in best_results) / len(best_results)
    best_avg_ssim = sum(r['ssim'] for r in best_results) / len(best_results)
    
    print('\nAverage Metrics:')
    print(f'Final Model - PSNR: {final_avg_psnr:.2f}, SSIM: {final_avg_ssim:.4f}')
    print(f'Best Model - PSNR: {best_avg_psnr:.2f}, SSIM: {best_avg_ssim:.4f}')

    def test_external_images(model, device, test_dir='Test_images', output_dir='external_test_results'):
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        
        test_files = [f for f in os.listdir(test_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
        transform = transforms.Compose([
            transforms.ToTensor(),])
        
        results = []
        
        for img_file in test_files:
            img_path = os.path.join(test_dir, img_file)
            
            img = Image.open(img_path).convert('RGB')
            
            patches, positions, _, original_size, window = create_patches(img_path, output_dir=None)
            
            denoised_patches = []
            with torch.no_grad():
                for patch in patches:
                    patch_tensor = transform(patch).unsqueeze(0).to(device)
                    denoised_patch = model(patch_tensor)
                    denoised_patch = transforms.ToPILImage()(denoised_patch.squeeze().cpu())
                    denoised_patches.append(denoised_patch)
            
            denoised_full = reconstruct_image(denoised_patches, positions, original_size, window)

            base_name = os.path.splitext(img_file)[0]
            denoised_save_path = os.path.join(output_dir, f'{base_name}_denoised.png')
            
            denoised_full.save(denoised_save_path)
            
            results.append({
                'image_name': base_name,
                'original_path': img_path,
                'denoised_path': denoised_save_path})
        
        return results
    
    if os.path.exists('Test_images') and any(f.lower().endswith(('.png', '.jpg', '.jpeg')) for f in os.listdir('Test_images')):
        external_results = test_external_images(best_model, DEVICE)
        
        def show_external_results(results, num_samples=4):
            num_samples = min(num_samples, len(results))
            fig, axes = plt.subplots(num_samples, 2, figsize=(10, 5*num_samples))
            
            if num_samples == 1:
                axes = [axes]
            
            for i, result in enumerate(results[:num_samples]):
                noisy = Image.open(result['original_path'])
                denoised = Image.open(result['denoised_path'])
                
                axes[i, 0].imshow(noisy)
                axes[i, 0].set_title('Original')
                axes[i, 0].axis('off')
                
                axes[i, 1].imshow(denoised)
                axes[i, 1].set_title('Denoised')
                axes[i, 1].axis('off')
            
            plt.tight_layout()
            plt.savefig('external_test_results/external_results.png')
            plt.show()
        
        if external_results:
            show_external_results(external_results)
    
else:
    print(f'Error: Best model not found at {BEST_MODEL_PATH}')