# 训练数据整理

### 1. 安装OpenXLab提供的Python库
```shell
pip install openxlab
```
访问该[链接](https://opendatalab.com)前往官网注册账号并记录下`Access Key`与`Secret Key`，并替换如下代码中相应内容。

### 2. 依次下载`Objects365`, `Flickr30k`和`GQA`数据集
如遇卡顿可中段该进程重新运行(下载全程使用cache可以直接恢复下载进度)。

In [None]:
import openxlab
from openxlab.dataset import info
from openxlab.dataset import get
from openxlab.dataset import download
import os
import hashlib

access_key = '<Access Key>'
secret_key = '<Secret Key>'
openxlab.login(ak=access_key, sk=secret_key)


def check(file_path: str, checksum: str):
    assert os.path.exists(file_path), FileNotFoundError(file_path)
    hasher = hashlib.md5()
    with open(file_path, 'rb') as fp:
        while chunk := fp.read(8192):
            hasher.update(chunk)
    assert hasher.hexdigest() == checksum


def get_dataset(dataset: str, target_path: str):
    if not os.path.exists(target_path):
        os.makedirs(target_path)

    if dataset == 'obj365':
        dataset_repo = 'OpenDataLab/Objects365_v1'
        source_path = '/raw/Objects365_v1.tar.gz'
        checksum = 'ca354f9e6e33f99ac16cbac900c692f2'
    elif dataset == 'flickr':
        dataset_repo = 'OpenDataLab/Flickr_Image'
        source_path = '/raw/archive.zip'
        checksum = '0acd2ed7099c62c9d0ce77941391ee2b'
    elif dataset == 'gqa':
        dataset_repo = 'OpenDataLab/GQA'
        source_path = '/raw/images.zip'
        checksum = 'ce0e89c03830722434d7f20a41b05342'
    else:
        raise NotImplementedError(dataset)

    info(dataset_repo=dataset_repo)
    download(dataset_repo=dataset_repo,
             source_path=source_path,
             target_path=target_path)
    check(
        os.path.join(target_path, dataset_repo.replace('/', '___'),
                     source_path), checksum)

In [None]:
get_dataset(dataset='obj365', target_path='data/obj365')
get_dataset(dataset='flickr', target_path='data/flickr')
get_dataset(dataset='gqa', target_path='data/gqa')

### 3. 依次提取`Objects365`, `Flickr30k`和`GQA`数据集中的图片

In [None]:
import os
import tarfile
import zipfile
from tqdm import tqdm
import shutil


def get_images(dataset: str, target_path: str):
    if not os.path.exists(target_path):
        os.makedirs(target_path)

    if dataset == 'obj365':
        tar_gz_path = './data/obj365/OpenDataLab___Objects365_v1/raw/Objects365_v1.tar.gz'
        with tarfile.open(tar_gz_path, 'r:gz') as tar_ref:
            for member in tar_ref.getmembers():
                if 'train' in member.name and member.name.endswith('.zip'):
                    zip_file = tar_ref.extractfile(member)
                    with zipfile.ZipFile(zip_file) as zip_ref:
                        pbar = tqdm(
                            total=len(zip_ref.infolist()),
                            desc=
                            f'unzip {member.name}/train/*.jpg to {target_path}',
                            ncols=150,
                            unit='file')
                        for file in zip_ref.infolist():
                            zip_ref.extract(file, target_path)
                            shutil.move(
                                os.path.join(target_path, file.filename),
                                os.path.join(target_path,
                                             os.path.basename(file.filename)))
                            pbar.update(1)
                        pbar.close()
                        shutil.rmtree(os.path.join(target_path, 'train'))
    elif dataset == 'flickr':
        zip_path = './data/flickr/OpenDataLab___Flickr_Image/raw/archive.zip'
        flag = 'flickr30k_images/flickr30k_images/flickr30k_images'
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            target_files = []
            for member in zip_ref.infolist():
                if os.path.dirname(
                        member.filename) == flag and member.filename.endswith(
                            '.jpg'):
                    target_files.append(member)
            pbar = tqdm(total=len(target_files),
                        desc=f'unzip {zip_path}/{flag}/*.jpg to {target_path}',
                        unit='file',
                        ncols=150)
            for member in target_files:
                zip_ref.extract(member, target_path)
                shutil.move(
                    os.path.join(target_path, member.filename),
                    os.path.join(target_path,
                                 os.path.basename(member.filename)))
                pbar.update(1)
            pbar.close()
            shutil.rmtree(os.path.join(target_path, flag.split(os.sep)[0]))
    elif dataset == 'gqa':
        zip_path = './data/gqa/OpenDataLab___GQA/raw/images.zip'
        flag = 'images'
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            target_files = []
            for member in zip_ref.infolist():
                if os.path.dirname(
                        member.filename) == flag and member.filename.endswith(
                            '.jpg'):
                    target_files.append(member)
            pbar = tqdm(total=len(target_files),
                        desc=f'unzip {zip_path} to {target_path}',
                        unit='file',
                        ncols=150)
            for member in target_files:
                zip_ref.extract(member, target_path)
                shutil.move(
                    os.path.join(target_path, member.filename),
                    os.path.join(target_path,
                                 os.path.basename(member.filename)))
                pbar.update(1)
            pbar.close()
            shutil.rmtree(os.path.join(target_path, flag))
    else:
        raise NotImplementedError(dataset)

In [None]:
get_images(dataset='obj365', target_path='./data/obj365/images')
get_images(dataset='flickr', target_path='./data/flickr/images')
get_images(dataset='gqa', target_path='./data/gqa/images')

In [None]:
assert len(os.listdir('./data/obj365/images')) == 608606
assert len(os.listdir('./data/flickr/images')) == 31784
assert len(os.listdir('./data/gqa/images')) == 148854

### 4. 依次获取`Objects365`, `Flickr30k`和`GQA`数据集的标注文件

`Objects365`的标注文件直接从数据集压缩包中提取。

`Flickr30k`和`GQA`的标注文件需要重Huggingface的GLIP仓库中进行下载。

#### 下载方式一

```shell
pushd .
cd ./data/flickr/annotations
wget https://huggingface.co/GLIPModel/GLIP/resolve/main/mdetr_annotations/final_flickr_separateGT_train.json
popd

pushd .
cd ./data/gqa/annotations
wget https://huggingface.co/GLIPModel/GLIP/resolve/main/mdetr_annotations/final_mixed_train_no_coco.json
popd
```
#### 下载方式二

访问[此链接](https://huggingface.co/GLIPModel/GLIP/tree/main/mdetr_annotations), 下载`final_flickr_separateGT_train.json`与`final_mixed_train_no_coco.json`。

In [None]:
def get_annotations(dataset: str, target_path: str):
    if not os.path.exists(target_path):
        os.makedirs(target_path)

    if dataset == 'obj365':
        tar_gz_path = './data/obj365/OpenDataLab___Objects365_v1/raw/Objects365_v1.tar.gz'
        with tarfile.open(tar_gz_path, 'r:gz') as tar_ref:
            for member in tar_ref.getmembers():
                if os.path.basename(member.name) == 'objects365_train.json':
                    tar_ref.extract(member, target_path)
                    shutil.move(
                        os.path.join(target_path, member.name),
                        os.path.join(target_path,
                                     os.path.basename(member.name)))
                    shutil.rmtree(
                        os.path.join(target_path,
                                     member.name.split(os.sep)[0]))
    else:
        raise NotImplementedError(dataset)

In [None]:
get_annotations(dataset='obj365', target_path='./data/obj365/annotations')

### 5. 清理`Objects365`, `Flickr30k`和`GQA`数据集的原始文件(可选)
清理前请确保以上步骤已全部正确执行！！！

In [None]:
shutil.rmtree('./data/obj365/OpenDataLab___Objects365_v1')
shutil.rmtree('./data/flickr/OpenDataLab___Flickr_Image')
shutil.rmtree('./data/gqa/OpenDataLab___GQA')

### 6. 打印数据集目录结构
正确的目录结构如下：
```txt
data
├── obj365
│   ├── images
│   │   ├── obj365_train_000000000012.jpg
│   │   ├── obj365_train_000000000036.jpg
│   │   ├── obj365_train_000000000072.jpg
│   │   ├── ... 608603 more files
│   └── annotations
│       └── objects365_train.json
├── flickr
│   ├── images
│   │   ├── 1000092795.jpg
│   │   ├── 10002456.jpg
│   │   ├── 1000268201.jpg
│   │   ├── ... 31780 more files
│   └── annotations
│       └── final_flickr_separateGT_train.json
└── gqa
    ├── images
    │   ├── 2317659.jpg
    │   ├── 2325293.jpg
    │   ├── n324162.jpg
    │   ├── ... 148851 more files
    └── annotations
        └── final_mixed_train_no_coco.json
```

In [None]:
import os


def gen_dir_tree(target_path, prefix='', max_files=3):
    tree = prefix + os.path.basename(target_path) + '\n'
    prefix = prefix.replace('├── ', '│   ').replace('└── ', '    ')

    if os.path.isdir(target_path):
        items = os.listdir(target_path)
        items.sort(key=lambda x: os.path.isfile(os.path.join(target_path, x)))
        for index, item in enumerate(items):
            item_path = os.path.join(target_path, item)
            if os.path.isfile(item_path) and index >= max_files:
                remaining_files = len(items) - index
                tree += prefix + '├── ... ' + str(
                    remaining_files) + ' more files\n'
                break
            if index == len(items) - 1:
                tree += gen_dir_tree(item_path, prefix + '└── ', max_files)
            else:
                tree += gen_dir_tree(item_path, prefix + '├── ', max_files)
    return tree


def print_dir_tree(target_path, max_files=3):
    tree = gen_dir_tree(target_path, max_files=max_files)
    print(tree)


print_dir_tree('./data')

In [None]:
from pycocotools.coco import COCO  # type: ignore
from tqdm import tqdm  # type: ignore

obj365_ann_file = './data/obj365/annotations/objects365_train.json'
obj365_coco = COCO(obj365_ann_file)
obj365_img_ids = obj365_coco.getImgIds()

### 7. 依次将`Objects365`, `Flickr30k`和`GQA`数据集中的有效标准信息写入`LMDB`文件中

In [None]:
import re
from pycocotools.coco import COCO
from tqdm import tqdm
import os
import numpy as np
import cv2
import lmdb
import pickle
import concurrent.futures
import threading

obj365_ann_file = "./data/obj365/annotations/objects365_train.json"
flickr_ann_file = "./data/flickr/annotations/final_flickr_separateGT_train.json"
gqa_ann_file = "./data/gqa/annotations/final_mixed_train_no_coco.json"

obj365_coco = COCO(obj365_ann_file)
flickr_coco = COCO(flickr_ann_file)
gqa_coco = COCO(gqa_ann_file)

obj365_img_ids = obj365_coco.getImgIds()
flickr_img_ids = flickr_coco.getImgIds()
gqa_img_ids = gqa_coco.getImgIds()

obj365_cat_ids = sorted(obj365_coco.getCatIds())
obj365_cat_names = [
    obj365_coco.loadCats(cat_id)[0]["name"] for cat_id in obj365_cat_ids
]
obj365_texts = [cat_name.split("/") for cat_name in obj365_cat_names]

env = lmdb.open("data/obj365_goldg_lmdb", map_size=1024**4)
obj365_img_dir = "data/obj365/images"
flickr_img_dir = "data/flickr/images"
gqa_img_dir = "data/gqa/images"
eps = 1e-5
worker_num = 8

new_img_id_lock = threading.Lock()
new_img_id = 0

In [None]:
def process_obj365(img_id, img_dir, coco_inst):
    global new_img_id, new_ann_id

    img_info = coco_inst.loadImgs([img_id])[0]
    ann_ids = coco_inst.getAnnIds(imgIds=[img_id])
    ann_info = coco_inst.loadAnns(ann_ids)

    img_path = os.path.join(img_dir, img_info["file_name"])
    if not os.path.exists(img_path):
        return None

    try:
        im = cv2.imread(img_path, cv2.IMREAD_COLOR)
    except Exception:
        return None

    im_h, im_w = im.shape[:2]
    if im_h != float(img_info["height"]) or im_w != float(img_info["width"]):
        return None

    del im

    rec = {
        "im_file": os.path.relpath(img_path, "./data"),
        "h": im_h,
        "w": im_w,
        "texts": obj365_texts,
    }

    valid_anns = []
    for ann in ann_info:
        if "bbox" not in ann:
            continue
        else:
            if not any(np.array(ann["bbox"])):
                continue

        if ann["bbox"][2] <= eps or ann["bbox"][3] <= eps:
            continue

        if ann["area"] <= 0:
            continue

        valid_anns.append(ann)

    valid_num = len(valid_anns)
    if valid_num == 0:
        return None

    gt_bbox = np.zeros((valid_num, 4), dtype=np.float32)
    gt_class = np.zeros((valid_num, 1), dtype=np.int32)
    is_crowd = np.zeros((valid_num, 1), dtype=np.int32)

    for i, ann in enumerate(valid_anns):
        gt_class[i][0] = ann["category_id"] - 1
        x1, y1, box_w, box_h = ann["bbox"]
        x2 = x1 + box_w
        y2 = y1 + box_h
        gt_bbox[i, :] = [x1, y1, x2, y2]
        is_crowd[i][0] = ann["iscrowd"]

    rec.update({
        "gt_class": gt_class,
        "gt_bbox": gt_bbox,
        "is_crowd": is_crowd,
    })

    return rec


def process_flickr(img_id, img_dir, coco_inst):
    global new_img_id, new_ann_id

    img_info = coco_inst.loadImgs([img_id])[0]
    ann_ids = coco_inst.getAnnIds(imgIds=[img_id])
    ann_info = coco_inst.loadAnns(ann_ids)

    ann_info = [ann for ann in ann_info if len(ann["tokens_positive"]) > 0]
    if len(ann_info) == 0:
        return None
    ann_info = sorted(ann_info, key=lambda i: sum(i["tokens_positive"][0]))

    cat2id = {}
    texts = []

    valid_anns = []
    for ann in ann_info:
        if ann["bbox"][2] <= eps or ann["bbox"][3] <= eps:
            continue
        cat_names = []
        end_idx = []
        for t in ann["tokens_positive"]:
            cat_name = img_info["caption"][t[0]:t[1]]
            pattern = re.compile(r"^(?=.*[a-zA-Z])[a-zA-Z\d\s]+$")
            if not bool(pattern.match(cat_name)):
                continue
            if len(end_idx) > 0 and t[0] == end_idx[-1] + 1:
                cat_names[-1] = cat_names[-1] + " " + cat_name
            else:
                cat_names.append(cat_name)
            end_idx.append(t[1])
        if len(cat_names) == 0:
            continue
        cat_name = "/".join(cat_names)
        cat_name = cat_name.lower()
        if cat_name not in cat2id:
            cat2id[cat_name] = len(cat2id)
            texts.append(cat_name)
        ann["cat_id"] = len(cat2id)
        valid_anns.append(ann)

    if len(valid_anns) == 0:
        return None

    if len(texts) == 0:
        return None

    img_path = os.path.join(img_dir, img_info["file_name"])
    if not os.path.exists(img_path):
        return None

    try:
        im = cv2.imread(img_path, cv2.IMREAD_COLOR)
    except Exception:
        return None

    im_h, im_w = im.shape[:2]
    if im_h != float(img_info["height"]) or im_w != float(img_info["width"]):
        return None

    del im

    rec = {
        "im_file": os.path.relpath(img_path, "./data"),
        "h": im_h,
        "w": im_w,
        "texts": [text.split("/") for text in texts],
    }

    valid_num = len(valid_anns)
    gt_bbox = np.zeros((valid_num, 4), dtype=np.float32)
    gt_class = np.zeros((valid_num, 1), dtype=np.int32)
    is_crowd = np.zeros((valid_num, 1), dtype=np.int32)

    for i, ann in enumerate(valid_anns):
        gt_class[i][0] = ann["cat_id"] - 1
        x1, y1, box_w, box_h = ann["bbox"]
        x2 = x1 + box_w
        y2 = y1 + box_h
        gt_bbox[i, :] = [x1, y1, x2, y2]
        is_crowd[i][0] = ann["iscrowd"]

    rec.update({
        "gt_class": gt_class,
        "gt_bbox": gt_bbox,
        "is_crowd": is_crowd,
    })

    return rec


def process_gqa(img_id, img_dir, coco_inst):
    return process_flickr(img_id, img_dir, coco_inst)


def write_to_lmdb(rec):
    global new_img_id
    if rec is not None:
        with env.begin(write=True) as txn:
            with new_img_id_lock:
                rec["im_id"] = np.array([new_img_id])
                txn.put(f"{new_img_id}".encode(), pickle.dumps(rec))
                new_img_id += 1

In [None]:
with concurrent.futures.ThreadPoolExecutor(max_workers=worker_num) as executor:
    futures = [
        executor.submit(process_obj365, img_id, obj365_img_dir, obj365_coco)
        for img_id in obj365_img_ids
    ]
    for future in tqdm(
            concurrent.futures.as_completed(futures),
            total=len(futures),
            desc=f"Convert obj365",
    ):
        rec = future.result()
        executor.submit(write_to_lmdb, rec)

with concurrent.futures.ThreadPoolExecutor(max_workers=worker_num) as executor:
    futures = [
        executor.submit(process_flickr, img_id, flickr_img_dir, flickr_coco)
        for img_id in flickr_img_ids
    ]
    for future in tqdm(
            concurrent.futures.as_completed(futures),
            total=len(futures),
            desc=f"Convert flickr",
    ):
        rec = future.result()
        executor.submit(write_to_lmdb, rec)

with concurrent.futures.ThreadPoolExecutor(max_workers=worker_num) as executor:
    futures = [
        executor.submit(process_gqa, img_id, gqa_img_dir, gqa_coco)
        for img_id in gqa_img_ids
    ]
    for future in tqdm(
            concurrent.futures.as_completed(futures),
            total=len(futures),
            desc=f"Convert gqa",
    ):
        rec = future.result()
        executor.submit(write_to_lmdb, rec)

env.close()

### 8. 可视化`LMDB`文件中的标注信息

In [None]:
import cv2
import lmdb
import pickle
import random
import PIL
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw
import os
from tqdm import tqdm


def imagedraw_textsize_c(draw, text, font=None):
    if int(PIL.__version__.split('.')[0]) < 10:
        tw, th = draw.textsize(text, font=font)
    else:
        left, top, right, bottom = draw.textbbox((0, 0), text, font=font)
        tw, th = right - left, bottom - top

    return tw, th


env = lmdb.open("data/obj365_goldg_lmdb", map_size=1024**4)
with env.begin(write=False) as txn:
    total_num = txn.stat()['entries']
    for img_id in tqdm([random.randint(0, total_num) for _ in range(100)]):
        rec = pickle.loads(txn.get(f"{img_id}".encode()))
        img_path = os.path.join('./data', rec['im_file'])
        img = cv2.imread(img_path, cv2.IMREAD_COLOR)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = Image.fromarray(img)
        draw = ImageDraw.Draw(img)

        texts = rec['texts']

        num_boxes = rec['gt_bbox'].shape[0]

        for i in range(num_boxes):
            xmin, ymin, xmax, ymax = list(map(int, rec['gt_bbox'][i].tolist()))
            draw.line([(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
                       (xmin, ymin)],
                      width=2,
                      fill=(255, 0, 0))
            text = texts[rec['gt_class'][i].item()][0]
            tw, th = imagedraw_textsize_c(draw, text)
            draw.rectangle([(xmin + 1, ymin - th), (xmin + tw + 1, ymin)],
                           fill=(255, 0, 0))
            draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255))
        img.save(f"./tmp/{img_id}.jpg")
        # plt.imshow(img)