In [1]:
import fastai
from fastai import *          # Quick access to most common functionality
from fastai.vision import *   # Quick access to computer vision functionality
import PIL
import pytorch_ssim as ssim


In [2]:
path = Path('/DATA/WAMRI/SALK/uri/Image_restoration_data/')
train_lr = path/'train_LR'
train_hr = path/'train_HR'
test_lr = path/'test_LR'
test_hr = path/'test_HR'


In [3]:
def pull_id(fn):
    return fn.split('#')[-1].split('.')[0]


lr_names_full = list(train_lr.glob('*.tif'))
hr_names_by_id = {pull_id(hrfn.name):hrfn for hrfn in train_hr.glob('*.tif')}

In [4]:
def open_grayscale(fn):
    x = PIL.Image.open(fn)
    return Image(pil2tensor(x,np.float32).div_(255)[0:1])        

valid_pct = 0.10
src = (ImageFileList
      .from_df(pd.DataFrame(lr_names_full), 0)
      .label_from_func(lambda x: hr_names_by_id[pull_id(x.name)])
      .random_split_by_pct(valid_pct))

NameError: name 'ImageFileList' is not defined

In [5]:
class GrayImageToImageDataset(ImageToImageDataset):
    def __init__(self, x:FilePathList=[], y:FilePathList=[], **kwargs):
        super().__init__(x=x,y=y,**kwargs)
        self.image_opener = open_grayscale

def get_sr_transforms():
    res = []
    res.append(dihedral_affine(p=0.5))
    return (res, [])


def get_data(src, bs, sz_lr, scale=4, tfms=None, **kwargs):
    sz_hr = sz_lr*scale
    salk_stats = ( [0.10], [0.20])
    if tfms is None: tfms = get_transforms() 
    data = (src.datasets(GrayImageToImageDataset)
            #.transform(get_sr_transforms(), y_kwargs={'size': sz_hr}, size=sz_lr, tfm_y=True)
            .transform(tfms, y_kwargs={'size': sz_hr}, size=sz_lr, tfm_y=True)
            .databunch(bs=bs, **kwargs) #, num_workers=0)
            .normalize(salk_stats, tfm_y=True)
           )
    return data

NameError: name 'ImageToImageDataset' is not defined

In [None]:
def conv(ni, nf, kernel_size=3, actn=True):
    layers = [nn.Conv2d(ni, nf, kernel_size, padding=kernel_size//2)]
    if actn: layers.append(nn.ReLU(True))
    return nn.Sequential(*layers)

In [None]:
class ResSequential(nn.Module):
    def __init__(self, layers, res_scale=1.0):
        super().__init__()
        self.res_scale = res_scale
        self.m = nn.Sequential(*layers)

    def forward(self, x):
        x = x + self.m(x) * self.res_scale
        return x

In [None]:
def res_block(nf):
    return ResSequential(
        [conv(nf, nf), conv(nf, nf, actn=False)],
        0.1)

In [None]:
def upsample(ni, nf, scale):
    layers = []
    for i in range(int(math.log(scale,2))):
        layers += [conv(ni, nf*4), nn.PixelShuffle(2)]
    return nn.Sequential(*layers)

In [None]:
class SrResnet(nn.Module):
    def __init__(self, n_feats, n_res, n_colors, scale ):
        super().__init__()
        features = [conv(n_colors, n_feats)]
        for i in range(n_res): features.append(res_block(n_feats))
        features += [conv(n_feats,n_feats), upsample(n_feats, n_feats, scale),
                     nn.BatchNorm2d(n_feats),
                     conv(n_feats, n_colors, actn=False)]
        self.features = nn.Sequential(*features)
        
    def forward(self, x): return self.features(x)

In [None]:
def psnr(pred, targs):
    mse = F.mse_loss(pred, targs)
    return 20 * torch.log10(1./torch.sqrt(mse))

def psnr_loss(pred, targs):
    mse = F.mse_loss(pred, targs)
    return -20 * torch.log10(1./torch.sqrt(mse))

ssim_loss = ssim.SSIM(mult=-1.)
ssim_loss_2 = ssim.SSIM(window_size=3, mult=-1.)
ssim_loss_3 = ssim.SSIM(window_size=32, mult=-1.)

def combo_loss(pred, targs):
    return (3 + 
            ssim_loss(pred, targs) + 
            ssim_loss_2(pred, targs) + 
            ssim_loss_3(pred, targs) + 
            F.mse_loss(pred, targs)
           )
    #return ssim_loss(pred, targs) + psnr_loss(pred, targs)/50. # + F.l1_loss(pred, targs)
    
metrics = [F.mse_loss, 
           ssim.ssim,
           partial(ssim.ssim, window_size=3),
           partial(ssim.ssim, window_size=32),
           psnr]

In [None]:
bs = 16
lr_sz = 32
scale=4
data = get_data(src, bs, lr_sz, scale=scale)

In [None]:
x,y = next(iter(data.valid_dl))
x.shape, y.shape
#data.show_batch(2)

In [None]:
n_feats = 8
n_res = 64
n_color = 1
# 64,4,1
model = SrResnet(n_feats, n_res, n_color, scale)
model = nn.DataParallel(model).cuda()

In [None]:
learn = ImageLearner(data, model, loss_func=combo_loss, metrics=metrics)

In [None]:
learn.lr_find()
learn.recorder.plot()

In [None]:
learn = ImageLearner(data, model, loss_func=combo_loss, metrics=metrics)
lr = 1e-2

In [None]:
learn.fit_one_cycle(10, lr)

In [None]:
learn.save('enhance1')

In [None]:
learn.load('enhance1')
learn.fit_one_cycle(5, lr/10)

In [None]:
learn.recorder.plot_losses()

In [None]:
learn.save('enhance1.1')

In [None]:
bs = 16
lr_sz = 32*2
scale=4
data = get_data(src, bs, lr_sz, scale=scale)
learn = ImageLearner(data, model, loss_func=combo_loss, metrics=metrics)
learn = learn.load('enhance1.1')

In [None]:
lr = 1e-3
learn.fit_one_cycle(10, lr)

In [None]:
learn.save('enhance2')

In [None]:
learn.recorder.plot_losses()

In [None]:
bs = 8
lr_sz = 32*2*2
scale=4
data = get_data(src, bs, lr_sz, scale=scale)
learn = ImageLearner(data, model, loss_func=combo_loss, metrics=metrics)
learn = learn.load('enhance2')
lr = 1e-2
learn.fit_one_cycle(10, lr)

In [None]:
learn.save('enhance3')

In [None]:
bs = 8
lr_sz = 503
scale=4
data = get_data(src, bs, lr_sz, scale=scale)
learn = ImageLearner(data, model, loss_func=combo_loss, metrics=metrics)
learn = learn.load('enhance1')

#lr = 1e-3
#learn.fit_one_cycle(1, lr)

idx = 0
m_bilin = partial(nn.functional.interpolate, scale_factor=4, mode='bilinear', align_corners=True)
y_preds, ys = learn.get_preds(DatasetType.Valid)
y_pred = y_preds[idx]
y = ys[idx]

#x,y = dl.dataset[idx]
#xn, yn = learn.data.norm((x.data,y.data))
#pred = dl.reconstruct_output(preds[idx], x)
#if learn.data.denorm and learn.data.tfm_y and isinstance(pred, Image):
#    pred = Image(learn.data.denorm(pred.data[0]))
#pred_bilin = Image(learn.data.denorm(m_bilin(xn[None]))[0])


In [None]:
y_denorm = y_pred.mul(0.20) + tensor(0.10)

In [None]:
Image(y_denorm)

In [None]:
x.data.mean()

In [None]:
#Image(pred_bilin.data[0][None])

In [None]:
imgs = pred_bilin.data[None], y.data[None]
ssim.ssim(*imgs),psnr(*imgs), F.mse_loss(*imgs)

In [None]:
imgs = pred.data[None], y.data[None]
ssim.ssim(*imgs),psnr(*imgs), F.mse_loss(*imgs)

In [None]:
hr_name

In [None]:
test_lr_fns = list((path/'newman').glob('*.tif'))
test_hr_fns = list((path/'newman').glob('*.tif'))
test_lr_fns.sort()
test_hr_fns.sort()
bs = 3
lr_sz = 512
scale=4
test_lr_fns

In [None]:
hr_name = { lr.name:hr for lr,hr in zip(test_lr_fns, test_hr_fns)}

test_src = (ImageFileList
            .from_df(pd.DataFrame(test_lr_fns), 0)
            .label_from_func(lambda x: hr_name[x.name])
            .split_by_valid_func(lambda x: False))

test_data = get_data(test_src, bs, lr_sz, scale=scale, tfms=[[crop_pad()],[crop_pad()]], shuffle=False)
test_learn = ImageLearner(test_data, model, loss_func=combo_loss, metrics=metrics)
test_learn = test_learn.load('enhance3')

In [None]:
def my_denorm(xn):
    salk_stats = ( [0.10], [0.20])
    return xn * tensor(salk_stats[1][0]) - tensor(salk_stats[0][0])

x,y = next(iter(test_learn.data.train_dl))

In [None]:
xdn = my_denorm(x).detach()
ydn = my_denorm(y).detach()
y_pred = test_learn.model(x)
y_pred_dn = my_denorm(y_pred).detach()

In [None]:
test_learn.data.denorm(x)

In [None]:
pred_imgs = []
for i in range(y_pred_dn.shape[0]):
    x_img = Image(xdn[i]*255)
    y_img = Image(ydn[i]*255)
    y_pred_img = Image(y_pred_dn[i])
    pred_imgs.append((x_img, y_img, y_pred_img))
    print(i)
    x_img.save(f'gen_tiffs/{i+1}_x.tif')
    y_img.save(f'gen_tiffs/{i+1}_y.tif')
    y_pred_img.save(f'gen_tiffs/{i+1}_sr.tif')

In [None]:
x_img.data.mean()

In [None]:
pred_imgs[0][2]

In [None]:
idx = 2
x,y = test_learn.data.train_ds[idx]
xn, yn = test_learn.data.norm((x.data,y.data))
test_pred = test_dl.reconstruct_output(test_preds[idx], x)
if test_learn.data.denorm and test_learn.data.tfm_y and isinstance(test_pred, Image):
    test_pred = Image(test_learn.data.denorm(test_pred.data[0]))
m_bilin = partial(nn.functional.interpolate, scale_factor=4, mode='bilinear', align_corners=True)
test_pred_bilin = Image(test_learn.data.denorm(m_bilin(xn[None]))[0])

In [None]:
y

In [None]:
Image(test_pred.data[0][None])

In [None]:
x

In [None]:
Image(test_pred_bilin.data[0][None])

In [None]:
imgs = test_pred.data[None], y.data[None]
ssim.ssim(*imgs),psnr(*imgs), F.mse_loss(*imgs)

In [None]:
imgs = test_pred_bilin.data[None], y.data[None]
ssim.ssim(*imgs),psnr(*imgs), F.mse_loss(*imgs)

In [None]:
x.save('lr_orig.tif')
y.save('hr_orig.tif')
Image(pred_bilin.data[0][None]).save('bilin.tif')
Image(pred.data[0][None]).save('resnet.tif')

In [None]:
503*4