In [1]:
import torch
import os
import argparse
from datetime import datetime
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
from skimage.metrics import peak_signal_noise_ratio, structural_similarity as compare_ssim
from AutoDnCNN_Image_Denoising.TrainingCode.AutoDnCNN_remake.train.main_train import ConvAutoencoder, DnCNN

def load_model(model_path, model_class):
    """
    加载模型权重并设置为评估模式
    """
    model = model_class()
    model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
    model.eval()
    return model

def denoise_image(model, image):
    """
    对图像添加噪声并使用模型进行去噪
    """
    # 添加噪声
    noisy_image = image + torch.randn_like(image) * 0.3
    noisy_image = torch.clamp(noisy_image, 0., 1.)

    # 去噪
    with torch.no_grad():
        denoised_image = model(noisy_image)

    return denoised_image.squeeze()

def save_result(result, path):
    """
    保存去噪后的图像
    """
    result = np.clip(result, 0, 1)
    Image.fromarray((result * 255).astype(np.uint8)).save(path)

def log(*args, **kwargs):
    """
    打印日志信息
    """
    print(datetime.now().strftime("%Y-%m-%d %H:%M:%S:"), *args, **kwargs)

if __name__ == '__main__':
    # 解析命令行参数
    parser = argparse.ArgumentParser()
    parser.add_argument('--set_dir', default='AutoDnCNN-Image-Denoising/testsets', type=str, help='directory of test dataset')
    parser.add_argument('--set_names', default=['BSD68','NoisyImages','Set12'], help='directory of test dataset')
    parser.add_argument('--model_dir', default='AutoDnCNN-Image-Denoising/models', help='directory of the model')
    parser.add_argument('--model_name_autoencoder', default='autoencoder_weights.pth', type=str, help='the model name')
    parser.add_argument('--model_name_dncnn', default='dncnn_weights.pth', type=str, help='the model name')
    parser.add_argument('--result_dir', default='results', type=str, help='directory of test dataset')
    parser.add_argument('--save_result', default=1, type=int, help='save the denoised image, 1 or 0')
    args = parser.parse_args()

    # 创建结果目录（如果不存在）
    os.makedirs(args.result_dir, exist_ok=True)

    # 加载AutoEncoder模型
    log('Loading AutoEncoder model...')
    autoencoder_model = load_model(os.path.join(args.model_dir, args.model_name_autoencoder), ConvAutoencoder)
    log('AutoEncoder model loaded successfully.')

    # 加载DnCNN模型
    log('Loading DnCNN model...')
    dncnn_model = load_model(os.path.join(args.model_dir, args.model_name_dncnn), DnCNN)
    log('DnCNN model loaded successfully.')

    # 遍历数据集
    for set_name in args.set_names:
        set_dir = os.path.join(args.set_dir, set_name)
        result_set_dir = os.path.join(args.result_dir, set_name)
        os.makedirs(result_set_dir, exist_ok=True)
        log(f'Saving results to: {result_set_dir}')

        # 遍历图像文件
        for im_name in os.listdir(set_dir):
            if im_name.endswith(".jpg") or im_name.endswith(".bmp") or im_name.endswith(".png"):
                img_path = os.path.join(set_dir, im_name)
                log(f'Processing image: {img_path}')
                
                # 加载并预处理图像
                image = Image.open(img_path).convert('L')  # 转换为灰度图
                image = transforms.ToTensor()(image).unsqueeze(0)  # 增加批次维度

                # 使用AutoEncoder模型去噪
                log('Denoising with AutoEncoder...')
                denoised_image_autoencoder = denoise_image(autoencoder_model, image)
                log('AutoEncoder denoising completed.')

                # 使用DnCNN模型去噪
                log('Denoising with DnCNN...')
                denoised_image_dncnn = denoise_image(dncnn_model, image)
                log('DnCNN denoising completed.')

                # 保存去噪后的图像
                if args.save_result:
                    log('Saving denoised images...')
                    save_result(denoised_image_autoencoder.squeeze().numpy(), os.path.join(result_set_dir, f'{os.path.splitext(im_name)[0]}_autoencoder.png'))
                    save_result(denoised_image_dncnn.squeeze().numpy(), os.path.join(result_set_dir, f'{os.path.splitext(im_name)[0]}_dncnn.png'))
                    log(f'Saved denoised images: {os.path.splitext(im_name)[0]}')

    log('All images processed and saved successfully.')

ModuleNotFoundError: No module named 'AutoDnCNN_Image_Denoising'