In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import sys
import os
from PIL import Image
import time
import yaml
import numpy as np
from tqdm import tqdm
from torchvision import transforms
from torchvision.transforms import ToPILImage, ToTensor
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import MultiStepLR
from torch.optim import Adam
from torchvision.utils import save_image
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'


sys.path.append("models")
import models
from models import losses
from models.liif import LIIF
from models.discriminator import Discriminator
from models.losses import AdversarialLoss

import utils
from utils import make_coord,set_save_path,ssim
import datasets
from test import eval_psnr_ssim, batched_predict

%matplotlib inline

In [20]:
img = Image.open(r'load\div2k\DIV2K_valid_LR_bicubic\X2\0802x2.png')
#img = transforms.Resize((int(img.height/2),int(img.width/2)),Image.BICUBIC)(img)
timg = transforms.ToTensor()(img) #[3,LR_H,LR_W]
edgemap = edge.laplacian_kernel(timg.unsqueeze(0).cuda())
transforms.ToPILImage()(edgemap[0]).save('0802x2edge.jpg')

In [20]:
#pytorch 求LPIPS
 
import torch
import lpips
import os
 
use_gpu = False         # Whether to use GPU
spatial = False         # Return a spatial map of perceptual distance.
 
# Linearly calibrated models (LPIPS)
loss_fn = lpips.LPIPS(net='vgg', spatial=spatial) # Can also set net = 'squeeze' or 'vgg'
# loss_fn = lpips.LPIPS(net='alex', spatial=spatial, lpips=False) # Can also set net = 'squeeze' or 'vgg'
 
if(use_gpu):
	loss_fn.cuda()
	

Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]
Loading model from: d:\Miniconda3\envs\pytorch\lib\site-packages\lpips\weights\v0.1\vgg.pth


In [23]:
## Example usage with dummy tensors
rood_path = r'D:\Project\results\faces'
hr_path = r'E:\Code\Python\datas\selfWHURS\WHURS19-test\GT'
sr_path = r'E:\Code\Python\liif-self\result\WHURS19_edsrblx4'

hr_path_list = []
sr_path_list = []
## path in net is already exist

for root, _, fnames in sorted(os.walk(hr_path, followlinks=True)):
	for fname in fnames:
		path = os.path.join(hr_path, fname)
		hr_path_list.append(path)

for root, _, fnames in sorted(os.walk(sr_path, followlinks=True)):
	for fname in fnames:
		path = os.path.join(sr_path, fname)
		sr_path_list.append(path)


In [26]:
dist_ = []
for i in range(len(hr_path_list)):
	hr_img = lpips.im2tensor(lpips.load_image(hr_path_list[i]))
	sr_img = lpips.im2tensor(lpips.load_image(sr_path_list[i]))
	if(use_gpu):
		hr_img = hr_img.cuda()
		sr_img = sr_img.cuda()
	dist = loss_fn.forward(hr_img, sr_img)
	dist_.append(dist.mean().item())
print('Avarage Distances: %.3f' % (sum(dist_)/len(hr_path_list)))

Avarage Distances: 0.304


In [21]:
hr_img = lpips.im2tensor(lpips.load_image(r'E:\Code\Python\liif-self\result\WHURS19_edsrblx4\airport_41.png'))
sr_img = lpips.im2tensor(lpips.load_image(r'E:\Code\Python\datas\selfWHURS\WHURS19-test\GT\airport_41.jpg'))
if(use_gpu):
    hr_img = hr_img.cuda()
    sr_img = sr_img.cuda()
dist = loss_fn.forward(hr_img, sr_img)
dist.mean()

tensor(0.2441, grad_fn=<MeanBackward0>)

In [22]:
hr_img = lpips.im2tensor(lpips.load_image(r'E:\Code\Python\liif-self\result\WHURS19_samx_L0Sgradx4\airport_41.png'))
sr_img = lpips.im2tensor(lpips.load_image(r'E:\Code\Python\datas\selfWHURS\WHURS19-test\GT\airport_41.jpg'))
if(use_gpu):
    hr_img = hr_img.cuda()
    sr_img = sr_img.cuda()
dist = loss_fn.forward(hr_img, sr_img)
dist.mean()

tensor(0.2435, grad_fn=<MeanBackward0>)

In [None]:

modelpath = r'weights\edsr-baseline-liif.pth'
lr_path = r'testimg\div2klrx4\0801x4.png'
hr_path = r'load\div2k\DIV2K_valid_HR\0801.png'
sr_path = r'testimg\ouput.jpg'

lr = transforms.ToTensor()(Image.open(lr_path))
hr = transforms.ToTensor()(Image.open(hr_path))
model = models.make(torch.load(modelpath)['model'], load_sd=True).cuda()
lr = ((lr - 0.5) / 0.5).cuda().unsqueeze(0)
with torch.no_grad():
    feat = model.gen_feat(lr)
featimg= transforms.ToPILImage()(feat[0][0])
plt.imshow(featimg)