# docUnet 분석


In [1]:
# 예측 정의

import torch
from torchvision import transforms
import os
import matplotlib.pyplot as plt
import cv2
import time
from PIL import Image
import numpy as np
from natsort import natsorted


def get_file_list(folder_path: str, p_postfix: str or list = ['.jpg'], sub_dir: bool = True) -> list:
    """
    Get the file of the specified suffix in the file directory given, and read the file list.
    os.walk and os.listdir ，These two are currently much faster than Pathlib
    :param filder_path: Folder name
    :param p_postfix: File suffix, if it is [.*], all files will be returned
    :param sub_dir: Whether to search subfolders
    :return: Get the list of files of the specified type
    """
    assert os.path.exists(folder_path) and os.path.isdir(folder_path)
    if isinstance(p_postfix, str):
        p_postfix = [p_postfix]
    file_list = []
    if sub_dir:
        for rootdir, _, files in os.walk(folder_path):
            for file in files:
                file_path = os.path.join(rootdir, file)
                for p in p_postfix:
                    if os.path.isfile(file_path) and (file_path.endswith(p) or p == '.*'):
                        file_list.append(file_path)
    else:
        for file in os.listdir(folder_path):
            file_path = os.path.join(folder_path, file)
            for p in p_postfix:
                if os.path.isfile(file_path) and (file_path.endswith(p) or p == '.*'):
                    file_list.append(file_path)
    return natsorted(file_list)


class Pytorch_model:
    def __init__(self, model_path, net, img_h, img_w, img_channel=3, gpu_id=None):
        '''
        initialization pytorch Model
        :param model_path: Model address 
         (can be the parameters of the model or a file where the parameters and calculation graph are saved together)
        :param net: Network calculation graph. If the parameter saving path is specified in model_path
                    ,the calculation graph of the network needs to be given.
        :param img_channel: The number of channels in the image: 1,3
        :param gpu_id: Which GPU to run on
        '''
        self.img_h = img_h
        self.img_w = img_w
        self.img_channel = img_channel
        if gpu_id is not None and isinstance(gpu_id, int) and torch.cuda.is_available():
            os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
            self.device = torch.device("cuda:0")
            self.net = torch.load(
                model_path, map_location=lambda storage, loc: storage.cuda(gpu_id))
        else:
            self.device = torch.device("cpu")
            self.net = torch.load(
                model_path, map_location=lambda storage, loc: storage.cpu())
        print('device:', self.device)

        if net is not None:
            # If the network calculation graph and parameters are saved separately, perform parameter loading.
            net = net.to(self.device)
            net.load_state_dict(self.net['state_dict'])
            self.net = net
        self.net.eval()

    def predict(self, img):
        '''
        Make predictions on incoming images, support image addresses
        :param img: Image address
        :param is_numpy:
        :return:
        '''
        assert self.img_channel in [1, 3], 'img_channel must in [1.3]'

        if isinstance(img, str):  # read image
            assert os.path.exists(img), 'file is not exists'
            img = Image.open(img)
        if self.img_channel == 1 and img.mode == 'RGB':
            img = img.convert('L')
        w, h = img.size
        if w > h:
            ratio = h / w
            new_w = w // 16 * 16
            new_h = int(new_w * ratio)
        else:
            ratio = w / h
            new_h = h // 16 * 16
            new_w = int(new_h * ratio)

        img = img.resize((self.img_w, self.img_h))
        print(img.size)
        # Change the picture from (w,h) to (1,img_channel,h,w)
        tensor = transforms.ToTensor()(img)
        tensor = tensor.unsqueeze_(0)

        tensor = tensor.to(self.device)
        preds = self.net(tensor)
        # print(preds)
        preds = preds[0].permute(1, 2, 0).detach().cpu().numpy()
        img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
        unwarp_img = cv2.remap(img, preds[:, :, 0], preds[:, :, 1], cv2.INTER_LINEAR)
        # unwarp_img = cv2.resize(unwarp_img,(w,h))
        return unwarp_img

    def predict_cv(self, img: np.ndarray or str):
        '''
        Predict incoming images, support image addresses, opecv reads images, is slow
        :param img: Image address
        :param is_numpy:
        :return:
        '''
        assert self.img_channel in [1, 3], 'img_channel must in [1.3]'

        if isinstance(img, str):  # read image
            assert os.path.exists(img), 'file is not exists'
            img = cv2.imread(img, 0 if self.img_channel == 1 else 1)

        if len(img.shape) == 2 and self.img_channel == 3:
            img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
        elif len(img.shape) == 3 and self.img_channel == 1:
            img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        h, w = img.shape[:2]
        # img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
        t_img = cv2.resize(img, (self.img_w, self.img_h))
        print(t_img.shape)
        # Change the picture from (w,h) to (1,img_channel,h,w)
        tensor = transforms.ToTensor()(t_img)
        tensor = tensor.unsqueeze_(0)

        tensor = tensor.to(self.device)
        preds = self.net(tensor)
        # print(preds)
        preds = preds[0].permute(1, 2, 0).detach().cpu().numpy()
        unwarp_img = cv2.remap(t_img, preds[:, :, 0], preds[:, :, 1], cv2.INTER_LINEAR)
        unwarp_img = cv2.resize(unwarp_img, (w, h))
        return unwarp_img, img


if __name__ == '__main__':
    from models.deeplab_models.deeplab import DeepLab

    model_path = 'output/deeplab_add_bg_img_800_600_item_origin_deeplab_resnet/DocUnet_77_0.41399721733152867.pth'

    # model_path = './output/model.pkl'
    # img_path = '/data2/zj/data/add_bg_img_800_600/item1/0_0.jpg'
    img_path = '/data2/zj/data/doc_testdata/5.jpg'
    # Initialize the network
    net = DeepLab(backbone='resnet', output_stride=16, num_classes=2, pretrained=False)
    model = Pytorch_model(model_path, net=net, img_h=600, img_w=800, img_channel=3, gpu_id=2)
    for img_path in get_file_list('/data2/zj/data/doc_testdata', p_postfix='.jpg'):
        if img_path.__contains__('result'):
            continue
        start = time.time()
        unwarp_img, img = model.predict_cv(img_path)
        print(time.time() - start)
        # Execute prediction
        # Visualization
        save_path = os.path.splitext(img_path)[0]
        print(save_path)
        cv2.imwrite(save_path + '_epoch_77_result.jpg', unwarp_img)
        plt.subplot(1, 2, 1)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        plt.title('input')
        plt.imshow(img)
        plt.subplot(1, 2, 2)
        unwarp_img = cv2.cvtColor(unwarp_img, cv2.COLOR_BGR2RGB)
        plt.title('output')
        plt.imshow(unwarp_img)
        # plt.savefig(save_path + '_plt_result.jpg', dpi=600)
        plt.show()


KeyboardInterrupt: 