In [None]:
# Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import os
import argparse
import math

import cv2
import numpy as np
import torch
from torch import nn, Tensor
from natsort import natsorted

# ============================== MODEL DEFINITION ==============================
class ESPCN(nn.Module):
    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            channels: int,
            upscale_factor: int,
    ) -> None:
        super(ESPCN, self).__init__()
        hidden_channels = channels // 2
        out_channels = int(out_channels * (upscale_factor ** 2))
        # Feature mapping
        self.feature_maps = nn.Sequential(
            nn.Conv2d(in_channels, channels, (5, 5), (1, 1), (2, 2)),
            nn.Tanh(),
            nn.Conv2d(channels, hidden_channels, (3, 3), (1, 1), (1, 1)),
            nn.Tanh(),
        )
        # Sub-pixel convolution layer
        self.sub_pixel = nn.Sequential(
            nn.Conv2d(hidden_channels, out_channels, (3, 3), (1, 1), (1, 1)),
            nn.PixelShuffle(upscale_factor),
        )
        # Initial model weights
        for module in self.modules():
            if isinstance(module, nn.Conv2d):
                if module.in_channels == 32:
                    nn.init.normal_(module.weight.data,
                                    0.0,
                                    0.001)
                    nn.init.zeros_(module.bias.data)
                else:
                    nn.init.normal_(module.weight.data,
                                    0.0,
                                    math.sqrt(2 / (module.out_channels * module.weight.data[0][0].numel())))
                    nn.init.zeros_(module.bias.data)

    def forward(self, x: Tensor) -> Tensor:
        return self._forward_impl(x)

    # Support torch.script function.
    def _forward_impl(self, x: Tensor) -> Tensor:
        x = self.feature_maps(x)
        x = self.sub_pixel(x)
        x = torch.clamp_(x, 0.0, 1.0)
        return x

def espcn_x2(**kwargs) -> ESPCN:
    model = ESPCN(upscale_factor=2, **kwargs)
    return model

def espcn_x3(**kwargs) -> ESPCN:
    model = ESPCN(upscale_factor=3, **kwargs)
    return model

def espcn_x4(**kwargs) -> ESPCN:
    model = ESPCN(upscale_factor=4, **kwargs)
    return model

def espcn_x8(**kwargs) -> ESPCN:
    model = ESPCN(upscale_factor=8, **kwargs)
    return model

# Define available models
model_dict = {
    "espcn_x2": espcn_x2,
    "espcn_x3": espcn_x3,
    "espcn_x4": espcn_x4,
    "espcn_x8": espcn_x8
}

# ============================== IMAGE PROCESSING FUNCTIONS ==============================
def make_directory(dir_path):
    """Create directory.

    Args:
        dir_path (str): Directory path to create.
    """
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)

def preprocess_one_image(image_path, device):
    """Preprocess a single image and convert it to a tensor.

    Args:
        image_path (str): Image file path.
        device (torch.device): Device to use.

    Returns:
        y_tensor (torch.Tensor): Y channel tensor.
        cb_image (numpy.ndarray): CB channel image.
        cr_image (numpy.ndarray): CR channel image.
    """
    # Read image
    image = cv2.imread(image_path, cv2.IMREAD_COLOR)
    
    # Convert BGR to YCbCr
    image = cv2.cvtColor(image, cv2.COLOR_BGR2YCrCb)
    
    # Split Y, Cb, and Cr channels
    y_image, cb_image, cr_image = cv2.split(image)
    
    # Normalize Y channel to [0, 1]
    y_image = y_image.astype(np.float32) / 255.0
    
    # Reshape Y channel for model input
    y_tensor = torch.from_numpy(y_image).to(device)
    y_tensor = y_tensor.unsqueeze(0).unsqueeze(0)
    
    return y_tensor, cb_image, cr_image

def tensor_to_image(tensor, range_norm=False, half=False):
    """Convert tensor to image.

    Args:
        tensor (torch.Tensor): Tensor to convert.
        range_norm (bool): Whether to normalize to [0, 1].
        half (bool): Whether the tensor is in half precision.

    Returns:
        image (numpy.ndarray): Converted image.
    """
    if range_norm:
        tensor = tensor.detach().cpu().float().clamp_(0, 1)
    if half:
        tensor = tensor.detach().cpu().float()
    
    tensor = tensor.squeeze(0).squeeze(0).numpy()
    
    return tensor

def ycbcr_to_bgr(image):
    """Convert YCbCr image to BGR format.

    Args:
        image (numpy.ndarray): YCbCr image.

    Returns:
        bgr_image (numpy.ndarray): BGR image.
    """
    # Convert YCbCr to BGR
    image = image.astype(np.float32)
    image = cv2.cvtColor(image, cv2.COLOR_YCrCb2BGR)
    
    return image

# ============================== IMAGE QUALITY ASSESSMENT ==============================
class PSNR(nn.Module):
    def __init__(self, upscale_factor=4, only_test_y_channel=True):
        super(PSNR, self).__init__()
        self.upscale_factor = upscale_factor
        self.only_test_y_channel = only_test_y_channel

    def forward(self, sr_tensor, hr_tensor):
        # Don't need to crop, already done during preprocessing
        mse = torch.mean((sr_tensor - hr_tensor) ** 2)
        return 10. * torch.log10(1. / mse)

class SSIM(nn.Module):
    def __init__(self, upscale_factor=4, only_test_y_channel=True):
        super(SSIM, self).__init__()
        self.upscale_factor = upscale_factor
        self.only_test_y_channel = only_test_y_channel
        self.window_size = 11
        self.size_average = True
        self.channel = 1
        self.sigma = 1.5
        self.register_buffer("window", self._create_window(self.window_size, self.channel, self.sigma))

    def _create_window(self, window_size, channel, sigma):
        """Create Gaussian window for SSIM calculation
        """
        _1D_window = self._gaussian(window_size, sigma).unsqueeze(1)
        _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
        window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
        return window

    @staticmethod
    def _gaussian(window_size, sigma):
        gauss = torch.Tensor([math.exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
        return gauss / gauss.sum()

    def forward(self, sr_tensor, hr_tensor):
        # Don't need to crop, already done during preprocessing
        c1 = (0.01 ** 2)
        c2 = (0.03 ** 2)

        window = self.window
        mu1 = nn.functional.conv2d(sr_tensor, window, padding=self.window_size // 2, groups=self.channel)
        mu2 = nn.functional.conv2d(hr_tensor, window, padding=self.window_size // 2, groups=self.channel)

        mu1_sq = mu1.pow(2)
        mu2_sq = mu2.pow(2)
        mu1_mu2 = mu1 * mu2

        sigma1_sq = nn.functional.conv2d(sr_tensor * sr_tensor, window, padding=self.window_size // 2, groups=self.channel) - mu1_sq
        sigma2_sq = nn.functional.conv2d(hr_tensor * hr_tensor, window, padding=self.window_size // 2, groups=self.channel) - mu2_sq
        sigma12 = nn.functional.conv2d(sr_tensor * hr_tensor, window, padding=self.window_size // 2, groups=self.channel) - mu1_mu2

        ssim_map = ((2 * mu1_mu2 + c1) * (2 * sigma12 + c2)) / ((mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2))

        if self.size_average:
            return ssim_map.mean()
        else:
            return ssim_map.mean(1).mean(1).mean(1)




In [None]:
# ============================== MAIN FUNCTION ==============================

def preprocess_one_image(image_path, device):
    """Preprocess a single image and convert it to a tensor.

    Args:
        image_path (str): Image file path.
        device (torch.device): Device to use.

    Returns:
        rgb_tensor (torch.Tensor): RGB tensor ready for processing.
        original_image (numpy.ndarray): Original RGB image.
    """
    # Check if file exists
    if not os.path.exists(image_path):
        raise FileNotFoundError(f"Image file not found: {image_path}")
        
    # Read image in BGR (OpenCV default)
    image = cv2.imread(image_path, cv2.IMREAD_COLOR)
    
    # Check if image was loaded successfully
    if image is None:
        raise ValueError(f"Failed to load image: {image_path}")
    
    # Convert BGR to RGB (OpenCV loads as BGR)
    original_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    # Save a copy of the original image for debugging
    print(f"Input image shape: {original_image.shape}, min: {original_image.min()}, max: {original_image.max()}")
    
    # Normalize RGB channels to [0, 1]
    rgb_image = original_image.astype(np.float32) / 255.0
    
    # Transpose image from (H, W, C) to (C, H, W) for PyTorch
    rgb_image = np.transpose(rgb_image, (2, 0, 1))
    
    # Convert to tensor
    rgb_tensor = torch.from_numpy(rgb_image).to(device)
    rgb_tensor = rgb_tensor.unsqueeze(0)  # Add batch dimension
    
    # Debug tensor values
    print(f"Input tensor shape: {rgb_tensor.shape}, min: {rgb_tensor.min().item()}, max: {rgb_tensor.max().item()}")
    
    return rgb_tensor, original_image

def main() -> None:
    # HARDCODED PATHS - Change these to your specific locations
    model_weights_path = r" "
    lr_dir = r" "
    sr_dir = r" "
    gt_dir = r" "
    model_arch_name = "espcn_x4"  # Choose from: espcn_x2, espcn_x3, espcn_x4, espcn_x8
    
    # IMPORTANT: Using original parameters for now since the model was trained for Y-channel
    # We'll use a hybrid approach - process each channel separately
    in_channels = 1
    out_channels = 1
    channels = 64
    upscale_factor = 4
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    only_test_y_channel = True  # Keep this true for evaluation with the original model
    
    # Print selected parameters
    print(f"Model architecture: {model_arch_name}")
    print(f"Model weights: {model_weights_path}")
    print(f"Input directory: {lr_dir}")
    print(f"Output directory: {sr_dir}")
    print(f"Ground truth directory: {gt_dir}")
    print(f"Device: {device}")
    print(f"Upscale factor: {upscale_factor}")
    print(f"Processing in RGB mode with channel-by-channel approach")

    # Verify model weights exist
    if not os.path.exists(model_weights_path):
        raise FileNotFoundError(f"Model weights not found: {model_weights_path}")
        
    # Verify input directory/file exists
    if not os.path.exists(lr_dir):
        raise FileNotFoundError(f"Input directory/file not found: {lr_dir}")

    # Initialize the super-resolution model
    g_model = model_dict[model_arch_name](
        in_channels=in_channels,
        out_channels=out_channels,
        channels=channels
    )
    g_model = g_model.to(device=device)
    print(f"Build `{model_arch_name}` model successfully.")

    # Load the super-resolution model weights
    checkpoint = torch.load(model_weights_path, map_location=lambda storage, loc: storage)
    g_model.load_state_dict(checkpoint["state_dict"])
    print(f"Load `{model_arch_name}` model weights "
          f"`{os.path.abspath(model_weights_path)}` successfully.")

    # Create a folder of super-resolution experiment results
    make_directory(sr_dir)

    # Start the verification mode of the model.
    g_model.eval()

    # Initialize the sharpness evaluation function for Y channel only
    psnr = PSNR(upscale_factor, only_test_y_channel)
    ssim = SSIM(upscale_factor, only_test_y_channel)

    # Set the sharpness evaluation function calculation device to the specified model
    psnr = psnr.to(device=device, non_blocking=True)
    ssim = ssim.to(device=device, non_blocking=True)

    # Initialize IQA metrics
    psnr_metrics = 0.0
    ssim_metrics = 0.0

    # Check if input is a directory or single file
    if os.path.isdir(lr_dir):
        # Get a list of test image file names.
        file_names = natsorted(os.listdir(lr_dir))
        # Get the number of test image files.
        total_files = len(file_names)
        
        # If no files found
        if total_files == 0:
            raise ValueError(f"No files found in input directory: {lr_dir}")

        for index in range(total_files):
            lr_image_path = os.path.join(lr_dir, file_names[index])
            sr_image_path = os.path.join(sr_dir, file_names[index])
            gt_image_path = os.path.join(gt_dir, file_names[index])
            
            # Skip non-image files (optional)
            if not lr_image_path.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
                print(f"Skipping non-image file: {lr_image_path}")
                continue

            print(f"Processing `{os.path.abspath(lr_image_path)}`...")
            
            try:
                # Read input image directly with OpenCV
                lr_image = cv2.imread(lr_image_path, cv2.IMREAD_COLOR)
                if lr_image is None:
                    raise ValueError(f"Failed to load image: {lr_image_path}")
                
                # Convert BGR to RGB
                lr_image_rgb = cv2.cvtColor(lr_image, cv2.COLOR_BGR2RGB)
                
                # Save original dimension for reference
                original_h, original_w = lr_image_rgb.shape[:2]
                
                # Initialize array for super-resolved image
                sr_image_rgb = np.zeros((original_h * upscale_factor, original_w * upscale_factor, 3), dtype=np.uint8)
                
                # Process each channel separately
                for c in range(3):  # R, G, B channels
                    # Extract single channel and normalize
                    channel = lr_image_rgb[:, :, c].astype(np.float32) / 255.0
                    
                    # Convert to tensor
                    channel_tensor = torch.from_numpy(channel).to(device)
                    channel_tensor = channel_tensor.unsqueeze(0).unsqueeze(0)  # Add batch and channel dimensions
                    
                    # Process with model
                    with torch.no_grad():
                        sr_channel_tensor = g_model(channel_tensor)
                    
                    # Convert back to numpy and denormalize
                    sr_channel = sr_channel_tensor.squeeze().cpu().detach().numpy()
                    sr_channel = np.clip(sr_channel * 255.0, 0, 255).astype(np.uint8)
                    
                    # Store in the combined RGB image
                    sr_image_rgb[:, :, c] = sr_channel
                
                # Save the super-resolved RGB image (convert to BGR for OpenCV)
                sr_image_bgr = cv2.cvtColor(sr_image_rgb, cv2.COLOR_RGB2BGR)
                cv2.imwrite(sr_image_path, sr_image_bgr)
                
                # Debug output
                print(f"SR image saved with shape: {sr_image_rgb.shape}, min: {sr_image_rgb.min()}, max: {sr_image_rgb.max()}")
                
                # Process ground truth for evaluation if available
                if os.path.exists(gt_image_path):
                    # For proper evaluation we'd need to extract Y channel and compare
                    gt_image = cv2.imread(gt_image_path, cv2.IMREAD_COLOR)
                    gt_image_ycrcb = cv2.cvtColor(gt_image, cv2.COLOR_BGR2YCrCb)
                    sr_image_ycrcb = cv2.cvtColor(sr_image_bgr, cv2.COLOR_BGR2YCrCb)
                    
                    # Extract Y channels for evaluation
                    gt_y = gt_image_ycrcb[:, :, 0].astype(np.float32) / 255.0
                    sr_y = sr_image_ycrcb[:, :, 0].astype(np.float32) / 255.0
                    
                    # Resize ground truth to match SR if needed
                    if gt_y.shape != sr_y.shape:
                        gt_y = cv2.resize(gt_y, (sr_y.shape[1], sr_y.shape[0]))
                    
                    # Convert to tensors for evaluation
                    gt_y_tensor = torch.from_numpy(gt_y).to(device).unsqueeze(0).unsqueeze(0)
                    sr_y_tensor = torch.from_numpy(sr_y).to(device).unsqueeze(0).unsqueeze(0)
                    
                    # Calculate metrics
                    psnr_value = psnr(sr_y_tensor, gt_y_tensor).item()
                    ssim_value = ssim(sr_y_tensor, gt_y_tensor).item()
                    
                    print(f"PSNR: {psnr_value:4.2f} [dB]")
                    print(f"SSIM: {ssim_value:4.4f} [u]")
                    
                    psnr_metrics += psnr_value
                    ssim_metrics += ssim_value
                
            except Exception as e:
                print(f"Error processing image {lr_image_path}: {str(e)}")
                traceback.print_exc()  # Print full traceback
                continue

        # Calculate the average value of the sharpness evaluation index if any images were processed
        if total_files > 0 and psnr_metrics > 0:
            # PSNR range value is 0~100
            # SSIM range value is 0~1
            avg_psnr = 100 if psnr_metrics / total_files > 100 else psnr_metrics / total_files
            avg_ssim = 1 if ssim_metrics / total_files > 1 else ssim_metrics / total_files

            print(f"Average PSNR: {avg_psnr:4.2f} [dB]\n"
                  f"Average SSIM: {avg_ssim:4.4f} [u]")
    else:
        # Handle single file processing
        lr_image_path = lr_dir
        filename = os.path.basename(lr_image_path)
        sr_image_path = os.path.join(sr_dir, filename)
        
        print(f"Processing single image `{os.path.abspath(lr_image_path)}`...")
        
        try:
            # Read input image directly with OpenCV
            lr_image = cv2.imread(lr_image_path, cv2.IMREAD_COLOR)
            if lr_image is None:
                raise ValueError(f"Failed to load image: {lr_image_path}")
            
            # Convert BGR to RGB
            lr_image_rgb = cv2.cvtColor(lr_image, cv2.COLOR_BGR2RGB)
            
            # Save original dimension for reference
            original_h, original_w = lr_image_rgb.shape[:2]
            
            # Initialize array for super-resolved image
            sr_image_rgb = np.zeros((original_h * upscale_factor, original_w * upscale_factor, 3), dtype=np.uint8)
            
            # Process each channel separately
            for c in range(3):  # R, G, B channels
                # Extract single channel and normalize
                channel = lr_image_rgb[:, :, c].astype(np.float32) / 255.0
                
                # Convert to tensor
                channel_tensor = torch.from_numpy(channel).to(device)
                channel_tensor = channel_tensor.unsqueeze(0).unsqueeze(0)  # Add batch and channel dimensions
                
                # Process with model
                with torch.no_grad():
                    sr_channel_tensor = g_model(channel_tensor)
                
                # Convert back to numpy and denormalize
                sr_channel = sr_channel_tensor.squeeze().cpu().detach().numpy()
                sr_channel = np.clip(sr_channel * 255.0, 0, 255).astype(np.uint8)
                
                # Store in the combined RGB image
                sr_image_rgb[:, :, c] = sr_channel
            
            # Save the super-resolved RGB image (convert to BGR for OpenCV)
            sr_image_bgr = cv2.cvtColor(sr_image_rgb, cv2.COLOR_RGB2BGR)
            cv2.imwrite(sr_image_path, sr_image_bgr)
            
            print(f"Super-resolution image saved to `{os.path.abspath(sr_image_path)}`")
            
            # Check if ground truth is available
            gt_image_path = os.path.join(gt_dir, filename)
            if os.path.exists(gt_image_path):
                # For proper evaluation we'd need to extract Y channel and compare
                gt_image = cv2.imread(gt_image_path, cv2.IMREAD_COLOR)
                gt_image_ycrcb = cv2.cvtColor(gt_image, cv2.COLOR_BGR2YCrCb)
                sr_image_ycrcb = cv2.cvtColor(sr_image_bgr, cv2.COLOR_BGR2YCrCb)
                
                # Extract Y channels for evaluation
                gt_y = gt_image_ycrcb[:, :, 0].astype(np.float32) / 255.0
                sr_y = sr_image_ycrcb[:, :, 0].astype(np.float32) / 255.0
                
                # Resize ground truth to match SR if needed
                if gt_y.shape != sr_y.shape:
                    gt_y = cv2.resize(gt_y, (sr_y.shape[1], sr_y.shape[0]))
                
                # Convert to tensors for evaluation
                gt_y_tensor = torch.from_numpy(gt_y).to(device).unsqueeze(0).unsqueeze(0)
                sr_y_tensor = torch.from_numpy(sr_y).to(device).unsqueeze(0).unsqueeze(0)
                
                # Calculate metrics
                psnr_value = psnr(sr_y_tensor, gt_y_tensor).item()
                ssim_value = ssim(sr_y_tensor, gt_y_tensor).item()
                
                print(f"PSNR: {psnr_value:4.2f} [dB]")
                print(f"SSIM: {ssim_value:4.4f} [u]")
        
        except Exception as e:
            print(f"Error processing image: {str(e)}")
            traceback.print_exc()  # Print full traceback

In [None]:
if __name__ == "__main__":
    try:
        main()
    except Exception as e:
        print(f"Error: {str(e)}")

In [None]:
import os
import cv2
import numpy as np

def restore_ycbcr_to_rgb(input_folder, output_folder):
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)
    
    for filename in os.listdir(input_folder):
        if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
            image_path = os.path.join(input_folder, filename)
            output_path = os.path.join(output_folder, filename)
            
            # Read image
            img = cv2.imread(image_path)
            if img is None:
                continue
            
            # Convert from YUV to RGB (fixing green tint issue)
            restored_img = cv2.cvtColor(img, cv2.COLOR_YUV2RGB)
            
            # Save the restored image
            cv2.imwrite(output_path, cv2.cvtColor(restored_img, cv2.COLOR_RGB2BGR))
            print(f"Restored: {filename} -> {output_path}")

# Example usage
input_folder = r" "  # Change this to your folder path
output_folder = r" "
restore_ycbcr_to_rgb(input_folder, output_folder)
