In [2]:
import os
import glob
from model_utils.config import config
import numpy as np
import PIL.Image as pil_image
from PIL import ImageFile
from utils import convert_rgb_to_y
from mindspore.mindrecord import FileWriter



ImageFile.LOAD_TRUNCATED_IMAGES = True

if __name__ == '__main__':
    cfg = config
    if not os.path.exists(cfg.output_folder):
        os.makedirs(cfg.output_folder)
    prefix = "srcnn.mindrecord"
    file_num = 32
    patch_size = cfg.patch_size
    stride = cfg.stride
    scale = cfg.scale
    mindrecord_path = os.path.join(cfg.output_folder, prefix)
    writer = FileWriter(mindrecord_path, file_num)

    srcnn_json = {
        "lr": {"type": "float32", "shape": [1, patch_size, patch_size]},
        "hr": {"type": "float32", "shape": [1, patch_size, patch_size]},
    }
    writer.add_schema(srcnn_json, "srcnn_json")
    image_list = []
    file_list = sorted(os.listdir(cfg.src_folder))
    for file_name in file_list:
        path = os.path.join(cfg.src_folder, file_name)
        if os.path.isfile(path):
            image_list.append(path)
        else:
            for image_path in sorted(glob.glob('{}/*'.format(path))):
                image_list.append(image_path)

    print("image_list size ", len(image_list), flush=True)

    for path in image_list:
        hr = pil_image.open(path).convert('RGB')
        hr_width = (hr.width // scale) * scale
        hr_height = (hr.height // scale) * scale
        hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC)
        lr = hr.resize((hr_width // scale, hr_height // scale), resample=pil_image.BICUBIC)
        lr = lr.resize((lr.width * scale, lr.height * scale), resample=pil_image.BICUBIC)
        hr = np.array(hr).astype(np.float32)
        lr = np.array(lr).astype(np.float32)
        hr = convert_rgb_to_y(hr)
        lr = convert_rgb_to_y(lr)

        for i in range(0, lr.shape[0] - patch_size + 1, stride):
            for j in range(0, lr.shape[1] - patch_size + 1, stride):
                lr_res = np.expand_dims(lr[i:i + patch_size, j:j + patch_size] / 255., 0)
                hr_res = np.expand_dims(hr[i:i + patch_size, j:j + patch_size] / 255., 0)
                row = {"lr": lr_res, "hr": hr_res}
                writer.write_raw_data([row])

    writer.commit()
    print("Finish!")

{'enable_modelarts': 'Whether training on modelarts, default: False', 'data_url': 'Url for modelarts', 'train_url': 'Url for modelarts', 'data_path': 'The location of the input data.', 'output_path': 'The location of the output file.', 'device_target': 'Device target, support GPU and Ascend.', 'enable_profiling': 'Whether enable profiling while training, default: False'}
{'batch_size': 8,
 'checkpoint_path': '/cache/checkpoint_path.ckpt',
 'ckpt_file': '',
 'config_path': 'G:\\PythonProject\\Peter\\model_utils\\../config.yaml',
 'data_path': 'dataset/mindrecord',
 'device_num': 1,
 'device_target': 'CPU',
 'enable_profiling': False,
 'epoch_size': 40,
 'file_format': 'MINDIR',
 'file_name': 'srcnn',
 'filter_weight': False,
 'image_height': 512,
 'image_path': '',
 'image_width': 512,
 'keep_checkpoint_max': 20,
 'lr': 0.0001,
 'output_folder': 'dataset/mindrecord',
 'output_path': 'output/train',
 'patch_size': 33,
 'pre_trained_path': '',
 'predict_path': '',
 'pretrained_ckpt_path':