## HLIP Inference

In [2]:
import os
from PIL import Image

import torch
from torchvision import transforms

import models
import numpy as np
from tqdm import tqdm
from utils import make_coord
from glob import glob

def freq_filter(crop_lr, centers):
    frq_list = []
    chs, rows, cols = crop_lr.shape
    crow, ccol = int(rows/2), int(cols/2)
    for center in centers:
        mask = np.ones((rows, cols), np.uint8)
        mask[crow-center:crow+center, ccol-center:ccol+center] = 0
        for ch in range(chs):
            f_lr = np.fft.fft2(crop_lr[ch])
            f_lr_shift = np.fft.fftshift(f_lr)
            f_shift = f_lr_shift * mask
            ifft_image = np.fft.ifft2(np.fft.ifftshift(f_shift)).real
            frq_list.append((ifft_image - np.min(ifft_image)) / (np.max(ifft_image) - np.min(ifft_image)))
    frq_lr = np.stack(frq_list, axis=0).astype(np.float32)
    return frq_lr

def batched_predict(model, inp, coord, cell, bsize, frq_lr):
    with torch.no_grad():
        features = model.gen_feat(inp, frq_lr) # B F height width
        n = coord.shape[1]
        ql = 0
        preds = []
        while ql < n:
            qr = min(ql + bsize, n)
            pred = model.query_rgb(features, coord[:, ql: qr, :], cell[:, ql: qr, :])
            preds.append(pred)
            ql = qr
        pred = torch.cat(preds, dim=1)
    return pred

  from .autonotebook import tqdm as notebook_tqdm


### Parameters

In [4]:
input_dir = './input'  # Directory to input images
model = './save/x4.pth' # Path to inference model
# Remind: Our model is input-size and magnification free, so you can select any model for any magnification super resolution. Obviously, the matched model has the best performance.
resolution = [256, 256] # Target super resolution size
output_dir = './output' # Directory to save super resolution images
gpu = 'cuda:0'
ext = 'jpg' # Image type
os.makedirs(output_dir, exist_ok=True)
model = models.make(torch.load(model)['model'], load_sd=True).to(torch.device(gpu))
h, w = resolution

### Inference

In [5]:
for input_image in tqdm(glob(os.path.join(input_dir, f'*.{ext}'))):
    lr_img = Image.open(input_image).convert('RGB')
    img_name = input_image.split('/')[-1].split('.')[0]
    img = transforms.ToTensor()(lr_img)
    frq_lr = freq_filter(img, [30, 60])
    frq_lr = torch.tensor(frq_lr, dtype=torch.float32)
    coord = make_coord((h, w)).cuda()
    cell = torch.ones_like(coord)
    cell[:, 0] *= 2 / h
    cell[:, 1] *= 2 / w
    pred = batched_predict(model, ((img - 0.5) / 0.5).cuda().unsqueeze(0),
        coord.unsqueeze(0), cell.unsqueeze(0), bsize=30000, frq_lr=frq_lr.cuda().unsqueeze(0))[0]
    pred = (pred * 0.5 + 0.5).clamp(0, 1).view(h, w, 3).permute(2, 0, 1).cpu()
    transforms.ToPILImage()(pred).save(os.path.join(output_dir, f'{img_name}.jpg'))

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
100%|██████████| 5/5 [00:04<00:00,  1.23it/s]
