In [1]:
import tqdm
import cv2
import torch

In [2]:
from pathlib import Path
from fplab.tools.array import IMA
from fplab.tools.tensor import IMT
from fplab.tools.image import change_image, get_data_dirs

In [3]:
from datasets import preprocess_data
from models import FENet, ReNet

In [4]:
load_dir = r"latest_fe_net.pth"
load_dir = str(Path(load_dir).absolute())
save_root_dir = r"..\results\MOLF"
save_root_dir = str(Path(save_root_dir).absolute())

In [5]:
result_path = Path(save_root_dir)

In [6]:
imgs_dir = r"E:\Projects\fingerprint\data\test\MOLF\images"
imgs_dir = str(Path(imgs_dir).absolute())
img_dirs = []
img_dirs = get_data_dirs(imgs_dir, img_dirs)

In [7]:
class MyOptions():
    def __init__(self):
        self.device = "cuda:0"
        self.learning_rate = 0.0001
        self.epoch_start = 0
        self.epoch_end = 10
        self.epoch_fixed_lr = 5
        self.weight_path = Path("")
        self.ch_num_in = 1
        self.ch_num_n = 64
        self.res_num_n = [3, 6, 3, 3, 6]
        self.image_size_max = 448
        self.image_pad_size = [512, 512]
        self.need_zscore = True
        self.position_embedding = True

In [8]:
options = MyOptions()
fe_net = FENet(options)
re_net = ReNet(options)

In [9]:
nets = [fe_net, re_net]
for i in range(len(nets)):
    state_dict = torch.load(Path(load_dir.replace(nets[0].name, nets[i].name)), map_location=options.device)
    nets[i].load_state_dict(state_dict)
for net in nets:
    net.eval()

In [10]:
with tqdm.tqdm(total=len(img_dirs)) as pbar:
    pbar.set_description("Enhancing latent fingerprint")
    for d in img_dirs:
        # Preprocess the image and get the save path
        im_d = str(Path(d).absolute())
        sv_d = im_d.replace(imgs_dir, str(result_path.absolute()))
        shape = IMA.read(im_d).shape[:2]
        Path(sv_d).parent.mkdir(parents=True, exist_ok=True)
        length_max = shape[0] if shape[0] >= shape[1] else shape[1]
        if length_max > options.image_size_max:
            r = options.image_size_max / length_max
            resize_shape = [int(shape[0] * r), int(shape[1] * r)]
        else:
            resize_shape = shape
        im_d = str((change_image(im_d, sv_pt=Path(sv_d).parent, md='L', re_sz=[resize_shape[1], resize_shape[0]])).absolute())
        # Read the image
        im = IMT.read(im_d, options.device).rgb2l().imt
        im_in = preprocess_data(im, need_zscore=options.need_zscore, test_mode=True)
        im_in = IMT(im_in).pad_crop(options.image_pad_size, pad_md='replicate').imt.unsqueeze(0)
        # Enhance the image
        with torch.no_grad():
            im_out = re_net(fe_net(im_in)[0])
        im_out = IMT(im_out.squeeze().unsqueeze(0)).pad_crop(resize_shape).imt
        im_out = IMA.imt2ima(im_out)
        im_out = cv2.resize(im_out.ima, (shape[1], shape[0]))
        IMA(im_out).save(sv_d)
        pbar.update(1)

Enhancing latent fingerprint: 100%|████████████████████████████████████████████████| 3600/3600 [39:24<00:00,  1.52it/s]
