In [None]:
import os
import cv2
import numpy as np
import torch
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import lpips
from DISTS_pytorch import DISTS
from pytorch_fid import fid_score

# Initialize LPIPS and DISTS models and move them to GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
loss_fn_lpips = lpips.LPIPS(net='alex').to(device)
dists_model = DISTS().to(device)

def calculate_psnr_ssim(image1_path, image2_path):
    image1 = cv2.imread(image1_path)
    image2 = cv2.imread(image2_path)

    if image1 is None or image2 is None:
        raise FileNotFoundError(f"Could not open {image1_path} or {image2_path}")

    image1 = cv2.cvtColor(image1, cv2.COLOR_BGR2RGB)
    image2 = cv2.cvtColor(image2, cv2.COLOR_BGR2RGB)
    
    # Resize images to 256x256
    image1 = cv2.resize(image1, (256, 256))
    image2 = cv2.resize(image2, (256, 256))

    psnr_value = psnr(image1, image2)
    ssim_value = ssim(image1, image2, channel_axis=2)  # Use 'channel_axis' instead of 'multichannel'

    return psnr_value, ssim_value

def calculate_rmse(image1, image2):
    if image1.shape != image2.shape:
        raise ValueError("Input images must have the same dimensions")
    
    rmse_per_channel = np.sqrt(np.mean((image1 - image2) ** 2, axis=(0, 1)))
    return np.mean(rmse_per_channel)

def calculate_sam(image1, image2):
    image1 = image1.astype(np.float64)
    image2 = image2.astype(np.float64)
    sam_value = np.mean(np.arccos(np.clip(np.sum(image1 * image2, axis=-1) / 
                            (np.linalg.norm(image1, axis=-1) * np.linalg.norm(image2, axis=-1) + 1e-10), -1, 1)))
    return sam_value

def calculate_lpips(image1, image2):
    image1 = torch.from_numpy(image1).permute(2, 0, 1).unsqueeze(0).float().to(device) / 255.0
    image2 = torch.from_numpy(image2).permute(2, 0, 1).unsqueeze(0).float().to(device) / 255.0
    with torch.no_grad():
        return loss_fn_lpips(image1, image2).item()

def calculate_dists(image1, image2):
    image1 = torch.from_numpy(image1).permute(2, 0, 1).unsqueeze(0).float().to(device) / 255.0
    image2 = torch.from_numpy(image2).permute(2, 0, 1).unsqueeze(0).float().to(device) / 255.0
    with torch.no_grad():
        return dists_model(image1, image2).item()

def calculate_lpips_old(image1_path, image2_path):
    img1 = lpips.im2tensor(lpips.load_image(image1_path)).to(device)
    img2 = lpips.im2tensor(lpips.load_image(image2_path)).to(device)
    return loss_fn_lpips(img1, img2).item()

def calculate_dists_old(image1_path, image2_path):
    img1 = lpips.im2tensor(lpips.load_image(image1_path)).to(device)
    img2 = lpips.im2tensor(lpips.load_image(image2_path)).to(device)
    return dists_model(img1, img2).item()

def safe_symlink(src, dst):
    try:
        os.symlink(src, dst)
    except FileExistsError:
        os.remove(dst)
        os.symlink(src, dst)

def prepare_folders_for_fid(images_folder):
    generated_folder = os.path.join(images_folder, 'generated_for_fid')
    gt_folder = os.path.join(images_folder, 'gt_for_fid')
    
    os.makedirs(generated_folder, exist_ok=True)
    os.makedirs(gt_folder, exist_ok=True)
    
    for filename in os.listdir(images_folder):
        if filename.startswith('generated_'):
            src = os.path.join(images_folder, filename)
            dst = os.path.join(generated_folder, filename)
            safe_symlink(os.path.relpath(src, os.path.dirname(dst)), dst)
        elif filename.startswith('gt_'):
            src = os.path.join(images_folder, filename)
            dst = os.path.join(gt_folder, filename)
            safe_symlink(os.path.relpath(src, os.path.dirname(dst)), dst)
    
    return generated_folder, gt_folder

def main():
    images_folder = './image_results'
    result_file = './result.txt'
    html_file = './index.html'

    psnr_scores = []
    ssim_scores = []
    rmse_scores = []
    sam_scores = []
    lpips_scores = []
    dists_scores = []
    lpips_scores_old = []
    dists_scores_old = []

    html_content = """
    <!DOCTYPE html>
    <html lang="en">
    <head>
        <meta charset="UTF-8">
        <meta name="viewport" content="width=device-width, initial-scale=1.0">
        <title>Image Comparison</title>
        <style>
            table { width: 100%; border-collapse: collapse; }
            th, td { border: 1px solid black; padding: 10px; text-align: center; }
            img { width: 100%; max-width: 300px; }
        </style>
    </head>
    <body>
        <h1>Image Comparison</h1>
        <table>
            <tr>
                <th>RGB Image</th>
                <th>Generated Image</th>
                <th>NIR Image</th>
                <th>PSNR</th>
                <th>SSIM</th>
                <th>RMSE</th>
                <th>SAM</th>
                <th>LPIPS</th>
                <th>DISTS</th>
            </tr>
    """

    with open(result_file, 'w') as f:
        f.write("File, PSNR, SSIM, RMSE, SAM, LPIPS, DISTS\n")
        image_count = len([name for name in os.listdir(images_folder) if name.startswith('input_')])
        
        for i in range(1, image_count + 1):
            rgb_image_name = f'input_{i}.png'
            generated_image_name = f'generated_{i}.png'
            nir_image_name = f'gt_{i}.png'
            
            rgb_image_path = os.path.join(images_folder, rgb_image_name)
            generated_image_path = os.path.join(images_folder, generated_image_name)
            nir_image_path = os.path.join(images_folder, nir_image_name)

            if os.path.exists(rgb_image_path) and os.path.exists(nir_image_path) and os.path.exists(generated_image_path):
                image1 = cv2.imread(generated_image_path)
                image2 = cv2.imread(nir_image_path)

                # Resize images to 256x256
                image1 = cv2.resize(image1, (256, 256))
                image2 = cv2.resize(image2, (256, 256))

                psnr_value, ssim_value = calculate_psnr_ssim(generated_image_path, nir_image_path)
                rmse_value = calculate_rmse(image1, image2)
                sam_value = calculate_sam(image1, image2)
                lpips_value = calculate_lpips(image1, image2)
                dists_value = calculate_dists(image1, image2)
                
                psnr_scores.append(psnr_value)
                ssim_scores.append(ssim_value)
                rmse_scores.append(rmse_value)
                sam_scores.append(sam_value)
                lpips_scores.append(lpips_value)
                dists_scores.append(dists_value)
                
                f.write(f"image_{i}, {psnr_value}, {ssim_value}, {rmse_value}, {sam_value}, {lpips_value}, {dists_value}\n")
                
                html_content += f"""
                <tr>
                    <td>
                        <p>{rgb_image_name}</p>
                        <img src="{rgb_image_path}" alt="{rgb_image_name}">
                    </td>
                    <td>
                        <p>{generated_image_name}</p>
                        <img src="{generated_image_path}" alt="{generated_image_name}">
                    </td>
                    <td>
                        <p>{nir_image_name}</p>
                        <img src="{nir_image_path}" alt="{nir_image_name}">
                    </td>
                    <td>{psnr_value:.2f}</td>
                    <td>{ssim_value:.4f}</td>
                    <td>{rmse_value:.2f}</td>
                    <td>{sam_value:.4f}</td>
                    <td>{lpips_value:.4f}</td>
                    <td>{dists_value:.4f}</td>
                </tr>
                """
            else:
                print(f"File {rgb_image_path} or {nir_image_path} or {generated_image_path} does not exist.")

        mean_psnr = np.mean(psnr_scores)
        mean_ssim = np.mean(ssim_scores)
        mean_rmse = np.mean(rmse_scores)
        mean_sam = np.mean(sam_scores)
        mean_lpips = np.mean(lpips_scores)
        mean_dists = np.mean(dists_scores)

        # Prepare folders for FID calculation
        generated_folder, gt_folder = prepare_folders_for_fid(images_folder)

        # Calculate FID
        fid_value = fid_score.calculate_fid_given_paths([generated_folder, gt_folder], 50, device, 2048)

        f.write(f"\nMean PSNR: {mean_psnr}\n")
        f.write(f"Mean SSIM: {mean_ssim}\n")
        f.write(f"Mean RMSE: {mean_rmse}\n")
        f.write(f"Mean SAM: {mean_sam}\n")
        f.write(f"Mean LPIPS: {mean_lpips}\n")
        f.write(f"Mean DISTS: {mean_dists}\n")
        f.write(f"FID: {fid_value}\n")

        html_content += f"""
        </table>
        <h2>Mean PSNR: {mean_psnr:.2f}</h2>
        <h2>Mean SSIM: {mean_ssim:.4f}</h2>
        <h2>Mean RMSE: {mean_rmse:.2f}</h2>
        <h2>Mean SAM: {mean_sam:.4f}</h2>
        <h2>Mean LPIPS: {mean_lpips:.4f}</h2>
        <h2>Mean DISTS: {mean_dists:.4f}</h2>
        <h2>FID: {fid_value:.4f}</h2>
        """

        html_content += "</body></html>"

    with open(html_file, 'w') as f:
        f.write(html_content)

if __name__ == '__main__':
    main()