In [9]:
import os
import time
import torch
from torch.autograd import Variable
from skimage.color import lab2rgb
from model import ColorizationNet
from img_folder import ValImageFolder
import numpy as np
import matplotlib.pyplot as plt
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
from tqdm import tqdm

data_dir = "../dataset/colorization_/test"
gray_dir = "../dataset/colorization_/test_gray"
have_cuda = torch.cuda.is_available()

color_model = ColorizationNet()
color_model.load_state_dict(torch.load('./model_best_params.pkl', map_location=torch.device('cpu')))
if have_cuda:
    color_model.cuda()

val_set = ValImageFolder(gray_dir)
val_set_size = len(val_set)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=1, shuffle=False, num_workers=1)

total_psnr = 0.0
total_ssim = 0.0
total_time = 0.0
total_samples = 0

def evaluate():
    global total_psnr, total_ssim, total_time, total_samples
    color_model.eval()

    with tqdm(total=val_set_size, desc="Processing") as pbar:
        for i, (data, _) in enumerate(val_loader):
            original_img = data[0].unsqueeze(1).float()
            w = original_img.size()[2]
            h = original_img.size()[3]
            scale_img = data[1].unsqueeze(1).float()
            if have_cuda:
                original_img, scale_img = original_img.cuda(), scale_img.cuda()

            with torch.no_grad():
                original_img, scale_img = Variable(original_img), Variable(scale_img)

                start_time = time.time()

                _, output = color_model(original_img, scale_img)
                color_img = torch.cat((original_img, output[:, :, 0:w, 0:h]), 1)
                color_img = color_img.data.cpu().numpy().transpose((0, 2, 3, 1))

                end_time = time.time()
                elapsed_time = end_time - start_time
                total_time += elapsed_time

                for img in color_img:
                    # 归一化处理
                    img[:, :, 0:1] = img[:, :, 0:1] * 100
                    img[:, :, 1:3] = img[:, :, 1:3] * 255 - 128
                    img = img.astype(np.float64)
                    img = lab2rgb(img)

                    # 加载对应的原始彩色图像
                    original_color_img_path = os.path.join(data_dir, os.path.relpath(val_set.samples[i][0], gray_dir))
                    original_color_img_path = os.path.normpath(original_color_img_path)
                    original_color_img = plt.imread(original_color_img_path)

                    # 确保图像数据类型一致
                    original_color_img = original_color_img.astype(np.float64)
                    img = img.astype(np.float64)

                    # 归一化处理
                    original_color_img = original_color_img / 255.0
                    img = img / 255.0

                    # 计算PSNR和SSIM
                    psnr = peak_signal_noise_ratio(original_color_img, img, data_range=1.0)
                    ssim = structural_similarity(original_color_img, img, multichannel=True, channel_axis=-1, data_range=1.0)

                    total_psnr += psnr
                    total_ssim += ssim
                    total_samples += 1

                    pbar.update(1)
                    pbar.set_postfix({"PSNR": psnr, "SSIM": ssim, "Time": elapsed_time})

    avg_psnr = total_psnr / total_samples
    avg_ssim = total_ssim / total_samples
    avg_time = total_time / total_samples

    print(f"Average PSNR: {avg_psnr:.4f}")
    print(f"Average SSIM: {avg_ssim:.4f}")
    print(f"Average Time per sample: {avg_time:.4f} seconds")

if __name__ == '__main__':
    evaluate()

Processing: 100%|██████████| 6300/6300 [03:16<00:00, 32.04it/s, PSNR=31.7, SSIM=0.852, Time=0.00503]

Average PSNR: 30.7328
Average SSIM: 0.8653
Average Time per sample: 0.0042 seconds



