### This code will merge the data transfer obtained after running ietrans.py to a new .h5 file containing the augmented dataset

In [None]:
import h5py
import json
import numpy as np
from tqdm import tqdm

data_path = "../../datasets/IndoorVG_4/VG-SGG.h5"
new_rels_path = "./ietrans_vctree.json"
dict_path = "../../datasets/IndoorVG_4/VG-SGG-dicts.json"
img_data_path = "../../datasets/vg/image_data.json"
img_data = json.load(open(img_data_path, 'r'))

data = h5py.File(data_path, 'r')
new_rels_data = json.load(open(new_rels_path, 'r'))
labels_dict = json.load(open(dict_path, 'r'))

idx_to_labels = labels_dict['idx_to_label']
idx_to_predicates = labels_dict['idx_to_predicate']

relationships = []
predicates = []
im_to_first_rel = np.zeros(len(data['img_to_first_rel']), dtype=np.int32)
im_to_last_rel = np.zeros(len(data['img_to_first_rel']), dtype=np.int32)

rel_idx_counter = 0

print('Before transfer: ', len(data['relationships']))

data_split = data['split_rel'][:]
split_mask = data_split == 0
split_mask &= data['img_to_first_rel'][:] >= 0
image_index = np.where(split_mask)[0]

count_trans = 0
out_count = 0
orig_img = 0
for i in tqdm(range(len(data['split_rel']))):
    image_name = str(img_data[i]['image_id'])+'.jpg'
    if (str(image_name) not in new_rels_data.keys()) or len(new_rels_data[str(image_name)]) == 0:
        orig_img +=1
        im_to_first_rel[i] = rel_idx_counter
        if data['img_to_first_rel'][i] == -1:
            im_to_first_rel[i] = -1
            im_to_last_rel[i] = -1
        else:
            for j in range(data['img_to_first_rel'][i], data['img_to_last_rel'][i]+1):
                predicates.append(data['predicates'][j])
                relationships.append(data['relationships'][j])
                rel_idx_counter += 1
                out_count += 1
            im_to_last_rel[i] = rel_idx_counter - 1
    else:
        new_rels = new_rels_data[str(image_name)]
        im_to_first_rel[i] = rel_idx_counter
        # internal trans
        if data['img_to_first_rel'][i] != -1:
            for j in range(data['img_to_first_rel'][i], data['img_to_last_rel'][i]+1):
                r = data['relationships'][j]
                rel = [r[0]-data['img_to_first_box'][i], r[1]-data['img_to_first_box'][i]]
                if len(new_rels) != 0 and np.any(np.all(np.array(new_rels)[:,0:2] == rel, axis=1)):
                    rel_idx = np.where(np.all(np.array(new_rels)[:,0:2] == rel, axis=1))[0]
                    # print('Found: ', rel, new_rels[rel_idx[0]])
                    # print('old: ', data['predicates'][j], 'new: ', new_rels[rel_idx[0]][2])
                    predicates.append([new_rels[rel_idx[0]][2]])
                    new_rels.pop(rel_idx[0])
                else:
                    out_count += 1
                    predicates.append(data['predicates'][j])
                rel_idx_counter += 1
                relationships.append(r)
        # external trans
        for rel in new_rels:
            i_obj_start = data['img_to_first_box'][i]
            num_boxes = data['img_to_last_box'][i] - i_obj_start + 1
            sub = i_obj_start + rel[0]
            obj = i_obj_start + rel[1]

            if rel[0] >= num_boxes or rel[1] >= num_boxes:
                print('Error: ', rel[0], rel[1], num_boxes)
                continue

            predicates.append(rel[2])
            relationships.append([sub,obj])
            rel_idx_counter += 1
            count_trans += 1
        if im_to_first_rel[i] == rel_idx_counter:
        # if no qualifying relationship
            im_to_first_rel[i] = -1
            im_to_last_rel[i] = -1
        else:
            im_to_last_rel[i] = rel_idx_counter - 1

print('After transfer: ', len(relationships))

assert len(predicates) == len(relationships)    

Before transfer:  99824


100%|██████████| 108073/108073 [00:29<00:00, 3654.14it/s]

After transfer:  133908





In [4]:
out_path = "../../datasets/IndoorVG_4/VG-SGG-augmented-vctree.h5"

f = h5py.File(out_path, 'w')

predicates = np.vstack(predicates)
relationships = np.vstack(relationships)

f.create_dataset('labels', data=data['labels'])
f.create_dataset('boxes_512', data=data['boxes_512'])
f.create_dataset('boxes_1024', data=data['boxes_1024'])

f.create_dataset('img_to_first_box', data=data['img_to_first_box'])
f.create_dataset('img_to_last_box', data=data['img_to_last_box'])

f.create_dataset('predicates', data=predicates)
f.create_dataset('relationships', data=relationships)
f.create_dataset('img_to_first_rel', data=im_to_first_rel)
f.create_dataset('img_to_last_rel', data=im_to_last_rel)

f.create_dataset('split_rel', data=data['split_rel'])
f.create_dataset('split', data=data['split'])

# open h5 file
f.close()

data = h5py.File(out_path, 'r')
print(data['relationships'].shape)
data.close()

(133908, 2)
