In [1]:
import os
import cv2
import tensorflow as tf
import numpy as np
from image_utils import LoadImagesFromFolder, GetEngancedRGB, GetRGBFrame
import matplotlib.pyplot as plt
from QECNNYUV import EnhancerModel
from improved_qecnn import ImprovedEnhancerModel

2024-12-11 13:48:34.757558: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


INFO:tensorflow:Mixed precision compatibility check (mixed_float16): OK
Your GPU will likely run quickly with dtype policy mixed_float16 as it has compute capability of at least 7.0. Your GPU: NVIDIA A100-SXM4-80GB, compute capability 8.0


In [None]:
PAtCH_SIZE = 40
WIDTH = 480
HEIGHT = 320
FRAMES_MAX = 20
TEST_RAW_YUV = './test_data/testrawyuv/'
TEST_COMP_YUV = './test_data/testcompyuv/'
SAVE_FOLDER = 'test_results'

In [3]:
def cal_psnr(img_orig, img_out):
    squared_error = np.square(img_orig - img_out)
    mse = np.mean(squared_error)
    psnr = 10 * np.log10(1.0 / mse)
    return psnr

def load_models(width, height):
    if os.path.exists('enhancer_original.weights.h5'):
            print("Loading pre-trained original enhancer model weights...")
            enhancer = EnhancerModel(width, height)
            enhancer.load_weights('enhancer_original.weights.h5')
    else:
        print("Pre-trained enhancer model not found. Please train the model first.")
        return None, None
    if os.path.exists('improved_enhancer.weights.h5'):
        print("Loading pre-trained improved enhancer model weights...")
        improved_enhancer = ImprovedEnhancerModel(width, height)
        improved_enhancer.load_weights('improved_enhancer.weights.h5')
    else:
        print("Pre-trained improved enhancer model not found. Please train the model first.")
        return None, None
    return enhancer, improved_enhancer

def create_directories(save_folder):
    """Create subfolders for saving images."""
    compressed_path = os.path.join(save_folder, 'compressed')
    enhanced_path = os.path.join(save_folder, 'enhanced')
    improved_path = os.path.join(save_folder, 'improved')
    os.makedirs(compressed_path, exist_ok=True)
    os.makedirs(enhanced_path, exist_ok=True)
    os.makedirs(improved_path, exist_ok=True)
    return compressed_path, enhanced_path, improved_path

def process_frame(folder, video_index, frame_index, w, h):
    """Load and normalize RGB frame from YUV video file."""
    r, g, b = GetRGBFrame(folder, video_index, frame_index, w, h)
    rgb_frame = np.zeros((h, w, 3), dtype=np.float32)
    rgb_frame[:, :, 0] = r / 255.0  
    rgb_frame[:, :, 1] = g / 255.0  
    rgb_frame[:, :, 2] = b / 255.0  
    return rgb_frame

def process_video(fullname, fw, fh, framesmax):
    """Determine the number of frames in the video."""
    with open(fullname, 'rb') as fp:
        fp.seek(0, 2)  
        size = fp.tell()
    frames = (2 * size) // (fw * fh * 3)
    return min(frames, framesmax)

def calculate_psnr_for_frame(raw_frame, compressed_frame, enhancer_1, enhancer_2):
    """Calculate PSNR for compressed, enhanced, and improved frames."""
    psnr_comp = cal_psnr(raw_frame, compressed_frame)
    enhanced_frame = GetEngancedRGB(compressed_frame, enhancer_1)
    enhanced_frame = np.clip(enhanced_frame, 0, 1)  
    psnr_enh = cal_psnr(raw_frame, enhanced_frame)
    improved_frame = GetEngancedRGB(compressed_frame, enhancer_2)
    improved_frame = np.clip(improved_frame, 0, 1)  
    psnr_enh_improved = cal_psnr(raw_frame, improved_frame)
    return psnr_comp, psnr_enh, psnr_enh_improved, enhanced_frame, improved_frame

def save_sample_frames(f, compressed_path, enhanced_path, improved_path, compressed_frame, enhanced_frame, improved_frame):
    """Save the first 3 frames for each video."""
    if f < 3:  # Save only the first 3 frames
        cv2.imwrite(os.path.join(compressed_path, f'frame_{f}.png'), (compressed_frame * 255).astype(np.uint8))
        cv2.imwrite(os.path.join(enhanced_path, f'frame_{f}.png'), (enhanced_frame * 255).astype(np.uint8))
        cv2.imwrite(os.path.join(improved_path, f'frame_{f}.png'), (improved_frame * 255).astype(np.uint8))

def plot_psnr_performance(PSNRCOMP, PSNRENH, PSNRENH_IMPROVED, name):
    """Plot and display PSNR performance for each frame."""
    ind = np.argsort(PSNRCOMP)
    plt.plot(np.array(PSNRCOMP)[ind], label='Compressed')
    plt.plot(np.array(PSNRENH)[ind], label='Enhanced')
    plt.plot(np.array(PSNRENH_IMPROVED)[ind], label='Enhanced (Improved)')
    plt.xlabel('Frame index')
    plt.ylabel('PSNR, dB')
    plt.grid()
    plt.legend()
    title = "%s PSNR = [%.2f, %.2f, %.2f] dB" % (name, np.mean(PSNRCOMP), np.mean(PSNRENH), np.mean(PSNRENH_IMPROVED))
    plt.title(title)
    plt.show()

def display_image(image, image_label, image_index, num_columns):
    plt.subplot(1, num_columns, image_index)
    plt.imshow(image)
    plt.title(f'{image_label}')
    plt.axis('off')

def display_image_comparison(RGBRAW, RGBCOMP, RGBENH, RGBENH_IMPROVED, frame_index):
    """
    Display a side-by-side comparison of Compressed, Enhanced, and Improved Enhanced images.
    Parameters:
    - RGBRAW: The ground-truth raw image (optional, can be included for comparison).
    - RGBCOMP: The compressed version of the image.
    - RGBENH: The enhanced version of the image.
    - RGBENH_IMPROVED: The improved enhanced version of the image.
    - frame_index: Index of the current frame (used in titles for clarity).
    """
    RGBRAW = np.clip(RGBRAW, 0, 1) if RGBRAW is not None else None
    RGBCOMP = np.clip(RGBCOMP, 0, 1)
    RGBENH = np.clip(RGBENH, 0, 1)
    RGBENH_IMPROVED = np.clip(RGBENH_IMPROVED, 0, 1)
    num_columns = 4 if RGBRAW is not None else 3  
    plt.figure(figsize=(4 * num_columns, 4))  
    image_index = 1
    if RGBRAW is not None:
        display_image(image=RGBRAW, image_label="RAW Frame", image_index=image_index, num_columns=num_columns)
        image_index += 1
    display_image(image=RGBCOMP, image_label="Compressed Frame", image_index=image_index, num_columns=num_columns)
    image_index += 1
    display_image(image=RGBENH, image_label="Enhanced Frame", image_index=image_index, num_columns=num_columns)
    image_index += 1
    display_image(image=RGBENH_IMPROVED, image_label="Enhanced (improved) Frame", image_index=image_index, num_columns=num_columns)
    plt.tight_layout()
    plt.show()


def show_psnr_performance(enhancer_1, enhancer_2, w, h, folderyuv, foldercomp, video_index, framesmax, fw, fh, save_folder):
    """Main function to evaluate the performance of image enhancement models."""
    PSNRCOMP, PSNRENH, PSNRENH_IMPROVED = [], [], []
    compressed_path, enhanced_path, improved_path = create_directories(save_folder)
    dir_list = os.listdir(folderyuv)
    v = 0
    for name in dir_list:
        fullname = os.path.join(folderyuv, name)
        if v != video_index:
            v += 1
            continue
        if fullname.endswith('.yuv'):
            frames = process_video(fullname, fw, fh, framesmax)
            for f in range(frames):
                raw_frame = process_frame(folderyuv, video_index, f, w, h)
                compressed_frame = process_frame(foldercomp, video_index, f, w, h)
                psnr_comp, psnr_enh, psnr_enh_improved, enhanced_frame, improved_frame = \
                    calculate_psnr_for_frame(raw_frame, compressed_frame, enhancer_1, enhancer_2)
                PSNRCOMP.append(psnr_comp)
                PSNRENH.append(psnr_enh)
                PSNRENH_IMPROVED.append(psnr_enh_improved)
                # display_image_comparison(raw_frame, compressed_frame, enhanced_frame, improved_frame, 1)
                save_sample_frames(f, compressed_path, enhanced_path, improved_path, 
                                   compressed_frame, enhanced_frame, improved_frame)
            
        break  # Process only one video to avoid unnecessary iterations
    
    plot_psnr_performance(PSNRCOMP, PSNRENH, PSNRENH_IMPROVED, name)