In [None]:
import os,sys, math, time
from glob import glob
from tqdm import tqdm
from PIL import Image
import numpy as np
import torch,warnings
from torch import nn
import torchvision
import torchvision.utils as vutils
import torch.utils.data as data
import torchvision.transforms as tfs
from torchvision.transforms import functional as FF
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
from torchvision.utils import make_grid
warnings.filterwarnings('ignore')
from torchvision.models import vgg16
from model import *
from losses import *
from PerceptualLoss import PerLoss
from metrics import psnr,ssim
from torch.backends import cudnn
from torch import optim
from dataset import AFO_Dataset

### Define model and load pretrained weights

In [None]:
### define model

models_={
    'lvrnet':LVRNet(gps=3,blocks=16),
}

net=models_['lvrnet']
device = 'cuda' if torch.cuda.is_available() else 'cpu'
net=net.to(device)

if device=='cuda':
    net=torch.nn.DataParallel(net)
    cudnn.benchmark=True

In [None]:
model_wts = "../weights/LPEF_Epoch47.pth"

ckp=torch.load(model_wts)
net.load_state_dict(ckp['model'])

### Qualitative results

In [None]:
###
demo_dir = "../demo"

## go in demo_dir and get images
img_paths = glob(os.path.join(demo_dir, "*.jpg"))
n_imgs = len(img_paths)

## load images
inputs = []
for img_path in img_paths:
    img = Image.open(img_path).convert('RGB')
    img = img.resize((456, 256))
    # img = img.resize((img.size[0] // 2, img.size[1] // 2)) # this might give OOM error
    img = tfs.ToTensor()(img)
    inputs.append(img)

inputs = torch.stack(inputs, dim=0).to(device) # (#images, 3, 456, 256)

In [None]:
net.eval()
torch.cuda.empty_cache()
for i in range(n_imgs):
    start = time.time()
    with torch.no_grad():
        pred=net(inputs[i].unsqueeze(0))
    end = time.time()
    print("Time taken for inference: ", end-start)
    
    # visualize outputs
    merged_io = torch.cat([inputs[i].unsqueeze(0),pred],dim=0)
    grid_img = vutils.make_grid(merged_io, nrow=2, normalize=True, scale_each=True)
    plt.figure(figsize=(15,15))
    plt.imshow(grid_img.permute(1, 2, 0).cpu().numpy())
    plt.axis('off')
    plt.show()