In [107]:
import json
import os
import pathlib
import sklearn
import random

from jupyterlab.semver import test_set

In [108]:
json_dir = '/home/dhm04/PycharmProjects/RSCR-baseline/dataset/RSITMD/dataset_RSITMD.json'
with open(json_dir, 'r') as f:
    data = json.load(f)

In [109]:
data['images'][0].keys(), len(data['images'])

random.shuffle(data['images'])

In [110]:
# train num
cnt_train = 0
cnt_test = 0
for item in data['images']:
    if item['split'] == 'train':
        cnt_train += 1
    else:
        cnt_test += 1

# 数据集的分割参考 AMNFN这篇论文
cnt_train, cnt_test, int((cnt_train) * 0.8) # 这里的3432和AMNFN的一致

(4291, 452, 3432)

In [111]:
print(f'image: {data["images"][94]["filename"]}')
print(f'image_id: {data["images"][94]["imgid"]}')

image: airport_567.tif
image_id: 124


In [112]:
item0 = data['images'][0]
print(f'sentence: {item0["sentences"][0]["raw"]}, corrosponding label: {item0["labels"][0]}')
print(f'sentence: {item0["sentences"][1]["raw"]}, corrosponding label: {item0["labels"][1]}')
print(f'sentence: {item0["sentences"][2]["raw"]}, corrosponding label: {item0["labels"][2]}')
print(f'sentence: {item0["sentences"][3]["raw"]}, corrosponding label: {item0["labels"][3]}')
print(f'sentence: {item0["sentences"][4]["raw"]}, corrosponding label: {item0["labels"][4]}')

sentence: Four resorts next to the swimming pool, corrosponding label: swimming pools
sentence: There are three swimming pools next to a resort, corrosponding label: - swimming
sentence: A resort has three pools next to it, corrosponding label: - beach
sentence: Some buildings with swimming pools and some green plants are near the beach., corrosponding label: green plants
sentence: There are several buildings with swimming pools and some green plants near the beach., corrosponding label: a resort


In [113]:
# 检查所有类型的标签
all_cls = set()
for item in data['images']:
    all_cls.add(item['filename'].split('_')[0])
len(all_cls)

33

In [114]:
from tqdm import tqdm
import numpy as np
import pprint
from sklearn.model_selection import train_test_split
def get_word_to_idx_dict(data):
    all_cls = set()
    cls_list = list()

    for item in data['images']:
        all_cls.add(item['filename'].split('_')[0])
        cls_list.append(item['filename'].split('_')[0])
    word_to_idx = {}
    for idx, cls in enumerate(sorted(all_cls)):
        word_to_idx[cls] = idx
    return word_to_idx

def create_banlanced_dataset(data, stratified=False):
    '''
    原论文的数据集处理方式只是随机抽样，没有考虑到样本每个类数量的平衡性
    这里我两种方法都添加了，也就是随机抽样和随机分层抽样（这种方式可以保证样本类别都是平衡的）可以避免数据的长尾分布(LongTail Distribution)
    '''
    train_percent = 0.8
    train_num = cnt_train * train_percent
    val_num = cnt_train * (1 - train_percent)

    test_num = cnt_test


    for item in data['images']:
        if len(item['sentids']) == 5:
            pass
        else:
            raise ValueError('有一个数据对应的文本不是5个')

    train_captions = []
    train_images = []
    train_word_label = []
    train_imgids = []

    test_captions = []
    test_images = []
    test_word_label = []
    test_imgids = []

    train_obj = []
    test_obj = []


    if stratified:
        pass
    else:
        for item in tqdm(data['images']):
            if item['split'] == 'train':
                for i in range(len(item['sentids'])): # 一张图片 五个文本
                    train_images.append('train/' + item['filename'])
                    train_captions.append(item['sentences'][i]['raw'])
                    train_word_label.append(item['filename'].split('_')[0])
                    train_imgids.append(item['imgid'])
            else:
                for i in range(len(item['sentids'])): # 一张图片 五个文本
                    test_images.append('test/' + item['filename'])
                    test_captions.append(item['sentences'][i]['raw'])
                    test_word_label.append(item['filename'].split('_')[0])
                    test_imgids.append(item['imgid'])

        word_to_idx = get_word_to_idx_dict(data)
        train_idx_label = [word_to_idx[i] for i in train_word_label]
        test_idx_label = [word_to_idx[i] for i in test_word_label]

        for index in range(len(train_images)):
            train_obj.append(
                {
                    "image": train_images[index],
                    "caption": train_captions[index],
                    "label_name": train_word_label[index],
                    "image_id": train_imgids[index],
                    "label": train_idx_label[index]
                }
            )
        for index in range(len(test_images)):
            test_obj.append(
                {
                    "image": test_images[index],
                    "caption": test_captions[index],
                    "label_name": test_word_label[index],
                    "image_id": test_imgids[index],
                    "label": test_idx_label[index]
                }
            )

        train_set, val_set = train_test_split(train_obj, train_size=train_percent, random_state=42)
        test_set = test_obj
    return train_set, val_set, test_set, word_to_idx

train_set, val_set, test_set, word_to_idx = create_banlanced_dataset(data)

100%|██████████| 4743/4743 [00:00<00:00, 434689.91it/s]


In [115]:
len(train_set) // 5, len(val_set) // 5, len(test_set) // 5 # 可以看到 与之前的train_filename_verify

(3432, 858, 452)

In [116]:
for item in train_set:
    if word_to_idx[item['label_name']] != item['label']:
        print('!!!')

In [118]:
with open('/home/dhm04/PycharmProjects/RSCR-baseline/dataset/AMNFNFinetune/ours_train_rsitmd.json', 'w', encoding='utf-8') as f:
    json.dump(train_set, f, ensure_ascii=False, indent=4)
with open('/home/dhm04/PycharmProjects/RSCR-baseline/dataset/AMNFNFinetune/ours_val_rsitmd.json', 'w', encoding='utf-8') as f:
    json.dump(val_set, f, ensure_ascii=False, indent=4)
with open('/home/dhm04/PycharmProjects/RSCR-baseline/dataset/AMNFNFinetune/ours_test_rsitmd.json', 'w', encoding='utf-8') as f:
    json.dump(test_set, f, ensure_ascii=False, indent=4)