In [1]:
import argparse
import os
from collections import OrderedDict
from glob import glob
from os.path import join

import cv2
import torch
import torch.nn.functional as F

from model.model import Quantize
from model.model import ResHalf
from utils import util


class Inferencer:
    def __init__(self, checkpoint_path, model, use_cuda=False, multi_gpu=False):
        self.checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
        self.use_cuda = use_cuda
        self.model = model.eval()
        if multi_gpu:
            self.model = torch.nn.DataParallel(self.model)
            state_dict = self.checkpoint['state_dict']
        else:
            # remove keyword "module" in the state_dict
            state_dict = OrderedDict()
            for k, v in self.checkpoint['state_dict'].items():
                name = k[7:]
                state_dict[name] = v
        if self.use_cuda:
            self.model = self.model.cuda()
        self.model.load_state_dict(state_dict)

    def __call__(self, input_img, decoding_only=False):
        # 禁用梯度计算
        with torch.no_grad():
            scale = 8  # 定义缩放因子
            _, _, H, W = input_img.shape  # 获取输入图像的高度和宽度
            # 如果高度或宽度不能被缩放因子整除，就进行反射填充
            if H % scale != 0 or W % scale != 0:
                input_img = F.pad(input_img, [0, scale - W % scale, 0, scale - H % scale], mode='reflect')
            if self.use_cuda:
                input_img = input_img.cuda()
            if decoding_only:
                # 调用self.model进行图像解码
                # model=ResHalf(train=False)
                # print(len(input_img))
                resColor = self.model(input_img, decoding_only)
                # print(type(resColor))
                # 如果之前进行了填充，将结果裁剪回原始图像大小
                if H % scale != 0 or W % scale != 0:
                    resColor = resColor[:, :, :H, :W]
                # 返回逆半调图像
                return resColor
            else:
                # 调用self.model进行图像处理，获取两个结果：resHalftone和resColor
                resHalftone, resColor = self.model(input_img, decoding_only)
                # 对resHalftone进行量化操作
                # Q:为什么在 ResHalf 量化过了这里又进行一次量化
                resHalftone = Quantize.apply((resHalftone + 1.0) * 0.5) * 2.0 - 1.
                # 如果之前进行了填充，将结果裁剪回原始图像大小
                if H % scale != 0 or W % scale != 0:
                    resHalftone = resHalftone[:, :, :H, :W]
                    resColor = resColor[:, :, :H, :W]
                # 返回半调图像和逆半调图像
                return resHalftone, resColor

In [3]:
# Define the variables that were previously command line arguments
model_path = "checkpoints/model_best.pth.tar"
decoding = False  # 仅逆半调
data_dir = "./test_imgs" # 放入测试图片的文件夹（连续调图像）
save_dir = "./result" # 保存结果的文件夹，包括半调与逆半调图像

# Continue with the rest of the code
data_dir = os.path.join(data_dir)
save_dir = os.path.join(save_dir)
invhalfer = Inferencer(
    checkpoint_path=model_path,
    model=ResHalf(train=False)
)
util.ensure_dir(save_dir)
test_imgs = glob(join(data_dir, '*.*g'))
print('------loaded %d images.' % len(test_imgs))
for img in test_imgs:
    print('[*] processing %s ...' % img)
    (name, suffix) = img.split('/')[-1].split('.')
    if decoding:
        input_img = cv2.imread(img, flags=cv2.IMREAD_GRAYSCALE) / 127.5 - 1.
        print(input_img.shape)
        c = invhalfer(util.img2tensor(input_img), decoding_only=True)
        c = util.tensor2img(c / 2. + 0.5) * 255.
        cv2.imwrite(join(save_dir, f'{name}.{suffix}'), c)
    else:
        input_img = cv2.imread(img, flags=cv2.IMREAD_COLOR) / 127.5 - 1.
        h, c = invhalfer(util.img2tensor(input_img), decoding_only=False)
        h = util.tensor2img(h / 2. + 0.5) * 255.
        c = util.tensor2img(c / 2. + 0.5) * 255.
        cv2.imwrite(join(save_dir, f'halftone_{name}.{suffix}'), h)
        cv2.imwrite(join(save_dir, f'restored_{name}.{suffix}'), c)

------loaded 8 images.
[*] processing ./test_imgs/017.png ...
[*] processing ./test_imgs/027.png ...
[*] processing ./test_imgs/037.png ...
[*] processing ./test_imgs/1.jpg ...
[*] processing ./test_imgs/2009_001468.png ...




[*] processing ./test_imgs/dog.png ...
[*] processing ./test_imgs/klee.png ...
[*] processing ./test_imgs/paimoon.png ...
