In [1]:
import torch
import os
from PIL import Image
from skimage.color import lab2rgb, rgb2lab
import numpy as np
from tqdm import tqdm
import time
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
from skimage.metrics import structural_similarity as compare_ssim
import warnings

from model import Colorization_Model
from my_dataset import resize_img

def lab_to_rgb_tensor(L, ab):
    L = L.squeeze().cpu().numpy()
    ab = ab.squeeze().cpu().numpy()

    L = (L + 1.) * 50.
    ab = ab * 110.

    if L.shape != ab[0].shape or L.shape != ab[1].shape:
        ab0 = resize_img(ab[0], HW=L.shape, resample=1)
        ab1 = resize_img(ab[1], HW=L.shape, resample=1)
    else:
        ab0, ab1 = ab[0], ab[1]

    Lab = np.stack((L, ab0, ab1), axis=2)

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        rgb = lab2rgb(Lab)

    return rgb

def main():
    input_folder = '../dataset/colorization/test/test_color'
    model_path = './colorization_model.pt'

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = Colorization_Model().to(device)

    if os.path.exists(model_path):
        checkpoint = torch.load(model_path, map_location=device)
        if 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'])
        else:
            model.load_state_dict(checkpoint)
        print(f"Loaded model weights from {model_path}")
    else:
        print(f"Model weight {model_path} does not exist.")
        return

    model.eval()

    supported_formats = ('.jpg', '.jpeg', '.png', '.bmp', '.tiff')
    image_paths = [os.path.join(input_folder, fname) for fname in os.listdir(input_folder)
                  if fname.lower().endswith(supported_formats)]

    if not image_paths:
        print("Supported image formats not found in the specified folder.")
        return

    total_psnr = 0.0
    total_ssim = 0.0
    total_time = 0.0
    num_images = len(image_paths)
    skipped_images = 0

    with torch.no_grad():
        with tqdm(total=num_images, desc='Evaluating...') as pbar:
            for img_path in image_paths:
                start_time = time.perf_counter()

                try:
                    img = Image.open(img_path).convert('RGB')
                    img_np = np.array(img)
                    resized_img = resize_img(img_np, HW=(256, 256), resample=1)

                    lab = rgb2lab(resized_img).astype(np.float32)
                    L = lab[:, :, 0] / 50. - 1.
                    ab = lab[:, :, 1:] / 110.

                    L_tensor = torch.from_numpy(L).unsqueeze(0).unsqueeze(0)
                    L_tensor = L_tensor.to(device)

                    fake_ab = model.net_G(L_tensor)

                    rgb_generated = lab_to_rgb_tensor(L_tensor, fake_ab)

                    rgb_original = resized_img.astype(np.float32) / 255.

                    if rgb_generated.shape != rgb_original.shape:
                        raise ValueError(f"The generated RGB image shape {rgb_generated.shape} does not "
                                         f"match the original image shape {rgb_original.shape}.")

                    if rgb_generated.shape[0] < 7 or rgb_generated.shape[1] < 7:
                        raise ValueError(f"Image size too small: {rgb_generated.shape}")

                    current_psnr = compare_psnr(rgb_original, rgb_generated, data_range=1.0)
                    current_ssim = compare_ssim(rgb_original, rgb_generated, channel_axis=-1, data_range=1.0)

                    end_time = time.perf_counter()
                    elapsed_time = end_time - start_time

                    total_psnr += current_psnr
                    total_ssim += current_ssim
                    total_time += elapsed_time

                except Exception as e:
                    print(f"Error processing image {img_path}: {e}")
                    skipped_images += 1
                    pbar.update(1)
                    continue

                pbar.set_postfix({'PSNR': f'{current_psnr:.6f}', 'SSIM': f'{current_ssim:.6f}'})
                pbar.update(1)

    processed_images = num_images - skipped_images
    if processed_images > 0:
        avg_psnr = total_psnr / processed_images
        avg_ssim = total_ssim / processed_images
        avg_time = total_time / processed_images

        print(f"Processed images: {processed_images}")
        print(f"Skipped images: {skipped_images}")
        print(f"Average PSNR: {avg_psnr:.2f}")
        print(f"Average SSIM: {avg_ssim:.4f}")
        print(f"Average processing time per image: {avg_time:.4f} seconds")
    else:
        print("No images were successfully processed.")

if __name__ == '__main__':
    main()

Generator initialized with norm initialization
Discriminator initialized with norm initialization
Loaded model weights from ./colorization_model.pt


Evaluating...: 100%|██████████| 6300/6300 [02:54<00:00, 36.12it/s, PSNR=24.825292, SSIM=0.950138]

Processed images: 6300
Skipped images: 0
Average PSNR: 30.98
Average SSIM: 0.9772
Average processing time per image: 0.0265 seconds



