In [None]:
import os
import pywt
import numpy as np
import tensorflow as tf
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image, ImageDraw, ImageFont
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
from scipy import stats
from scipy.ndimage import laplace

# Constants
NOISE_LEVEL = 0.9  # Dose reduction level
MAX_PHOTON_COUNT = 5e3
OUTPUT_DIR = 'processed_images'
CHEST_SAMPLES_DIR = 'chest_samples'
FONT_PATH = '/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf'
FONT_SIZE = 18
IMG_SIZE = (512, 512)  # Define standard image size to match model's expected input



# Ensure output directory exists
os.makedirs(OUTPUT_DIR, exist_ok=True)

def compute_image_quality_metrics(original, processed, noisy=None):
    """
    Compute comprehensive image quality metrics.
    """
    metrics = {}
    
    # Basic metrics
    metrics['psnr'] = psnr(original, processed, data_range=1.0)
    metrics['ssim'] = ssim(original, processed, data_range=1.0)
    
    # Edge preservation ratio
    def compute_epr(img1, img2):
        # Using numpy's gradient for edge detection instead of Laplace
        grad1_x, grad1_y = np.gradient(img1)
        grad2_x, grad2_y = np.gradient(img2)
        
        numerator = np.sum(grad1_x * grad2_x + grad1_y * grad2_y)
        denominator = np.sqrt(np.sum(grad1_x**2 + grad1_y**2) * np.sum(grad2_x**2 + grad2_y**2))
        
        return numerator / denominator if denominator != 0 else 0.0
    
    metrics['epr'] = compute_epr(original, processed)
    
    # Noise reduction ratio if noisy image is provided
    if noisy is not None:
        noise_before = np.mean((original - noisy)**2)
        noise_after = np.mean((original - processed)**2)
        metrics['noise_reduction_ratio'] = noise_before / noise_after if noise_after != 0 else float('inf')
    
    # Calculate noise power spectrum
    def compute_nps(img):
        return np.abs(np.fft.fft2(img))**2
        
    nps_original = compute_nps(original)
    nps_processed = compute_nps(processed)
    metrics['nps_correlation'] = stats.pearsonr(nps_original.flatten(), nps_processed.flatten())[0]
    
    return metrics

def add_noise(clean_img, dose_reduction=0.9, max_photon_count=5e3):
    """
    Add Poisson noise to simulate X-ray noise.
    
    Parameters:
    clean_img (ndarray): Input clean image
    dose_reduction (float): Factor by which to reduce dose (0-1)
    max_photon_count (float): Maximum number of photons
    
    Returns:
    ndarray: Noisy image
    """
    # Scale image to photon counts
    scaling_factor = 1.0 - dose_reduction
    scaled_photon_counts = clean_img * max_photon_count * scaling_factor
    
    # Add Poisson noise
    noisy_photon_counts = np.random.poisson(scaled_photon_counts)
    
    # Convert back to image scale
    noisy_img = noisy_photon_counts / (max_photon_count * scaling_factor)
    return np.clip(noisy_img, 0, 1)


@tf.function
def ssim_loss(y_true, y_pred):
    return 1 - tf.reduce_mean(tf.image.ssim(y_true, y_pred, max_val=1.0))

def print_quality_analysis(metrics):
    """Print detailed analysis of image quality metrics"""
    print("\nDetailed Image Quality Analysis:")
    print("-" * 50)
    print(f"PSNR: {metrics['psnr']:.2f} dB")
    print(f"SSIM: {metrics['ssim']:.4f}")
    print(f"Edge Preservation Ratio: {metrics['epr']:.4f}")
    if 'noise_reduction_ratio' in metrics:
        print(f"Noise Reduction Ratio: {metrics['noise_reduction_ratio']:.2f}x")
    print(f"NPS Correlation: {metrics['nps_correlation']:.4f}")
    print("-" * 50)

def main():


    # Process all images
    print(f"Processing images in {CHEST_SAMPLES_DIR}...")
    image_files = [f for f in os.listdir(CHEST_SAMPLES_DIR) 
                  if f.lower().endswith(('.png', '.jpg', '.jpeg'))]

    for image_file in tqdm(image_files):
        image_path = os.path.join(CHEST_SAMPLES_DIR, image_file)
        
        # Load and preprocess image
        original_img = Image.open(image_path).convert('L')
        original_img = original_img.resize(IMG_SIZE, Image.LANCZOS)
        original_img = np.array(original_img, dtype=np.float32) / 255.0

        # Add X-ray noise
        noisy_img = add_noise(
            original_img, 
            dose_reduction=NOISE_LEVEL, 
            max_photon_count=MAX_PHOTON_COUNT
        )

        # Compute features
        fft_feat = compute_fft_features(noisy_img)
        wavelet_feat_1 = compute_wavelet_features(noisy_img, WAVELET_TYPE_1, DECOMPOSITION_LEVEL)
        wavelet_feat_2 = compute_wavelet_features(noisy_img, WAVELET_TYPE_2, DECOMPOSITION_LEVEL)
        wavelet_feat_3 = compute_wavelet_features(noisy_img, WAVELET_TYPE_3, DECOMPOSITION_LEVEL)
        wavelet_feat_4 = compute_wavelet_features(noisy_img, WAVELET_TYPE_4, DECOMPOSITION_LEVEL)

        # Predict denoised image
        predicted_img = model.predict([
            noisy_img[tf.newaxis, ..., tf.newaxis],
            fft_feat[tf.newaxis, ..., tf.newaxis],
            wavelet_feat_1[tf.newaxis, ..., tf.newaxis],
            wavelet_feat_2[tf.newaxis, ..., tf.newaxis],
            wavelet_feat_3[tf.newaxis, ..., tf.newaxis],
            wavelet_feat_4[tf.newaxis, ..., tf.newaxis]
        ])[0, ..., 0]

        # Compute quality metrics
        metrics = compute_image_quality_metrics(
            original=original_img,
            processed=predicted_img,
            noisy=noisy_img
        )
        
        # Print analysis
        print(f"\nAnalysis for {image_file}:")
        print_quality_analysis(metrics)

        # Create combined visualization
        combined_width = IMG_SIZE[0] * 3 + 120  # Increased spacing for larger images
        combined_height = IMG_SIZE[1]

        combined_img = Image.new("RGB", (combined_width, combined_height + 50), "white")

        # Convert images to PIL format
        noisy_pil = Image.fromarray((noisy_img * 255).astype(np.uint8)).convert("RGB")
        predicted_pil = Image.fromarray((predicted_img * 255).astype(np.uint8)).convert("RGB")
        original_pil = Image.fromarray((original_img * 255).astype(np.uint8)).convert("RGB")

        # Paste images
        combined_img.paste(noisy_pil, (0, 0))
        combined_img.paste(predicted_pil, (IMG_SIZE[0] + 20, 0))
        combined_img.paste(original_pil, (2 * (IMG_SIZE[0] + 20), 0))

        # Add text
        draw = ImageDraw.Draw(combined_img)
        try:
            font = ImageFont.truetype(FONT_PATH, FONT_SIZE)
        except:
            font = ImageFont.load_default()

        # Create metrics text
        metrics_text = [
            "-" * 50,
            f"PSNR: {metrics['psnr']:.2f} dB",
            f"SSIM: {metrics['ssim']:.4f}",
            f"Edge Preservation Ratio: {metrics['epr']:.4f}",
            f"Noise Reduction Ratio: {metrics['noise_reduction_ratio']:.2f}x",
            f"NPS Correlation: {metrics['nps_correlation']:.4f}",
            "-" * 50
        ]
        
        # Calculate text height needed
        line_height = FONT_SIZE + 2
        text_height = len(metrics_text) * line_height
        
        # Adjust combined image height to accommodate metrics
        new_combined_height = combined_height + text_height + 20
        new_combined_img = Image.new("RGB", (combined_width, new_combined_height), "white")
        new_combined_img.paste(combined_img, (0, 0))
        
        # Add all metrics text
        draw = ImageDraw.Draw(new_combined_img)
        for i, line in enumerate(metrics_text):
            y_position = combined_height + 10 + (i * line_height)
            draw.text((10, y_position), line, fill="black", font=font)

        # Save combined image
        combined_save_path = os.path.join(OUTPUT_DIR, f"{os.path.splitext(image_file)[0]}_combined.png")
        new_combined_img.save(combined_save_path, quality=95, dpi=(300, 300))

    print(f"\nProcessing completed. Combined images saved to {OUTPUT_DIR}.")

if __name__ == "__main__":
    main()