In [48]:
import os
import sys
import glob
import pickle
import lmdb
import cv2 as cv
import itertools

import math

In [33]:
# configurations
img_folder = '/mnt/SSD/xtwang/BasicSR_datasets/DIV2K800/DIV2K800_sub/*'  # glob matching pattern
lmdb_save_path = '/mnt/SSD/xtwang/BasicSR_datasets/DIV2K800/DIV2K800_sub.lmdb'
meta_info = {'name': 'DIV2K800_sub_GT'}
mode = 2  # 1 for reading all the images to memory and then writing to lmdb (more memory);
# 2 for reading several images and then writing to lmdb, loop over (less memory)
batch = 1000  # Used in mode 2. After batch images, lmdb commits.



In [41]:
def write_to_lmdb(img_gen, img_names, lmdb_path, data_size, nframes, save_every=10):
    
    # Create folder if its not there
    lmdb_folder = os.path.split(lmdb_path)[0]
    os.makedirs(lmdb_folder, exist_ok=True)
    
    lmdb_obj = lmdb.open(lmdb_path, map_size=data_size) # create file
    
    txn = lmdb_obj.begin(write=True) # Create transaction object
    for i, (img, name) in enumerate(zip(img_gen, img_names)):
        name = name.encode("ascii")
        txn.put(name, img)
        
        # save chunk to file, reset transaction
        if (i+1) % save_every == 0:
            txn.commit()
            txn = lmdb_obj.begin(write=True)
            print(f"Writing {i+1}/{nframes}")
    
    # Close file
    txn.commit()
    lmdb_obj.close()

In [50]:
def create_metadata(img_names, resolutions, dataset_name, lmdb_path):
    # resolutions: a list of length 1 (if all images are the same resolution), or len(img_names)
    
    meta = {
        "name": dataset_name,
        "keys": list(img_names), # converts generators to lists
        "resolution": list(resolutions)
    }
    
    with open(os.path.join(lmdb_path, "meta_info.pkl"), "wb") as f:
        pickle.dump(meta, f)

In [34]:
def load_video(filename):
    vid = cv.VideoCapture(filename)

    while vid.isOpened():
        has_frame, img = vid.read()

        if has_frame:
            yield img
        else:
            break
            
    vid.release()
    
def create_names(nframes):
    # pads names. eg: returns "0001", ... "2000" for nframes == 2000
    zeros = math.ceil(math.log(nframes, 10))
    
    for i in range(nframes):
        yield f"{i:0>{zeros}d}"

            
def get_video_stats(filename):
    vid = cv.VideoCapture(filename)
    
    nframes = int(vid.get(cv.CAP_PROP_FRAME_COUNT))
    resolution = int(vid.get(cv.CAP_PROP_FRAME_HEIGHT)), int(vid.get(cv.CAP_PROP_FRAME_WIDTH)), 3 # height, width, channels
    
    _, img = vid.read()
    frame_data_size = img.nbytes
    
    
    vid.release()
    
    return nframes, resolution, frame_data_size

In [54]:
def load_imgs_glob(glob_pattern):
    filenames = sorted(glob.glob(glob_pattern)) # list
    images = (cv.imread(f, cv.IMREAD_UNCHANGED) for f in filenames) # generator
    
    return filenames, images

def get_glob_stats(glob_pattern):
    filenames, images = load_imgs_glob(glob_pattern)
    
    nfiles = len(filenames)
    
    resolutions = []
    total_dsize = 0
    for img in images:
        resolutions.append(img.shape)
        total_dsize += img.nbytes * 10
        
    return nfiles, resolutions, total_dsize

# Full run for a video input

In [53]:
video_path = "/data/aicity/train/2.mp4"
lmdb_path = "/data/aicity/lmdb/train/2.lmdb"

images = load_video(video_path)
length, (h, w, c), frame_dsize = get_video_stats(video_path)
total_dsize = frame_dsize * length * 10
resolution = "{:d}_{:d}_{:d}".format(c, h, w)
img_names1, img_names2 = create_names(length), create_names(length)

write_to_lmdb(images, img_names1, lmdb_path, total_dsize, length, )
create_metadata(img_names2, [resolution], "video_data", lmdb_path)

Writing 10/26820
Writing 20/26820
Writing 30/26820
Writing 40/26820
Writing 50/26820
Writing 60/26820
Writing 70/26820
Writing 80/26820
Writing 90/26820
Writing 100/26820
Writing 110/26820
Writing 120/26820
Writing 130/26820
Writing 140/26820
Writing 150/26820
Writing 160/26820
Writing 170/26820
Writing 180/26820
Writing 190/26820
Writing 200/26820
Writing 210/26820
Writing 220/26820
Writing 230/26820
Writing 240/26820
Writing 250/26820
Writing 260/26820
Writing 270/26820
Writing 280/26820
Writing 290/26820
Writing 300/26820
Writing 310/26820
Writing 320/26820
Writing 330/26820
Writing 340/26820
Writing 350/26820
Writing 360/26820
Writing 370/26820
Writing 380/26820
Writing 390/26820
Writing 400/26820
Writing 410/26820
Writing 420/26820
Writing 430/26820
Writing 440/26820
Writing 450/26820
Writing 460/26820
Writing 470/26820
Writing 480/26820
Writing 490/26820
Writing 500/26820
Writing 510/26820
Writing 520/26820
Writing 530/26820
Writing 540/26820
Writing 550/26820
Writing 560/26820
W

Writing 4390/26820
Writing 4400/26820
Writing 4410/26820
Writing 4420/26820
Writing 4430/26820
Writing 4440/26820
Writing 4450/26820
Writing 4460/26820
Writing 4470/26820
Writing 4480/26820
Writing 4490/26820
Writing 4500/26820
Writing 4510/26820
Writing 4520/26820
Writing 4530/26820
Writing 4540/26820
Writing 4550/26820
Writing 4560/26820
Writing 4570/26820
Writing 4580/26820
Writing 4590/26820
Writing 4600/26820
Writing 4610/26820
Writing 4620/26820
Writing 4630/26820
Writing 4640/26820
Writing 4650/26820
Writing 4660/26820
Writing 4670/26820
Writing 4680/26820
Writing 4690/26820
Writing 4700/26820
Writing 4710/26820
Writing 4720/26820
Writing 4730/26820
Writing 4740/26820
Writing 4750/26820
Writing 4760/26820
Writing 4770/26820
Writing 4780/26820
Writing 4790/26820
Writing 4800/26820
Writing 4810/26820
Writing 4820/26820
Writing 4830/26820
Writing 4840/26820
Writing 4850/26820
Writing 4860/26820
Writing 4870/26820
Writing 4880/26820
Writing 4890/26820
Writing 4900/26820
Writing 4910

KeyboardInterrupt: 

# Full run for images

In [None]:
img_folder = '/mnt/SSD/xtwang/BasicSR_datasets/DIV2K800/DIV2K800_sub/*'  # glob matching pattern
lmdb_path = '/mnt/SSD/xtwang/BasicSR_datasets/DIV2K800/DIV2K800_sub.lmdb'

filenames, images = load_imgs_glob(img_folder)
length, resolutions, total_dsize = get_glob_stats(img_folder)
resolutions = ["{:d}_{:d}_{:d}".format(c, h, w) for h, w, c in resolutions]
img_names = [os.path.splitext(os.path.basename(f))[0] for f in filenames]


write_to_lmdb(images, img_names1, lmdb_path, total_dsize, length, )
create_metadata(img_names2, resolutions, "video_data", lmdb_path)

In [None]:

try:
    sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
    from utils.util import ProgressBar
except ImportError:
    pass



###########################################
if not lmdb_save_path.endswith('.lmdb'):
    raise ValueError("lmdb_save_path must end with \'lmdb\'.")
#### whether the lmdb file exist
if os.path.exists(lmdb_save_path):
    print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path))
    sys.exit(1)
    
img_list = sorted(glob.glob(img_folder))
if mode == 1:
    print('Read images...')
    dataset = [cv2.imread(v, cv2.IMREAD_UNCHANGED) for v in img_list]
    data_size = sum([img.nbytes for img in dataset])
elif mode == 2:
    print('Calculating the total size of images...')
    data_size = sum(os.stat(v).st_size for v in img_list)
else:
    raise ValueError('mode should be 1 or 2')

key_l = []
resolution_l = []
pbar = ProgressBar(len(img_list))
env = lmdb.open(lmdb_save_path, map_size=data_size * 10)
txn = env.begin(write=True)  # txn is a Transaction object
for i, v in enumerate(img_list):
    pbar.update('Write {}'.format(v))
    base_name = os.path.splitext(os.path.basename(v))[0]
    key = base_name.encode('ascii')
    data = dataset[i] if mode == 1 else cv2.imread(v, cv2.IMREAD_UNCHANGED)
    if data.ndim == 2:
        H, W = data.shape
        C = 1
    else:
        H, W, C = data.shape
    txn.put(key, data)
    key_l.append(base_name)
    resolution_l.append('{:d}_{:d}_{:d}'.format(C, H, W))
    # commit in mode 2
    if mode == 2 and i % batch == 1:
        txn.commit()
        txn = env.begin(write=True)

txn.commit()
env.close()

print('Finish writing lmdb.')

#### create meta information
# check whether all the images are the same size
same_resolution = (len(set(resolution_l)) <= 1)
if same_resolution:
    meta_info['resolution'] = [resolution_l[0]]
    meta_info['keys'] = key_l
    print('All images have the same resolution. Simplify the meta info...')
else:
    meta_info['resolution'] = resolution_l
    meta_info['keys'] = key_l
    print('Not all images have the same resolution. Save meta info for each image...')

#### pickle dump
pickle.dump(meta_info, open(os.path.join(lmdb_save_path, 'meta_info.pkl'), "wb"))
print('Finish creating lmdb meta info.')