In [2]:
scale = 3
patch_size = 84
stride = 29

In [None]:
import sys
import h5py
import numpy as np
from PIL import Image
import gc
import os
from utils import convert_rgb_to_ycbcr, convert_ycbcr_to_rgb, convert_rgb_to_y, calc_psnr, ssim

def save_h5(imgs_dir, phase, scale, patch_size, stride, batch_size = 100, cache_size=256):
    print('counting total patch number, please wait')
    total_patch_number = 0
    for img in sorted(os.listdir(imgs_dir)):
        hr = Image.open(os.path.join(imgs_dir, img)).convert('RGB')
        total_patch_number += ((hr.width // scale - patch_size) // stride + 1) * ((hr.height // scale - patch_size) // stride + 1)
    print('total patch number : ' + str(total_patch_number))

    h5f = h5py.File(r'./Datasets/{}_X{}.h5'.format(phase, scale), 'w')
    try:
        # train: 110550 for X4; 221487 for X3; 558822 for X2
        # valid: 14175 for X4; 28161 for X3; 71115 for X2
        # The first dimension of chunks should be an integral multiple of the batch_size during training
        hlset = h5f.create_dataset('lh', (total_patch_number, 2, patch_size, patch_size), maxshape=(None, 2, patch_size, patch_size), dtype='f', chunks = (cache_size, 2, patch_size, patch_size))
        # write to file by batch to avoid OOM
        idx = 0 # starting index of a batch
        patch_number = 0 # patch number in a batch
        total_number = 0 # total number to let me know when to stop training
        batch_size = batch_size # better be a factor of image number
        image_idx_of_batch = 0 # i
        batch = 0
        patches = [] # (patch_idx dimension, lr/hr Y_channel dimension, height dimension, width dimension)
        if len(os.listdir(imgs_dir)) % batch_size != 0:
            print('warning: ' + str(len(os.listdir(imgs_dir)) % batch_size) + \
                  ' images will not be used, check batch_size. len(): ' + str(len(os.listdir(imgs_dir))))
            return
        # maybe it is confusing, but sorted only called once
        print('batch ' + str(batch + 1) + ' processing')
        for img in sorted(os.listdir(imgs_dir)):
            hr = Image.open(os.path.join(imgs_dir, img)).convert('RGB')
            hr_width = hr.width
            hr_height = hr.height
            lr = hr.resize((hr_width // scale, hr_height // scale), resample=Image.BICUBIC)
            hr = np.array(hr).astype(np.float32)
            lr = np.array(lr).astype(np.float32)
            hr = convert_rgb_to_y(hr)
            lr = convert_rgb_to_y(lr)
            for x in range(0, lr.shape[0] - patch_size + 1, stride):
                for y in range(0, lr.shape[1] - patch_size + 1, stride):
                    # Add HR patch and LR patch to patches.
                    # Continued lr and hr are much more efficient for IO than seperate dataset I used before
                    patches.append([np.pad(lr[x // scale:x // scale + patch_size // scale, y // scale:y // scale + patch_size // scale],
                                           ((0, patch_size * (scale - 1) // scale),
                                           (0, patch_size * (scale - 1) // scale))),
                                    hr[x:x + patch_size, y:y + patch_size]])

            if image_idx_of_batch < batch_size - 1:
                image_idx_of_batch += 1
            # write to h5file by batch
            else:
                patch_number = len(patches)
                patches = np.array(patches)
                patches = patches / 255
                print(patches.shape)
                # shuffle all patches in the batch, thus batch_size should be high for a good shuffle
                shuffle_ix = np.random.permutation(np.arange(patch_number))
                patches = patches[shuffle_ix]
                hlset[idx : idx + patch_number] = patches

                del patches
                gc.collect()

                patches = []
                idx += patch_number
                total_number += patch_number
                patch_number = 0
                image_idx_of_batch = 0
                batch += 1
                print('batch:' + str(batch) + ' of ' + str(len(os.listdir(imgs_dir)) // batch_size))
        print(total_number)
    except BaseException as e:
        print(e)
    finally:
        h5f.close()

In [None]:
save_h5(imgs_dir=r'./Datasets/DIV2K_train_HR', phase='train', scale=scale, patch_size=patch_size, stride=stride, batch_size=100, cache_size=256)

In [5]:
import os
import copy
import time
import h5py

import numpy as np
from torch import Tensor
import torch
from torch import nn
import torch.optim as optim

# gpu acc
import torch.backends.cudnn as cudnn

from torch.utils.data.dataloader import DataLoader

from models import FSRCNN, OFSRCNN
from datasets import TrainDataset, EvalDataset
from utils import AverageMeter, calc_psnr, ssim

def train(scale, patch_size, model_name, num_epochs, continue_epoch=0, batch_size=256):

    train_file = r'./Datasets/train_X{}.h5'.format(scale)
    eval_file = r'./Datasets/valid_X{}.h5'.format(scale)
    outputs_dir = r'./outputs'
    log_dir = r'./log'
    lr_1 = 1e-3
    lr_2 = 1e-3
    lr_3 = 1e-3
    lr_4 = 1e-3
    lr_5 = 1e-4

    batch_size = 256
    num_workers = 0
    num_epochs = num_epochs
    seed = 1
    model_name = model_name
    continue_epoch = continue_epoch # will load this epoch weight file to continue

    if not os.path.exists(outputs_dir):
        os.makedirs(outputs_dir)

    # benckmark mode to acc
    cudnn.benchmark = True
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    torch.manual_seed(seed)


    model = FSRCNN(scale=scale).to(device) if model_name == 'FSRCNN' else OFSRCNN(scale=scale).to(device)

    # plt
    lossLog = []
    psnrLog = []
    ssimLog = []

    if continue_epoch != 0:
        model.load_state_dict(torch.load(os.path.join(outputs_dir, '{}_X{}_epoch_{}.pth'.format(model_name, scale, continue_epoch))))
        lossLog = np.loadtxt(os.path.join(log_dir, '{}_X{}_lossLog.txt'.format(model_name, scale)))
        lossLog = lossLog.tolist()
        psnrLog = np.loadtxt(os.path.join(log_dir, '{}_X{}_psnrLog.txt'.format(model_name, scale)))
        psnrLog = psnrLog.tolist()
        ssimLog = np.loadtxt(os.path.join(log_dir, '{}_X{}_ssimLog.txt'.format(model_name, scale)))
        ssimLog = ssimLog.tolist()


    # loss MSE
    criterion = nn.MSELoss()

    # opt
    optimizer = optim.Adam(
        [{'params': model.features.parameters(), 'lr': lr_1},
         {'params': model.shrinking.parameters(), 'lr': lr_2},
         {'params': model.mapping.parameters(), 'lr': lr_3},
         {'params': model.expanding.parameters(), 'lr': lr_4},
         {'params': model.deconv.parameters(), 'lr': lr_5},
        ]) if model_name == 'FSRCNN' else optim.Adam(
        [{'params': model.features.parameters(), 'lr': lr_1},
         {'params': model.shrinking.parameters(), 'lr': lr_2},
         {'params': model.mapping.parameters(), 'lr': lr_3},
         {'params': model.expanding.parameters(), 'lr': lr_4},
         {'params': model.upsample.parameters(), 'lr': lr_5},
        ])

    train_dataset = TrainDataset(h5_file=train_file, patch_size=patch_size, scale=scale)
    train_dataloader = DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        # dataset has already shuffled
        shuffle=False,
        num_workers=num_workers,
    #     pin_memory=True,
        drop_last=True)

    eval_dataset = EvalDataset(h5_file=eval_file, patch_size=patch_size, scale=scale)
    eval_dataloader = DataLoader(
        dataset=eval_dataset,
        batch_size=batch_size,
        shuffle=False,
        drop_last=True)

    # weights copy
    best_psnr = 0.0
    best_epoch = 0
    if continue_epoch != 0:
        model.load_state_dict(torch.load(os.path.join(outputs_dir, '{}_X{}_best.pth'.format(model_name, scale))))
    best_weights = copy.deepcopy(model.state_dict())
    best_psnr = max(psnrLog)
    best_epoch = psnrLog.index(best_psnr) + 1

    # Train
    for epoch in range(continue_epoch + 1, continue_epoch + num_epochs + 1):
        since = time.time()
        print('epoch: '+ str(epoch) + ' at ' + time.strftime("%H:%M:%S"))

        model.train()

        epoch_losses = AverageMeter()

        process = 0
        for data in train_dataloader:
            process += 1
            print('\r', '***training process of epoch {} : {:.2f}%***'.format(epoch, process / len(train_dataset) * batch_size * 100), end='')

            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)

            preds = model(inputs)

            loss = criterion(preds, labels)

            epoch_losses.update(loss.item(), len(inputs))

            optimizer.zero_grad()

            loss.backward()

            optimizer.step()

        lossLog.append(np.array(epoch_losses.avg))
        print('')
        print('train loss: ' + str(epoch_losses.avg))
        np.savetxt(os.path.join(log_dir, '{}_X{}_lossLog.txt'.format(model_name, scale)), lossLog)

        torch.save(model.state_dict(), os.path.join(outputs_dir, '{}_X{}_epoch_{}.pth'.format(model_name, scale, epoch)))

        # PSNR SSIM
        model.eval()
        epoch_psnr = AverageMeter()
        epoch_ssim = AverageMeter()

        process = 0
        for data in eval_dataloader:
            process += 1
            print('\r', '***eval process of epoch {} : {:.2f}%***'.format(epoch, process / len(eval_dataset) * batch_size * 100), end='')
            inputs, labels = data

            inputs = inputs.to(device)
            labels = labels.to(device)

            with torch.no_grad():
                preds = model(inputs).clamp(0.0, 1.0)
    #             preds = model(inputs)

            epoch_psnr.update(calc_psnr(preds, labels), len(inputs))
            epoch_ssim.update(ssim(preds, labels), len(inputs))

        print('')
        print('eval psnr: {:.2f}'.format(epoch_psnr.avg))
        print('eval ssim: {:.2f}'.format(epoch_ssim.avg))

        psnrLog.append(Tensor.cpu(epoch_psnr.avg))
        ssimLog.append(Tensor.cpu(epoch_ssim.avg))
        np.savetxt(os.path.join(log_dir, '{}_X{}_psnrLog.txt'.format(model_name, scale)), psnrLog)
        np.savetxt(os.path.join(log_dir, '{}_X{}_ssimLog.txt'.format(model_name, scale)), ssimLog)

        # update weight
        if epoch_psnr.avg > best_psnr:
            best_epoch = epoch
            best_psnr = epoch_psnr.avg
            best_weights = copy.deepcopy(model.state_dict())

        print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))

        torch.save(best_weights, os.path.join(outputs_dir, '{}_X{}_best.pth'.format(model_name, scale)))

        time_elapsed = time.time() - since
        print('Epoch complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
        print('')

    print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))

    torch.save(best_weights, os.path.join(outputs_dir, '{}_X{}_best.pth'.format(model_name, scale)))

In [7]:
train(scale=scale, patch_size=patch_size, model_name='FSRCNN', num_epochs=20, continue_epoch=100, batch_size=256)

epoch: 101 at 21:50:59
 ***training process of epoch 101 : 99.98%***
train loss: 0.0015020041297878963
 ***eval process of epoch 101 : 100.00%***
eval psnr: 27.45
eval ssim: 0.87
best epoch: 101, psnr: 27.45
Epoch complete in 2m 33s

epoch: 102 at 21:53:32
 ***training process of epoch 102 : 99.98%***
train loss: 0.001493445773827009
 ***eval process of epoch 102 : 100.00%***
eval psnr: 27.45
eval ssim: 0.87
best epoch: 101, psnr: 27.45
Epoch complete in 2m 34s

epoch: 103 at 21:56:06
 ***training process of epoch 103 : 99.98%***
train loss: 0.0014938050972545423
 ***eval process of epoch 103 : 100.00%***
eval psnr: 27.45
eval ssim: 0.87
best epoch: 101, psnr: 27.45
Epoch complete in 2m 37s

epoch: 104 at 21:58:43
 ***training process of epoch 104 : 99.98%***
train loss: 0.0014945881373270528
 ***eval process of epoch 104 : 100.00%***
eval psnr: 27.45
eval ssim: 0.87
best epoch: 101, psnr: 27.45
Epoch complete in 2m 32s

epoch: 105 at 22:01:15
 ***training process of epoch 105 : 99.98%

In [None]:
import torch
import numpy as np
from PIL import Image
import os

from models import OFSRCNN, FSRCNN
from utils import convert_rgb_to_ycbcr, convert_ycbcr_to_rgb, calc_psnr, ssim

scale = 3
model_name = 'FSRCNN'
dataset_name = 'Set5'

def test(scale, model_name, dataset_name):

    weight_dir = os.path.join('./outputs/', '{}_X{}_best.pth'.format(model_name, scale))
    hr_dir = './Datasets/' + dataset_name + '/HR'
    out_dir = './Datasets/' + dataset_name + '/' + model_name

    device = torch.device('cpu')
    model = FSRCNN(scale=scale).to(device) if model_name == 'FSRCNN' else OFSRCNN(scale=scale).to(device)
    model.load_state_dict(torch.load(weight_dir))

    model.eval()

    if not os.path.exists(os.path.join(out_dir, 'X{}'.format(scale))):
        os.makedirs(os.path.join(out_dir, 'X{}'.format(scale)))

    for img in sorted(os.listdir(hr_dir)):
        image = Image.open(os.path.join(hr_dir, img)).convert('RGB')
        image_lr = image.resize((image.width // scale, image.height // scale), resample=Image.BICUBIC)

        # image to ycbcr arr
        image_arr = np.array(image).astype(np.float32)
        ycbcr = convert_rgb_to_ycbcr(image_arr)
        # y of ycbcr
        y = ycbcr[..., 0]
        y /= 255.
        y = torch.from_numpy(y).to(device)
        y = y.unsqueeze(0).unsqueeze(0) # dim expand

        # lr image to ycbcr arr
        image_lr = np.array(image_lr).astype(np.float32)
        ycbcr_lr = convert_rgb_to_ycbcr(image_lr)
        # y of lr ycbcr
        y_lr = ycbcr_lr[..., 0]
        y_lr /= 255.
        y_lr = torch.from_numpy(y_lr).to(device)
        y_lr = y_lr.unsqueeze(0).unsqueeze(0) # dim expand

        with torch.no_grad():
            preds = model(y_lr).clamp(0.0, 1.0)

        # resize preds if necessary
        if preds.size()[2] != image.height or preds.size()[3] != image.width:
            temp_image = Image.fromarray(preds.numpy().squeeze(0).squeeze(0))
            temp_image = temp_image.resize((image.width, image.height), resample=Image.BICUBIC)
            temp_image_arr = np.array(temp_image).astype(np.float32)
            preds = torch.from_numpy(temp_image_arr).to(device).unsqueeze(0).unsqueeze(0)
        print(img)
        print('PSNR on y: {:.4f}'.format(calc_psnr(y, preds)))
        print('SSIM on y: {:.4f}'.format(ssim(y, preds)))

        preds = preds.mul(255.0).numpy().squeeze(0).squeeze(0)

        # (channels,imagesize,imagesize) to (imagesize,imagesize,channels)
        output = np.array([preds, ycbcr[..., 1], ycbcr[..., 2]], dtype=object).transpose([1, 2, 0])

        output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8)
        output = Image.fromarray(output)

        output.save(os.path.join(out_dir, 'X{}'.format(scale), img))
    print('test done')

In [None]:
test(scale=scale, model_name='OFSRCNN', dataset_name='Set5')