In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import os
from PIL import Image
import time
import yaml
import numpy as np
from tqdm import tqdm
from torchvision import transforms
from torchvision.transforms import ToPILImage, ToTensor
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import MultiStepLR
from torch.optim import Adam
from torchvision.utils import save_image
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import sys
sys.path.append("models")

import models
import utils
from utils import make_coord,set_save_path,ssim
import datasets
from test import eval_psnr,batched_predict

from models import losses
from models.liif import LIIF
from models.discriminator import Discriminator
from models.losses import AdversarialLoss
%matplotlib inline

In [None]:
modelpath = r'weights\edsr-baseline-liif.pth'
lr_path = r'testimg\div2klrx4\0801x4.png'
hr_path = r'load\div2k\DIV2K_valid_HR\0801.png'
sr_path = r'testimg\ouput.jpg'

In [None]:
lr = ToTensor()(Image.open(lr_path))
hr = ToTensor()(Image.open(hr_path))

In [None]:
model = models.make(torch.load(modelpath)['model'], load_sd=True).cuda()

In [None]:
h, w = lr.shape[1]*2, lr.shape[2]*2

In [None]:
h,w

In [None]:
coord = make_coord((h, w)).cuda()

In [None]:
cell = torch.ones_like(coord)
cell[:, 0] *= 2 / h
cell[:, 1] *= 2 / w

In [None]:
model

In [None]:
lr = ((lr - 0.5) / 0.5).cuda().unsqueeze(0)
coord = coord.unsqueeze(0)
cell = cell.unsqueeze(0)

In [None]:
lr.shape

In [None]:
coord.shape

In [None]:
cell.shape

In [None]:
sr = batched_predict(model, lr, coord, cell, bsize=30000)[0]

In [None]:
sr = (sr * 0.5 + 0.5).clamp(0, 1).view(h, w, 3).permute(2, 0, 1).cpu()

In [None]:
ToPILImage()(sr).save(sr_path)
plt.imshow(ToPILImage()(sr))

In [None]:
ssim(hr.unsqueeze(0),sr.unsqueeze(0))

# 单步

In [None]:
with torch.no_grad():
    feat = model.gen_feat(lr)

In [None]:
feat.shape

In [None]:
for i in range(feat.shape[1]):
    featimg= ToPILImage()(feat[0][i])
    featimg.save(f"feat/featx{i}.jpg")

In [None]:
featuf = F.unfold(feat, 3, padding=1).view(feat.shape[0], feat.shape[1] * 9, feat.shape[2], feat.shape[3])

In [None]:
featuf.shape

In [None]:
for i in range(featuf.shape[1]):
    featimg= ToPILImage()(featuf[0][i])
    featimg.save(f"featunfold/featuf{i}.jpg")

In [None]:
preds = []
areas = []

In [None]:
vx_lst = [-1, 1]
vy_lst = [-1, 1]
eps_shift = 1e-6

In [None]:
rx = 2 / featuf.shape[-2] / 2  #2/LR_H/2
ry = 2 / featuf.shape[-1] / 2  #2/LR_W/2

In [None]:
feat_coord = make_coord(featuf.shape[-2:], flatten=False).cuda()
#[LR_H,LR_W,2]
feat_coord.shape

In [None]:
feat_coord = feat_coord.permute(2, 0, 1).unsqueeze(0).expand(featuf.shape[0], 2, *featuf.shape[-2:])
#[N,2,LR_H,LR_W]
feat_coord.shape

In [None]:
vx = -1
vy = -1

In [None]:
coord = coord[:, 0: 1000, :]
cell = cell[:, 0: 1000, :]

In [None]:
coord_ = coord.clone()#[N,SR_H*SR_W,2]
coord_[:, :, 0] += vx * rx + eps_shift
coord_[:, :, 1] += vy * ry + eps_shift
coord_.clamp_(-1 + 1e-6, 1 - 1e-6)

In [None]:
coord_ = coord_.flip(-1).unsqueeze(1)

In [None]:
q_feat = F.grid_sample(featuf, coord_, mode='nearest', align_corners=False)
#[N,C*9,1,SR_H*SR_W]

In [None]:
q_feat = q_feat[:, :, 0, :].permute(0, 2, 1)#[N,SR_H*SR_W,C*9]

In [None]:
q_coord = F.grid_sample(feat_coord, coord_, mode='nearest', align_corners=False)
#[N,2,1,SR_H*SR_W]

In [None]:
q_coord = q_coord[:, :, 0, :].permute(0, 2, 1)#[N,SR_H*SR_W,2]

In [None]:
rel_coord = coord - q_coord #[N,SR_H*SR_W,2]
rel_coord[:, :, 0] *= featuf.shape[-2]
rel_coord[:, :, 1] *= featuf.shape[-1]
inp = torch.cat([q_feat, rel_coord], dim=-1) #[N,SR_H*SR_W,C*9+2]

In [None]:
rel_cell = cell.clone()
rel_cell[:, :, 0] *= feat.shape[-2]
rel_cell[:, :, 1] *= feat.shape[-1]
inp = torch.cat([inp, rel_cell], dim=-1) #[N,SR_H*SR_W,C*9+2+2]

In [None]:
model.imnet

In [None]:
bs, q = coord.shape[:2] #bs=N q=SR_H*SR_W
#[N*SR_H*SR_W,C*9+2+2] --> [N*SR_H*SR_W,3]

In [None]:
pred = model.imnet(inp.view(bs * q, -1)).view(bs, q, -1) #[N,SR_H*SR_W,3]
preds.append(pred) #[[N,SR_H*SR_W],[N,SR_H*SR_W],[N,SR_H*SR_W],[N,SR_H*SR_W]]

In [None]:
area = torch.abs(rel_coord[:, :, 0] * rel_coord[:, :, 1])
areas.append(area + 1e-9) #[[N,SR_H*SR_W],[N,SR_H*SR_W],[N,SR_H*SR_W],[N,SR_H*SR_W]]

In [None]:
tot_area = torch.stack(areas).sum(dim=0) #[N,SR_H*SR_W]

In [None]:
predimg = (ret * 0.5 + 0.5).clamp(0, 1).view(1000, 1, 3).permute(2, 0, 1).cpu()

In [None]:
transforms.ToPILImage()(predimg).save("1000x1.jpg")

In [None]:
if self.local_ensemble:
    t = areas[0]; areas[0] = areas[3]; areas[3] = t #swap(areas[0],areas[3])
    t = areas[1]; areas[1] = areas[2]; areas[2] = t #swap(areas[1],areas[2])
ret = 0
for pred, area in zip(preds, areas):
    ret = ret + pred * (area / tot_area).unsqueeze(-1)

In [None]:
ret = F.grid_sample(featuf, coord.flip(-1).unsqueeze(1), mode='nearest', align_corners=False)

In [None]:
ret.shape

In [None]:
ret = ret[:, :, 0, :].permute(0, 2, 1)

In [None]:
ret.shape

In [None]:
ret = ret.permute(0, 2, 1).view(1,576,h,w).squeeze(0)

In [None]:
ret.shape

In [None]:
for i in range(ret.shape[0]):
    ToPILImage()(ret[i]).save(f"grid_sample/g_{i}.jpg")

In [None]:
preds = []
areas = []
for vx in vx_lst:
    for vy in vy_lst:
        coord_ = coord.clone()#[N,SR_H*SR_W,2]
        coord_[:, :, 0] += vx * rx + eps_shift
        coord_[:, :, 1] += vy * ry + eps_shift
        coord_.clamp_(-1 + 1e-6, 1 - 1e-6)
        coord_ = coord_.flip(-1).unsqueeze(1)
        q_feat = F.grid_sample(featuf, coord_, mode='nearest', align_corners=False)
        #[N,C*9,1,SR_H*SR_W]
        q_feat = q_feat[:, :, 0, :].permute(0, 2, 1)#[N,SR_H*SR_W,C*9]
        q_coord = F.grid_sample(feat_coord, coord_, mode='nearest', align_corners=False)
        #[N,2,1,SR_H*SR_W]
        q_coord = q_coord[:, :, 0, :].permute(0, 2, 1)#[N,SR_H*SR_W,2]

        rel_coord = coord - q_coord #[N,SR_H*SR_W,2]
        rel_coord[:, :, 0] *= feat.shape[-2]
        rel_coord[:, :, 1] *= feat.shape[-1]
        inp = torch.cat([q_feat, rel_coord], dim=-1) #[N,SR_H*SR_W,C*9+2]

        rel_cell = cell.clone()
        rel_cell[:, :, 0] *= feat.shape[-2]
        rel_cell[:, :, 1] *= feat.shape[-1]
        inp = torch.cat([inp, rel_cell], dim=-1) #[N,SR_H*SR_W,C*9+2+2]

        bs, q = coord.shape[:2] #bs=N q=SR_H*SR_W
        #[N*SR_H*SR_W,C*9+2+2] --> [N*SR_H*SR_W,3]
        pred = model.imnet(inp.view(bs * q, -1)).view(bs, q, -1) #[N,SR_H*SR_W,3]
        preds.append(pred) #[[N,SR_H*SR_W],[N,SR_H*SR_W],[N,SR_H*SR_W],[N,SR_H*SR_W]]

        area = torch.abs(rel_coord[:, :, 0] * rel_coord[:, :, 1])
        areas.append(area + 1e-9) #[[N,SR_H*SR_W],[N,SR_H*SR_W],[N,SR_H*SR_W],[N,SR_H*SR_W]]

tot_area = torch.stack(areas).sum(dim=0) #[N,SR_H*SR_W]

t = areas[0]; areas[0] = areas[3]; areas[3] = t #swap(areas[0],areas[3])
t = areas[1]; areas[1] = areas[2]; areas[2] = t #swap(areas[1],areas[2])
ret = 0
for pred, area in zip(preds, areas):
    ret = ret + pred * (area / tot_area).unsqueeze(-1)

In [2]:
model_path = r'save\ITCVD_drsenmkcax2\epoch-best.pth'
lr_path = r'E:\Code\Python\datas\RS\ITCVD_patch\ITCVD_test_patchx2\007_0_0x2.jpg'
hr_path = r'E:\Code\Python\datas\RS\ITCVD_patch\ITCVD_test_patch\007_0_0.jpg'
sr_path = r'testimg\007_0_0x2x2.jpg'
scale = 2

In [3]:
from test_x import batched_predict
from models.losses import EdgeLoss

In [5]:
img = Image.open(lr_path)
#img = transforms.Resize((int(img.height/2),int(img.width/2)),Image.BICUBIC)(img)
timg = transforms.ToTensor()(img) #[3,LR_H,LR_W]
model = models.make(torch.load(model_path)['model'], load_sd=True).cuda()
bimg = ((timg - 0.5) / 0.5).cuda().unsqueeze(0)
pred = batched_predict(model, bimg)[0] #[1,SR_H*SR_W,3]
pred = (pred * 0.5 + 0.5).clamp(0, 1).cpu()
transforms.ToPILImage()(pred).save(sr_path)

In [8]:
class EdgeConv(nn.Module):
    def __init__(self):
        super().__init__()
        k = torch.Tensor([[.05, .25, .4, .25, .05]])
        self.kernel = torch.matmul(k.t(), k).unsqueeze(0).repeat(3, 1, 1, 1)
        if torch.cuda.is_available():
            self.kernel = self.kernel.cuda()


    def conv_gauss(self, img):
        n_channels, _, kw, kh = self.kernel.shape
        img = F.pad(img, (kw//2, kh//2, kw//2, kh//2), mode='replicate')
        return F.conv2d(img, self.kernel, groups=n_channels)

    def laplacian_kernel(self, current):
        filtered = self.conv_gauss(current)     # filter
        down = filtered[:, :, ::2, ::2]         # downsample
        new_filter = torch.zeros_like(filtered)
        new_filter[:, :, ::2, ::2] = down*4     # upsample
        filtered = self.conv_gauss(new_filter)  # filter
        diff = current - filtered
        return diff

In [14]:
edge =EdgeConv().cuda()

In [17]:
edgemap = edge.laplacian_kernel(pred.unsqueeze(0).cuda())
transforms.ToPILImage()(edgemap[0]).save('edgemap.jpg')

In [20]:
img = Image.open(r'load\div2k\DIV2K_valid_LR_bicubic\X2\0802x2.png')
#img = transforms.Resize((int(img.height/2),int(img.width/2)),Image.BICUBIC)(img)
timg = transforms.ToTensor()(img) #[3,LR_H,LR_W]
edgemap = edge.laplacian_kernel(timg.unsqueeze(0).cuda())
transforms.ToPILImage()(edgemap[0]).save('0802x2edge.jpg')