<a href="https://colab.research.google.com/github/Dekelv/ESPCN-pytorch/blob/develop/colab_espcn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
import math
import pandas as pd
import torch
from torch import nn

'GPU connected' if torch.cuda.is_available() else 'GPU connection failed'

'GPU connected'

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/drive


In [0]:
# Based on the explanation given in the paper,
# Number of layers (l) = 3
# kernel, i/p (f_i, n_i) where (5,64) -> (3, 32) -> 3
# GELU paper https://arxiv.org/pdf/1606.08415v3.pdf
# GELU incorporates regularisation (dropout) inherently. It demonstrates improvements in Computer Vision tasks.

class ESPCN(nn.Module):
    def __init__(self, scale_factor, num_channels=1):
        super(ESPCN, self).__init__()
        self.first_part = nn.Sequential(
            nn.Conv2d(num_channels, 64, kernel_size=5, padding=5//2),
            # used GeLU as activation function PSNR -> 32.99
            nn.GELU(),
            nn.Conv2d(64, 32, kernel_size=3, padding=3//2),
            nn.GELU(),
        )
        # pixel shuffle is basically up-sampling the data from LR -> HR
        self.last_part = nn.Sequential(
            nn.Conv2d(32, num_channels * (scale_factor ** 2), kernel_size=3, padding=3 // 2),
            nn.PixelShuffle(scale_factor)
        )

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                if m.in_channels == 32:
                    nn.init.normal_(m.weight.data, mean=0.0, std=0.001)
                    nn.init.zeros_(m.bias.data)
                else:
                    nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
                    nn.init.zeros_(m.bias.data)

    def forward(self, x):
        x = self.first_part(x)
        x = self.last_part(x)
        return x




In [0]:
import h5py
import numpy as np
from torch.utils.data import Dataset


class TrainDataset(Dataset):
    def __init__(self, h5_file):
        super(TrainDataset, self).__init__()
        self.h5_file = h5_file

    # returns the data points with 'lr' as inputs and 'hr' as labels
    def __getitem__(self, idx):
        with h5py.File(self.h5_file, 'r') as f:
            return np.expand_dims(f['lr'][idx] / 255., 0), np.expand_dims(f['hr'][idx] / 255., 0)

    def __len__(self):
        with h5py.File(self.h5_file, 'r') as f:
            return len(f['lr'])


class EvalDataset(Dataset):
    def __init__(self, h5_file):
        super(EvalDataset, self).__init__()
        self.h5_file = h5_file

    def __getitem__(self, idx):
        with h5py.File(self.h5_file, 'r') as f:
            return np.expand_dims(f['lr'][str(idx)][:, :] / 255., 0), np.expand_dims(f['hr'][str(idx)][:, :] / 255., 0)

    def __len__(self):
        with h5py.File(self.h5_file, 'r') as f:
            return len(f['lr'])


In [0]:
import torch
import numpy as np


def calc_patch_size(func):
    def wrapper(args):
        if args.scale == 2:
            args.patch_size = 10
        elif args.scale == 3:
            args.patch_size = 7
        elif args.scale == 4:
            args.patch_size = 6
        else:
            raise Exception('Scale Error', args.scale)
        return func(args)
    return wrapper


def convert_rgb_to_y(img, dim_order='hwc'):
    if dim_order == 'hwc':
        return 16. + (64.738 * img[..., 0] + 129.057 * img[..., 1] + 25.064 * img[..., 2]) / 256.
    else:
        return 16. + (64.738 * img[0] + 129.057 * img[1] + 25.064 * img[2]) / 256.


def convert_rgb_to_ycbcr(img, dim_order='hwc'):
    if dim_order == 'hwc':
        y = 16. + (64.738 * img[..., 0] + 129.057 * img[..., 1] + 25.064 * img[..., 2]) / 256.
        cb = 128. + (-37.945 * img[..., 0] - 74.494 * img[..., 1] + 112.439 * img[..., 2]) / 256.
        cr = 128. + (112.439 * img[..., 0] - 94.154 * img[..., 1] - 18.285 * img[..., 2]) / 256.
    else:
        y = 16. + (64.738 * img[0] + 129.057 * img[1] + 25.064 * img[2]) / 256.
        cb = 128. + (-37.945 * img[0] - 74.494 * img[1] + 112.439 * img[2]) / 256.
        cr = 128. + (112.439 * img[0] - 94.154 * img[1] - 18.285 * img[2]) / 256.
    return np.array([y, cb, cr]).transpose([1, 2, 0])


def convert_ycbcr_to_rgb(img, dim_order='hwc'):
    if dim_order == 'hwc':
        r = 298.082 * img[..., 0] / 256. + 408.583 * img[..., 2] / 256. - 222.921
        g = 298.082 * img[..., 0] / 256. - 100.291 * img[..., 1] / 256. - 208.120 * img[..., 2] / 256. + 135.576
        b = 298.082 * img[..., 0] / 256. + 516.412 * img[..., 1] / 256. - 276.836
    else:
        r = 298.082 * img[0] / 256. + 408.583 * img[2] / 256. - 222.921
        g = 298.082 * img[0] / 256. - 100.291 * img[1] / 256. - 208.120 * img[2] / 256. + 135.576
        b = 298.082 * img[0] / 256. + 516.412 * img[1] / 256. - 276.836
    return np.array([r, g, b]).transpose([1, 2, 0])


def preprocess(img, device):
    img = np.array(img).astype(np.float32)
    ycbcr = convert_rgb_to_ycbcr(img)
    x = ycbcr[..., 0]
    x /= 255.
    x = torch.from_numpy(x).to(device)
    x = x.unsqueeze(0).unsqueeze(0)
    return x, ycbcr


def calc_psnr(img1, img2):
    return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2))


class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


In [0]:
import argparse
import os
import copy

import torch
from torch import nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm

# from models import ESPCN
# from datasets import TrainDataset, EvalDataset
# from utils import AverageMeter, calc_psnr


# if __name__ == '__main__':
    # parser = argparse.ArgumentParser()
    # parser.add_argument('--train-file', type=str, required=True)
    # parser.add_argument('--eval-file', type=str, required=True)
    # parser.add_argument('--outputs-dir', type=str, required=True)
    # parser.add_argument('--weights-file', type=str)
    # parser.add_argument('--scale', type=int, default=3)
    # parser.add_argument('--lr', type=float, default=1e-3)
    # parser.add_argument('--batch-size', type=int, default=16)
    # parser.add_argument('--num-epochs', type=int, default=200)
    # parser.add_argument('--num-workers', type=int, default=8)
    # parser.add_argument('--seed', type=int, default=123)
    # args = parser.parse_args()
def train(args):
  # this is where the x3 comes from
    # args.outputs_dir = os.path.join(args.outputs_dir, 'x{}'.format(args.scale))

# check for drive path
    # if not os.path.exists(args.outputs_dir):
    #     os.makedirs(args.outputs_dir)

    cudnn.benchmark = True
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    torch.manual_seed(args.seed)
    # Init ESPCN model from models.py
    model = ESPCN(scale_factor=args.scale).to(device)
    criterion = nn.MSELoss()
    # lr is learning rate... where do they get this calculation for the last part from?
    optimizer = optim.Adam([
        {'params': model.first_part.parameters()},
        {'params': model.last_part.parameters(), 'lr': args.lr * 0.1}
    ], lr=args.lr)

    # Brings the training data and does some processing
    train_dataset = TrainDataset(args.train_file)
    train_dataloader = DataLoader(dataset=train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)
    eval_dataset = EvalDataset(args.eval_file)
    eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1)

    best_weights = copy.deepcopy(model.state_dict())
    best_epoch = 0
    best_psnr = 0.0

    for epoch in range(args.num_epochs):
        # A time decaying learning rate
        for param_group in optimizer.param_groups:
            param_group['lr'] = args.lr * (0.1 ** (epoch // int(args.num_epochs * 0.8)))

        # Train mode
        model.train()
        epoch_losses = AverageMeter()

        with tqdm(total=(len(train_dataset) - len(train_dataset) % args.batch_size), ncols=80) as t:
            t.set_description('epoch: {}/{}'.format(epoch, args.num_epochs - 1))

            for data in train_dataloader:
                inputs, labels = data

                inputs = inputs.to(device)
                labels = labels.to(device)

                # Core Model
                preds = model(inputs)
                loss = criterion(preds, labels)
                # this must update the values with a user defined function
                epoch_losses.update(loss.item(), len(inputs))
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # edit the print string for tqdm
                t.set_postfix(loss='{:.6f}'.format(epoch_losses.avg))
                t.update(len(inputs))

        torch.save(model.state_dict(), os.path.join(args.outputs_dir, 'epoch_{}.pth'.format(epoch)))

        model.eval()
        epoch_psnr = AverageMeter()

        for data in eval_dataloader:
            inputs, labels = data

            inputs = inputs.to(device)
            labels = labels.to(device)

            with torch.no_grad():
                preds = model(inputs).clamp(0.0, 1.0)

            epoch_psnr.update(calc_psnr(preds, labels), len(inputs))

        print('eval psnr: {:.2f}'.format(epoch_psnr.avg))

        if epoch_psnr.avg > best_psnr:
            best_epoch = epoch
            best_psnr = epoch_psnr.avg
            best_weights = copy.deepcopy(model.state_dict())

    print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))
    torch.save(best_weights, os.path.join(args.outputs_dir, 'best.pth'))


In [10]:
# Train the model
from typing import NamedTuple

class Train_args(NamedTuple):
    train_file: str
    eval_file: str
    outputs_dir: str
    weights_file: str
    scale: int
    lr: float
    batch_size: int
    num_epochs: int
    num_workers: int
    seed: int

train_args = Train_args("./drive/My Drive/espcn/91-image_x3.h5", "./drive/My Drive/espcn/Set5_x3.h5","./drive/My Drive/espcn/outputs",'',3,1e-3, 16, 200, 8, 123)
# train_args.batch_size
train(train_args)

epoch: 0/199: : 2701it [00:14, 191.10it/s, loss=0.024260]                       
epoch: 1/199:   0%|                                    | 0/2688 [00:00<?, ?it/s]

eval psnr: 25.67


epoch: 1/199: : 2701it [00:04, 668.46it/s, loss=0.004272]                       
epoch: 2/199:   0%|                                    | 0/2688 [00:00<?, ?it/s]

eval psnr: 27.58


epoch: 2/199: : 2701it [00:04, 672.71it/s, loss=0.003181]                       
epoch: 3/199:   0%|                                    | 0/2688 [00:00<?, ?it/s]

eval psnr: 28.76


epoch: 3/199: : 2701it [00:04, 656.91it/s, loss=0.002751]                       
epoch: 4/199:   0%|                                    | 0/2688 [00:00<?, ?it/s]

eval psnr: 29.26


epoch: 4/199: : 2701it [00:04, 667.51it/s, loss=0.002468]                       
epoch: 5/199:   0%|                                    | 0/2688 [00:00<?, ?it/s]

eval psnr: 29.80


epoch: 5/199: : 2701it [00:04, 668.16it/s, loss=0.002246]                       
epoch: 6/199:   0%|                                    | 0/2688 [00:00<?, ?it/s]

eval psnr: 30.25


epoch: 6/199: : 2701it [00:04, 667.34it/s, loss=0.002062]                       
epoch: 7/199:   0%|                                    | 0/2688 [00:00<?, ?it/s]

eval psnr: 30.54


epoch: 7/199: : 2701it [00:04, 672.29it/s, loss=0.001930]                       
epoch: 8/199:   0%|                                    | 0/2688 [00:00<?, ?it/s]

eval psnr: 30.31


epoch: 8/199: : 2701it [00:03, 676.60it/s, loss=0.001852]                       
epoch: 9/199:   0%|                                    | 0/2688 [00:00<?, ?it/s]

eval psnr: 31.08


epoch: 9/199: : 2701it [00:04, 669.34it/s, loss=0.001761]                       
epoch: 10/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 31.17


epoch: 10/199: : 2701it [00:04, 663.16it/s, loss=0.001686]                      
epoch: 11/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 31.34


epoch: 11/199: : 2701it [00:04, 662.71it/s, loss=0.001639]                      
epoch: 12/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 31.40


epoch: 12/199: : 2701it [00:04, 673.77it/s, loss=0.001617]                      
epoch: 13/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 31.46


epoch: 13/199: : 2701it [00:04, 660.74it/s, loss=0.001651]                      
epoch: 14/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 31.59


epoch: 14/199: : 2701it [00:04, 657.59it/s, loss=0.001544]                      
epoch: 15/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 31.64


epoch: 15/199: : 2701it [00:04, 653.32it/s, loss=0.001527]                      
epoch: 16/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 31.62


epoch: 16/199: : 2701it [00:04, 660.06it/s, loss=0.001516]                      
epoch: 17/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 31.71


epoch: 17/199: : 2701it [00:04, 662.74it/s, loss=0.001554]                      
epoch: 18/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 31.76


epoch: 18/199: : 2701it [00:04, 661.65it/s, loss=0.001488]                      
epoch: 19/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 31.79


epoch: 19/199: : 2701it [00:04, 664.71it/s, loss=0.001530]                      
epoch: 20/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 31.82


epoch: 20/199: : 2701it [00:04, 659.85it/s, loss=0.001472]                      
epoch: 21/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 31.81


epoch: 21/199: : 2701it [00:04, 666.59it/s, loss=0.001468]                      
epoch: 22/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 31.84


epoch: 22/199: : 2701it [00:04, 660.91it/s, loss=0.001466]                      
epoch: 23/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 31.71


epoch: 23/199: : 2701it [00:04, 660.95it/s, loss=0.001479]                      
epoch: 24/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 31.83


epoch: 24/199: : 2701it [00:04, 663.67it/s, loss=0.001455]                      
epoch: 25/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 31.82


epoch: 25/199: : 2701it [00:04, 669.71it/s, loss=0.001455]                      
epoch: 26/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 31.79


epoch: 26/199: : 2701it [00:04, 662.17it/s, loss=0.001476]
epoch: 27/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 31.88


epoch: 27/199: : 2701it [00:04, 647.06it/s, loss=0.001447]                      
epoch: 28/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 31.86


epoch: 28/199: : 2701it [00:04, 643.66it/s, loss=0.001450]                      
epoch: 29/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 31.89


epoch: 29/199: : 2701it [00:04, 652.43it/s, loss=0.001438]                      
epoch: 30/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 31.83


epoch: 30/199: : 2701it [00:04, 657.66it/s, loss=0.001447]                      
epoch: 31/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 31.65


epoch: 31/199: : 2701it [00:04, 657.10it/s, loss=0.001436]                      
epoch: 32/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 31.85


epoch: 32/199: : 2701it [00:04, 652.06it/s, loss=0.001460]                      
epoch: 33/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 31.95


epoch: 33/199: : 2701it [00:04, 652.33it/s, loss=0.001429]                      
epoch: 34/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 31.92


epoch: 34/199: : 2701it [00:04, 656.21it/s, loss=0.001425]                      
epoch: 35/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 31.96


epoch: 35/199: : 2701it [00:04, 668.00it/s, loss=0.001426]                      
epoch: 36/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 31.96


epoch: 36/199: : 2701it [00:04, 667.86it/s, loss=0.001445]                      
epoch: 37/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 31.81


epoch: 37/199: : 2701it [00:04, 660.64it/s, loss=0.001421]                      
epoch: 38/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 31.95


epoch: 38/199: : 2701it [00:04, 659.73it/s, loss=0.001417]                      
epoch: 39/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 31.99


epoch: 39/199: : 2701it [00:04, 655.37it/s, loss=0.001423]                      
epoch: 40/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 31.75


epoch: 40/199: : 2701it [00:04, 666.49it/s, loss=0.001439]                      
epoch: 41/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.00


epoch: 41/199: : 2701it [00:04, 669.84it/s, loss=0.001409]                      
epoch: 42/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.00


epoch: 42/199: : 2701it [00:04, 659.37it/s, loss=0.001408]                      
epoch: 43/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 31.92


epoch: 43/199: : 2701it [00:04, 667.06it/s, loss=0.001409]                      
epoch: 44/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 31.83


epoch: 44/199: : 2701it [00:04, 662.43it/s, loss=0.001403]                      
epoch: 45/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.04


epoch: 45/199: : 2701it [00:04, 658.61it/s, loss=0.001406]                      
epoch: 46/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.09


epoch: 46/199: : 2701it [00:04, 646.13it/s, loss=0.001397]                      
epoch: 47/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.11


epoch: 47/199: : 2701it [00:04, 642.27it/s, loss=0.001391]                      
epoch: 48/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.14


epoch: 48/199: : 2701it [00:04, 647.25it/s, loss=0.001386]                      
epoch: 49/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.15


epoch: 49/199: : 2701it [00:04, 651.79it/s, loss=0.001380]                      
epoch: 50/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.17


epoch: 50/199: : 2701it [00:04, 593.88it/s, loss=0.001392]                      
epoch: 51/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.14


epoch: 51/199: : 2701it [00:04, 663.69it/s, loss=0.001369]                      
epoch: 52/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.21


epoch: 52/199: : 2701it [00:04, 661.65it/s, loss=0.001367]                      
epoch: 53/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.22


epoch: 53/199: : 2701it [00:04, 658.21it/s, loss=0.001370]                      
epoch: 54/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.21


epoch: 54/199: : 2701it [00:04, 661.60it/s, loss=0.001358]                      
epoch: 55/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.29


epoch: 55/199: : 2701it [00:04, 657.17it/s, loss=0.001354]                      
epoch: 56/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.29


epoch: 56/199: : 2701it [00:04, 662.69it/s, loss=0.001353]                      
epoch: 57/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.32


epoch: 57/199: : 2701it [00:04, 669.94it/s, loss=0.001348]                      
epoch: 58/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 31.88


epoch: 58/199: : 2701it [00:04, 660.08it/s, loss=0.001353]                      
epoch: 59/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.36


epoch: 59/199: : 2701it [00:04, 650.03it/s, loss=0.001339]                      
epoch: 60/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.39


epoch: 60/199: : 2701it [00:04, 659.08it/s, loss=0.001347]                      
epoch: 61/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.35


epoch: 61/199: : 2701it [00:04, 660.59it/s, loss=0.001339]                      
epoch: 62/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.28


epoch: 62/199: : 2701it [00:04, 665.42it/s, loss=0.001334]                      
epoch: 63/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.37


epoch: 63/199: : 2701it [00:04, 662.74it/s, loss=0.001326]                      
epoch: 64/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.44


epoch: 64/199: : 2701it [00:04, 659.62it/s, loss=0.001327]                      
epoch: 65/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.48


epoch: 65/199: : 2701it [00:04, 662.24it/s, loss=0.001327]                      
epoch: 66/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.48


epoch: 66/199: : 2701it [00:04, 663.79it/s, loss=0.001317]                      
epoch: 67/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.49


epoch: 67/199: : 2701it [00:04, 661.86it/s, loss=0.001320]                      
epoch: 68/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.52


epoch: 68/199: : 2701it [00:04, 651.78it/s, loss=0.001312]                      
epoch: 69/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.55


epoch: 69/199: : 2701it [00:04, 649.49it/s, loss=0.001318]                      
epoch: 70/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.52


epoch: 70/199: : 2701it [00:04, 662.74it/s, loss=0.001305]                      
epoch: 71/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.53


epoch: 71/199: : 2701it [00:04, 659.18it/s, loss=0.001309]                      
epoch: 72/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.52


epoch: 72/199: : 2701it [00:04, 658.65it/s, loss=0.001299]                      
epoch: 73/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.60


epoch: 73/199: : 2701it [00:04, 659.66it/s, loss=0.001296]                      
epoch: 74/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.59


epoch: 74/199: : 2701it [00:04, 654.67it/s, loss=0.001295]                      
epoch: 75/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.49


epoch: 75/199: : 2701it [00:04, 661.01it/s, loss=0.001298]                      
epoch: 76/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.55


epoch: 76/199: : 2701it [00:04, 657.73it/s, loss=0.001290]                      
epoch: 77/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.62


epoch: 77/199: : 2701it [00:04, 665.81it/s, loss=0.001297]                      
epoch: 78/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.56


epoch: 78/199: : 2701it [00:04, 659.33it/s, loss=0.001286]                      
epoch: 79/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.34


epoch: 79/199: : 2701it [00:04, 663.81it/s, loss=0.001284]                      
epoch: 80/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.56


epoch: 80/199: : 2701it [00:04, 661.04it/s, loss=0.001290]                      
epoch: 81/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.68


epoch: 81/199: : 2701it [00:04, 662.51it/s, loss=0.001283]                      
epoch: 82/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.40


epoch: 82/199: : 2701it [00:04, 669.80it/s, loss=0.001284]                      
epoch: 83/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.70


epoch: 83/199: : 2701it [00:04, 664.45it/s, loss=0.001273]                      
epoch: 84/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.72


epoch: 84/199: : 2701it [00:04, 663.33it/s, loss=0.001277]                      
epoch: 85/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.68


epoch: 85/199: : 2701it [00:04, 665.35it/s, loss=0.001270]                      
epoch: 86/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.71


epoch: 86/199: : 2701it [00:04, 659.21it/s, loss=0.001274]                      
epoch: 87/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.69


epoch: 87/199: : 2701it [00:04, 647.39it/s, loss=0.001270]                      
epoch: 88/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.75


epoch: 88/199: : 2701it [00:04, 660.93it/s, loss=0.001267]                      
epoch: 89/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.62


epoch: 89/199: : 2701it [00:04, 658.93it/s, loss=0.001279]                      
epoch: 90/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.74


epoch: 90/199: : 2701it [00:04, 655.30it/s, loss=0.001264]                      
epoch: 91/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.69


epoch: 91/199: : 2701it [00:04, 663.37it/s, loss=0.001262]                      
epoch: 92/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.75


epoch: 92/199: : 2701it [00:04, 667.05it/s, loss=0.001266]                      
epoch: 93/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 31.42


epoch: 93/199: : 2701it [00:04, 667.88it/s, loss=0.001264]                      
epoch: 94/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.74


epoch: 94/199: : 2701it [00:04, 665.93it/s, loss=0.001258]                      
epoch: 95/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.76


epoch: 95/199: : 2701it [00:04, 660.37it/s, loss=0.001260]                      
epoch: 96/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.50


epoch: 96/199: : 2701it [00:04, 667.67it/s, loss=0.001272]
epoch: 97/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.71


epoch: 97/199: : 2701it [00:04, 663.22it/s, loss=0.001254]                      
epoch: 98/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.75


epoch: 98/199: : 2701it [00:04, 660.52it/s, loss=0.001257]                      
epoch: 99/199:   0%|                                   | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.64


epoch: 99/199: : 2701it [00:04, 665.93it/s, loss=0.001253]                      
epoch: 100/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.76


epoch: 100/199: : 2701it [00:04, 653.77it/s, loss=0.001256]                     
epoch: 101/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.78


epoch: 101/199: : 2701it [00:04, 666.11it/s, loss=0.001259]                     
epoch: 102/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.76


epoch: 102/199: : 2701it [00:04, 658.11it/s, loss=0.001250]                     
epoch: 103/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.75


epoch: 103/199: : 2701it [00:04, 664.27it/s, loss=0.001253]                     
epoch: 104/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.81


epoch: 104/199: : 2701it [00:04, 664.83it/s, loss=0.001252]                     
epoch: 105/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.80


epoch: 105/199: : 2701it [00:04, 660.83it/s, loss=0.001261]                     
epoch: 106/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.77


epoch: 106/199: : 2701it [00:04, 655.40it/s, loss=0.001248]                     
epoch: 107/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.80


epoch: 107/199: : 2701it [00:04, 663.96it/s, loss=0.001246]                     
epoch: 108/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.81


epoch: 108/199: : 2701it [00:04, 658.80it/s, loss=0.001250]                     
epoch: 109/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.81


epoch: 109/199: : 2701it [00:04, 659.10it/s, loss=0.001248]                     
epoch: 110/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.79


epoch: 110/199: : 2701it [00:04, 668.73it/s, loss=0.001245]                     
epoch: 111/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.72


epoch: 111/199: : 2701it [00:04, 665.48it/s, loss=0.001247]                     
epoch: 112/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.86


epoch: 112/199: : 2701it [00:04, 660.30it/s, loss=0.001242]                     
epoch: 113/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.80


epoch: 113/199: : 2701it [00:04, 665.02it/s, loss=0.001257]                     
epoch: 114/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.85


epoch: 114/199: : 2701it [00:04, 662.39it/s, loss=0.001240]                     
epoch: 115/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.84


epoch: 115/199: : 2701it [00:04, 656.28it/s, loss=0.001244]                     
epoch: 116/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.84


epoch: 116/199: : 2701it [00:04, 662.32it/s, loss=0.001242]                     
epoch: 117/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.79


epoch: 117/199: : 2701it [00:04, 666.22it/s, loss=0.001240]                     
epoch: 118/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.81


epoch: 118/199: : 2701it [00:04, 668.05it/s, loss=0.001242]                     
epoch: 119/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.80


epoch: 119/199: : 2701it [00:04, 659.94it/s, loss=0.001241]                     
epoch: 120/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.75


epoch: 120/199: : 2701it [00:04, 661.27it/s, loss=0.001238]                     
epoch: 121/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.74


epoch: 121/199: : 2701it [00:04, 640.15it/s, loss=0.001238]                     
epoch: 122/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.78


epoch: 122/199: : 2701it [00:04, 636.02it/s, loss=0.001248]                     
epoch: 123/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.86


epoch: 123/199: : 2701it [00:04, 654.78it/s, loss=0.001235]                     
epoch: 124/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.83


epoch: 124/199: : 2701it [00:04, 665.29it/s, loss=0.001237]                     
epoch: 125/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.57


epoch: 125/199: : 2701it [00:04, 666.54it/s, loss=0.001238]                     
epoch: 126/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.81


epoch: 126/199: : 2701it [00:04, 657.22it/s, loss=0.001234]                     
epoch: 127/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.88


epoch: 127/199: : 2701it [00:04, 649.59it/s, loss=0.001237]                     
epoch: 128/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.88


epoch: 128/199: : 2701it [00:04, 670.73it/s, loss=0.001236]                     
epoch: 129/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.82


epoch: 129/199: : 2701it [00:04, 652.77it/s, loss=0.001237]                     
epoch: 130/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.82


epoch: 130/199: : 2701it [00:04, 654.34it/s, loss=0.001234]                     
epoch: 131/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.84


epoch: 131/199: : 2701it [00:04, 666.59it/s, loss=0.001233]                     
epoch: 132/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.84


epoch: 132/199: : 2701it [00:04, 665.15it/s, loss=0.001231]                     
epoch: 133/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.85


epoch: 133/199: : 2701it [00:04, 670.79it/s, loss=0.001233]                     
epoch: 134/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.75


epoch: 134/199: : 2701it [00:04, 663.84it/s, loss=0.001237]                     
epoch: 135/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.86


epoch: 135/199: : 2701it [00:04, 656.27it/s, loss=0.001230]                     
epoch: 136/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.88


epoch: 136/199: : 2701it [00:04, 667.09it/s, loss=0.001228]                     
epoch: 137/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.91


epoch: 137/199: : 2701it [00:04, 666.11it/s, loss=0.001236]                     
epoch: 138/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.55


epoch: 138/199: : 2701it [00:04, 659.17it/s, loss=0.001230]                     
epoch: 139/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.90


epoch: 139/199: : 2701it [00:04, 668.05it/s, loss=0.001229]                     
epoch: 140/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.89


epoch: 140/199: : 2701it [00:04, 664.31it/s, loss=0.001229]                     
epoch: 141/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.87


epoch: 141/199: : 2701it [00:04, 666.14it/s, loss=0.001235]                     
epoch: 142/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.93


epoch: 142/199: : 2701it [00:04, 656.92it/s, loss=0.001226]                     
epoch: 143/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.91


epoch: 143/199: : 2701it [00:04, 669.56it/s, loss=0.001225]                     
epoch: 144/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.89


epoch: 144/199: : 2701it [00:04, 671.07it/s, loss=0.001226]                     
epoch: 145/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.85


epoch: 145/199: : 2701it [00:04, 665.70it/s, loss=0.001225]                     
epoch: 146/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.73


epoch: 146/199: : 2701it [00:04, 651.39it/s, loss=0.001230]                     
epoch: 147/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.90


epoch: 147/199: : 2701it [00:04, 659.26it/s, loss=0.001224]                     
epoch: 148/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.86


epoch: 148/199: : 2701it [00:04, 670.54it/s, loss=0.001235]                     
epoch: 149/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.90


epoch: 149/199: : 2701it [00:04, 661.09it/s, loss=0.001224]                     
epoch: 150/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.91


epoch: 150/199: : 2701it [00:04, 671.11it/s, loss=0.001223]                     
epoch: 151/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.80


epoch: 151/199: : 2701it [00:04, 659.78it/s, loss=0.001224]                     
epoch: 152/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.87


epoch: 152/199: : 2701it [00:04, 663.39it/s, loss=0.001228]                     
epoch: 153/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.82


epoch: 153/199: : 2701it [00:04, 667.48it/s, loss=0.001222]                     
epoch: 154/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.90


epoch: 154/199: : 2701it [00:04, 671.28it/s, loss=0.001225]                     
epoch: 155/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.80


epoch: 155/199: : 2701it [00:04, 668.96it/s, loss=0.001223]                     
epoch: 156/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.81


epoch: 156/199: : 2701it [00:04, 667.39it/s, loss=0.001223]                     
epoch: 157/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.87


epoch: 157/199: : 2701it [00:04, 658.60it/s, loss=0.001221]                     
epoch: 158/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.91


epoch: 158/199: : 2701it [00:04, 662.74it/s, loss=0.001226]                     
epoch: 159/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.86


epoch: 159/199: : 2701it [00:04, 669.17it/s, loss=0.001220]                     
epoch: 160/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.76


epoch: 160/199: : 2701it [00:04, 667.25it/s, loss=0.001207]                     
epoch: 161/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.98


epoch: 161/199: : 2701it [00:04, 665.96it/s, loss=0.001205]                     
epoch: 162/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.98


epoch: 162/199: : 2701it [00:04, 659.85it/s, loss=0.001204]                     
epoch: 163/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.99


epoch: 163/199: : 2701it [00:04, 667.40it/s, loss=0.001204]                     
epoch: 164/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.98


epoch: 164/199: : 2701it [00:04, 661.09it/s, loss=0.001204]                     
epoch: 165/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.98


epoch: 165/199: : 2701it [00:04, 653.37it/s, loss=0.001204]                     
epoch: 166/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.97


epoch: 166/199: : 2701it [00:04, 657.24it/s, loss=0.001204]                     
epoch: 167/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.98


epoch: 167/199: : 2701it [00:04, 656.55it/s, loss=0.001204]                     
epoch: 168/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.99


epoch: 168/199: : 2701it [00:04, 666.22it/s, loss=0.001204]                     
epoch: 169/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.98


epoch: 169/199: : 2701it [00:04, 661.20it/s, loss=0.001204]                     
epoch: 170/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.99


epoch: 170/199: : 2701it [00:04, 658.51it/s, loss=0.001204]                     
epoch: 171/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.97


epoch: 171/199: : 2701it [00:04, 663.07it/s, loss=0.001204]                     
epoch: 172/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.97


epoch: 172/199: : 2701it [00:04, 660.93it/s, loss=0.001204]                     
epoch: 173/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.98


epoch: 173/199: : 2701it [00:04, 644.50it/s, loss=0.001204]                     
epoch: 174/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.98


epoch: 174/199: : 2701it [00:04, 663.96it/s, loss=0.001204]                     
epoch: 175/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.98


epoch: 175/199: : 2701it [00:04, 662.67it/s, loss=0.001204]                     
epoch: 176/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.99


epoch: 176/199: : 2701it [00:04, 657.45it/s, loss=0.001204]                     
epoch: 177/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.98


epoch: 177/199: : 2701it [00:04, 665.30it/s, loss=0.001204]                     
epoch: 178/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.98


epoch: 178/199: : 2701it [00:04, 666.80it/s, loss=0.001203]                     
epoch: 179/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.98


epoch: 179/199: : 2701it [00:04, 668.54it/s, loss=0.001203]                     
epoch: 180/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.97


epoch: 180/199: : 2701it [00:04, 659.48it/s, loss=0.001204]                     
epoch: 181/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.97


epoch: 181/199: : 2701it [00:04, 670.54it/s, loss=0.001203]                     
epoch: 182/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.99


epoch: 182/199: : 2701it [00:04, 666.04it/s, loss=0.001203]                     
epoch: 183/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.97


epoch: 183/199: : 2701it [00:04, 670.45it/s, loss=0.001203]                     
epoch: 184/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.98


epoch: 184/199: : 2701it [00:04, 667.72it/s, loss=0.001203]                     
epoch: 185/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.98


epoch: 185/199: : 2701it [00:04, 653.99it/s, loss=0.001203]                     
epoch: 186/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.96


epoch: 186/199: : 2701it [00:04, 654.64it/s, loss=0.001203]                     
epoch: 187/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.99


epoch: 187/199: : 2701it [00:04, 663.44it/s, loss=0.001203]                     
epoch: 188/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.97


epoch: 188/199: : 2701it [00:04, 666.62it/s, loss=0.001203]                     
epoch: 189/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.99


epoch: 189/199: : 2701it [00:04, 667.32it/s, loss=0.001203]                     
epoch: 190/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.99


epoch: 190/199: : 2701it [00:04, 661.10it/s, loss=0.001203]                     
epoch: 191/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.98


epoch: 191/199: : 2701it [00:04, 661.63it/s, loss=0.001202]                     
epoch: 192/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.98


epoch: 192/199: : 2701it [00:04, 665.33it/s, loss=0.001203]                     
epoch: 193/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.98


epoch: 193/199: : 2701it [00:04, 665.90it/s, loss=0.001202]                     
epoch: 194/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.98


epoch: 194/199: : 2701it [00:04, 672.35it/s, loss=0.001202]                     
epoch: 195/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.97


epoch: 195/199: : 2701it [00:04, 665.87it/s, loss=0.001202]                     
epoch: 196/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.98


epoch: 196/199: : 2701it [00:04, 639.99it/s, loss=0.001202]                     
epoch: 197/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.97


epoch: 197/199: : 2701it [00:04, 636.05it/s, loss=0.001202]                     
epoch: 198/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.99


epoch: 198/199: : 2701it [00:04, 647.74it/s, loss=0.001202]                     
epoch: 199/199:   0%|                                  | 0/2688 [00:00<?, ?it/s]

eval psnr: 32.99


epoch: 199/199: : 2701it [00:04, 662.88it/s, loss=0.001202]                     

eval psnr: 32.98
best epoch: 189, psnr: 32.99





In [0]:
import argparse

import torch
import torch.backends.cudnn as cudnn
import numpy as np
import PIL.Image as pil_image

# from models import ESPCN
# from utils import convert_ycbcr_to_rgb, preprocess, calc_psnr


# if __name__ == '__main__':
#     parser = argparse.ArgumentParser()
#     parser.add_argument('--weights-file', type=str, required=True)
#     parser.add_argument('--image-file', type=str, required=True)
#     parser.add_argument('--scale', type=int, default=3)
#     args = parser.parse_args()
def test(args):

    cudnn.benchmark = True
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    model = ESPCN(scale_factor=args.scale).to(device)

    state_dict = model.state_dict()
    for n, p in torch.load(args.weights_file, map_location=lambda storage, loc: storage).items():
        if n in state_dict.keys():
            state_dict[n].copy_(p)
        else:
            raise KeyError(n)

    model.eval()

    image = pil_image.open(args.image_file).convert('RGB')

    image_width = (image.width // args.scale) * args.scale
    image_height = (image.height // args.scale) * args.scale

    hr = image.resize((image_width, image_height), resample=pil_image.BICUBIC)
    lr = hr.resize((hr.width // args.scale, hr.height // args.scale), resample=pil_image.BICUBIC)
    bicubic = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC)
    bicubic.save(args.image_file.replace('.', '_bicubic_x{}.'.format(args.scale)))

    lr, _ = preprocess(lr, device)
    hr, _ = preprocess(hr, device)
    _, ycbcr = preprocess(bicubic, device)

    with torch.no_grad():
        preds = model(lr).clamp(0.0, 1.0)

    psnr = calc_psnr(hr, preds)
    print('PSNR: {:.2f}'.format(psnr))

    preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)

    output = np.array([preds, ycbcr[..., 1], ycbcr[..., 2]]).transpose([1, 2, 0])
    output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8)
    output = pil_image.fromarray(output)
    output.save(args.image_file.replace('.', '_espcn_x{}.'.format(args.scale)))


In [0]:
class Test_args(NamedTuple):
    weights_file: str
    image_file: str
    scale: int

test_args = Test_args("./drive/My Drive/espcn/outputs/x3/best.pth", "THE IMAGE FILE YOU WANT TO TEST", 3)
test(test_args)