Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lmdb格式数据集转换代码可以提供一下嘛? #57

Closed
yeahQing opened this issue Sep 8, 2022 · 6 comments
Closed

lmdb格式数据集转换代码可以提供一下嘛? #57

yeahQing opened this issue Sep 8, 2022 · 6 comments

Comments

@yeahQing
Copy link

yeahQing commented Sep 8, 2022

# -*- coding: utf-8 -*-
import argparse
import glob
import io
import os
import pathlib
import threading

import cv2 as cv
import lmdb
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from tqdm import tqdm

# plt.rcParams['font.sans-serif'] = ['SimHei']  # 正常显示中文
# plt.rcParams['axes.unicode_minus'] = False  # 正常显示负号

root_path = pathlib.Path('/root/autodl-tmp/hwdb')

output_path = os.path.join(root_path, pathlib.Path('lmdb'))
train_path = os.path.join(root_path, pathlib.Path('train_3755'))
val_path = os.path.join(root_path, pathlib.Path('test'))

characters = []

with open('../character-3755.txt', 'r', encoding='utf-8') as f:
    while True:
        line = f.readline()
        if not line:
            break
        char = line.strip()
        characters.append(char)


def write_cache(env, cache):
    with env.begin(write=True) as txn:
        for k, v in cache.items():
            if isinstance(v, bytes):
                # 图片类型为bytes
                txn.put(k.encode(), v)
            else:
                # 标签类型为str, 转为bytes
                txn.put(k.encode(), v.encode())  # 编码


def create_dataset(env, image_path, label, index):
    n_samples = len(image_path)
    # map_size=1073741824 定义最大空间是1GB
    cache = {}
    cnt = index + 1
    for idx in range(n_samples):
        # 读取图片路径和对应的标签
        image = image_path[idx]
        if not os.path.exists(image):
            print('%s does not exist' % image)
            continue
        with open(image, 'rb') as fs:
            image_bin = fs.read()
        # .mdb数据库文件保存了两种数据,一种是图片数据,一种是标签数据,它们各有其key
        image_key = 'image-%09d' % cnt
        label_key = 'label-%09d' % cnt
        cache[image_key] = image_bin
        cache[label_key] = label
        cnt += 1
    if len(cache) != 0:
        write_cache(env, cache)
    return n_samples


def show_image(samples):
    plt.figure(figsize=(20, 10))
    for pos, sample in enumerate(samples):
        plt.subplot(4, 5, pos + 1)
        plt.imshow(sample[0])
        # plt.title(sample[1])
        plt.xticks([])
        plt.yticks([])
        plt.axis("off")
    plt.show()


def lmdb_test(root):
    env = lmdb.open(
        root,
        max_readers=1,
        readonly=True,
        lock=False,
        readahead=False,
        meminit=False)

    if not env:
        print('cannot open lmdb from %s' % root)
        return

    with env.begin(write=False) as txn:
        n_samples = int(txn.get('num-samples'.encode()))

    with env.begin(write=False) as txn:
        samples = []
        for index in range(1, n_samples + 1):
            img_key = 'image-%09d' % index
            img_buf = txn.get(img_key.encode())
            buf = io.BytesIO()
            buf.write(img_buf)
            buf.seek(0)
            try:
                img = Image.open(buf)
            except IOError:
                print('Corrupted image for %d' % index)
                return
            label_key = 'label-%09d' % index
            label = str(txn.get(label_key.encode()).decode('utf-8'))
            print(n_samples, len(img.split()), label)
            samples.append([img, label])
            if index == 5:
                # show_image(samples)
                # samples = []
                break


def lmdb_init(directory, out, left, right):
    entries = characters[left:right]
    pbar = tqdm(entries)
    n_samples = 0

    # 计算所需内存空间
    character_count = len(entries)
    image_path = glob.glob(os.path.join(directory, entries[0], '*.png'))
    image_cnt = len(image_path)
    data_size_per_img = cv.imdecode(np.fromfile(image_path[0], dtype=np.uint8), cv.IMREAD_UNCHANGED).nbytes
    # 一个类中所有图片的字节数
    data_size = data_size_per_img * image_cnt
    # 所有类的图片字节数
    total_byte = 2 * data_size * character_count
    # 创建lmdb文件
    if not os.path.exists(out):
        os.makedirs(out)
    env = lmdb.open(out, map_size=total_byte)
    for dir_name in pbar:
        image_path = glob.glob(os.path.join(directory, dir_name, '*.png'))
        label = dir_name
        n_samples += create_dataset(env, image_path, label, n_samples)
        pbar.set_description(
            f'character[{left + 1}:{right}]: {label} | nSamples: {n_samples} | total_byte: {total_byte}byte | progressing')

    write_cache(env, {'num-samples': str(n_samples)})
    env.close()


def begin(mode, left, right, valid=False):
    if mode == 'train':
        path = os.path.join(output_path, pathlib.Path(mode + '_' + str(right)))
        if not valid:
            lmdb_init(train_path, path, left=left, right=right)
        else:
            print(f"show:{valid},path:{path}")
            lmdb_test(path)
    elif mode == 'test':
        path = os.path.join(output_path, pathlib.Path(mode + '_' + str(right - left)))
        if not valid:
            lmdb_init(val_path, path, left=left, right=right)
        else:
            print(f"show:{valid},path:{path}")
            lmdb_test(path)


class MyThread(threading.Thread):
    def __init__(self, mode, left, right, valid):
        threading.Thread.__init__(self)
        self.mode = mode
        self.left = left
        self.right = right
        self.valid = valid

    def run(self):
        begin(mode=self.mode, left=self.left, right=self.right, valid=self.valid)


if __name__ == '__main__':
    """
    train_500: 3755类前500个类[1,500] = [0, 500)
    train_1000: 3755类第501到1000类[501,1000] = [500, 1000)
    train_1500: 3755类第1001到1500类[1001,1500] = [1000, 1500)
    train_2000: 3755类第1501到2000类[1501,2000] = [1500, 2000)
    train_2755: 3755类第2001到2755类[2001,2755] = [2000, 2755)
    train_3755: 3755类第2756到3755类[2756,3755] = [2755, 3755)
    test_1000: 3755类后1000类[2756,3755] = [2755, 3755)
    """
    parser = argparse.ArgumentParser()

    parser.add_argument("--train", action="store_true", help="generate train lmdb")
    parser.add_argument("--test", action="store_true", help="generate test lmdb")
    parser.add_argument("--all", action="store_true", help="generate all lmdb")
    parser.add_argument("--show", action="store_true", help="show result")
    parser.add_argument("--start", type=int, default=0, help="class start from where,default 0")
    parser.add_argument("--end", type=int, default=3755, help="class end from where,default 3755")

    args = parser.parse_args()

    train = args.train
    test = args.test
    build_all = args.all
    start = args.start
    end = args.end
    show = args.show

    if train:
        print(f"args: mode=train, [start:end)=[{start}:{end})")
        begin(mode='train', left=start, right=end, valid=show)
    if test:
        print(f"args: mode=test, [start:end)=[{start}:{end})")
        begin(mode='test', left=start, right=end, valid=show)
    if build_all:
        s = [0, 500, 1000, 1500, 2000, 2755]
        step = [500, 500, 500, 500, 755, 1000]
        m = ['5*train', '1*test']
        threads = []
        threadLock = threading.Lock()
        mode_index = 0
        for i in range(len(m)):
            tmp = m[i].strip().split("*")
            for j in range(int(tmp[0])):
                if show:
                    begin(mode=tmp[1], left=s[mode_index], right=s[mode_index] + step[mode_index], valid=show)
                else:
                    thread = MyThread(mode=tmp[1], left=s[mode_index],
                                      right=s[mode_index] + step[mode_index], valid=show)
                    threads.append(thread)
                    thread.start()
                mode_index += 1

        for t in threads:
            t.join()
@yeahQing yeahQing closed this as completed Sep 8, 2022
@yeahQing yeahQing changed the title 为什么Embeddings要把嵌入维度设置成word_n_class? lmdb格式数据集转换代码可以提供一下嘛? Sep 8, 2022
@yeahQing
Copy link
Author

yeahQing commented Sep 8, 2022

这个是我写的一个根据图片生成lmdb格式数据集的方法

@cptbtptp125
Copy link

cptbtptp125 commented Nov 9, 2022

非常感谢,我明白您的代码了

1 similar comment
@cptbtptp125
Copy link

非常感谢,我明白您的代码了

@cptbtptp125
Copy link

cptbtptp125 commented Nov 12, 2022

您好,我尝试用生成的lmdb格式的数据集,放入源码中运行,但是一直显示这个错误,请问是由于我原本的数据集标签是数字的缘故吗?如果您能帮助我,我将不胜感激。
File "/tmp/pycharm_project_671/data/lmdbReader.py", line 69, in getitem
return self[random.randint(0, len(self) - 1)]
File "/tmp/pycharm_project_671/data/lmdbReader.py", line 69, in getitem
return self[random.randint(0, len(self) - 1)]
File "/tmp/pycharm_project_671/data/lmdbReader.py", line 69, in getitem
return self[random.randint(0, len(self) - 1)]
[Previous line repeated 317 more times]
File "/tmp/pycharm_project_671/data/lmdbReader.py", line 57, in getitem
img = Image.open(buf)
File "/root/miniconda3/envs/py36/lib/python3.6/site-packages/PIL/Image.py", line 2953, in open
im = _open_core(fp, filename, prefix, formats)
File "/root/miniconda3/envs/py36/lib/python3.6/site-packages/PIL/Image.py", line 2939, in _open_core
im = factory(fp, filename)
File "/root/miniconda3/envs/py36/lib/python3.6/site-packages/PIL/ImageFile.py", line 121, in init
self._open()
File "/root/miniconda3/envs/py36/lib/python3.6/site-packages/PIL/PngImagePlugin.py", line 684, in _open
self.png = PngStream(self.fp)
File "/root/miniconda3/envs/py36/lib/python3.6/site-packages/PIL/PngImagePlugin.py", line 345, in init
super().init(fp)
RecursionError: maximum recursion depth exceeded while calling a Python object

@yeahQing
Copy link
Author

您好,我尝试用生成的lmdb格式的数据集,放入源码中运行,但是一直显示这个错误,请问是由于我原本的数据集标签是数字的缘故吗?如果您能帮助我,我将不胜感激。 File "/tmp/pycharm_project_671/data/lmdbReader.py", line 69, in getitem return self[random.randint(0, len(self) - 1)] File "/tmp/pycharm_project_671/data/lmdbReader.py", line 69, in getitem return self[random.randint(0, len(self) - 1)] File "/tmp/pycharm_project_671/data/lmdbReader.py", line 69, in getitem return self[random.randint(0, len(self) - 1)] [Previous line repeated 317 more times] File "/tmp/pycharm_project_671/data/lmdbReader.py", line 57, in getitem img = Image.open(buf) File "/root/miniconda3/envs/py36/lib/python3.6/site-packages/PIL/Image.py", line 2953, in open im = _open_core(fp, filename, prefix, formats) File "/root/miniconda3/envs/py36/lib/python3.6/site-packages/PIL/Image.py", line 2939, in _open_core im = factory(fp, filename) File "/root/miniconda3/envs/py36/lib/python3.6/site-packages/PIL/ImageFile.py", line 121, in init self._open() File "/root/miniconda3/envs/py36/lib/python3.6/site-packages/PIL/PngImagePlugin.py", line 684, in _open self.png = PngStream(self.fp) File "/root/miniconda3/envs/py36/lib/python3.6/site-packages/PIL/PngImagePlugin.py", line 345, in init super().init(fp) RecursionError: maximum recursion depth exceeded while calling a Python object

标签是汉字字符

@bad-meets-joke
Copy link

这个是我写的一个根据图片生成lmdb格式数据集的方法

请问你是先把原gnt格式都成一张张图片后,然后再做成lmdb数据集吗?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants