In [1]:
import pycocotools.coco as coco

# Load annotations
coco = coco.COCO('/data/hulab/zcai75/coco/annotations/captions_train2017.json')

loading annotations into memory...
Done (t=1.11s)
creating index...
index created!


In [2]:
target_concept = 'wearing'
# Filter out images that contain the target concept in the caption
imgIds = coco.getImgIds()
annIds = coco.getAnnIds(imgIds=imgIds)
anns = coco.loadAnns(annIds)
imgIds, captions = zip(*[(ann['image_id'], ann['caption']) for ann in anns if target_concept in ann['caption']])
imgIds = list(set(imgIds))
print(len(imgIds))
print(captions[:10])

7085
('A woman wearing a net on her head cutting a cake. ', 'A woman wearing a hair net cutting a large sheet cake.', 'a little boy wearing headphones and looking at a computer monitor', 'a boy wearing headphones using one computer in a long row of computers', 'A small child wearing headphones plays on the computer.', 'Two men wearing aprons working in a commercial-style kitchen.', 'A baby wearing gloves, lying next to a teddy bear', 'The kid is skateboarding on the street while wearing a jacket.', 'A young man wearing black attire and a flowered tie is standing and smiling.', 'Smiling man wearing black shirt and pale green tie.')


In [3]:
import csv
from io import BytesIO
import os
import base64
from tqdm import tqdm
import h5py
import json
from PIL import Image
import numpy as np
import random

data_dir = 'dataset/OFA_data/vrd_mix'
vg_dir = '/data/hulab/zcai75/visual_genome'
image_dir = os.path.join(vg_dir, 'VG_100K')

In [4]:
toy = False
version = 'toy' if toy else target_concept
toy_count = 1000

if not os.path.exists(data_dir):
	os.makedirs(data_dir)

with h5py.File(os.path.join(vg_dir, 'VG-SGG-with-attri.h5'), 'r') as f, \
	 open(os.path.join(vg_dir, 'VG-SGG-dicts-with-attri.json'), 'r') as d, \
	 open(os.path.join(vg_dir, 'image_data.json')) as img_data:
	d = json.load(d)
	img_data = json.load(img_data)
	print(f.keys())
	# print(f['boxes_1024'][0])


	data = enumerate(zip(
		f['img_to_first_rel'], f['img_to_last_rel'],
		f['img_to_first_box'], f['img_to_last_box'],
		f['predicates'], f['split']))
	tqdm_obj = tqdm(data, total=len(f['split']))

	train_count = 0
	val_count = 0
	skip_count = 0
	train_rows = []
	val_rows = []
	for i, (first_rel, last_rel, first_box, last_box, preds, split) in tqdm_obj:
		if toy and ((train_count > toy_count and split == 0) or (val_count > toy_count and split != 0)):
			continue
		try:
			if last_rel - first_rel < 0:
				skip_count += 1
				continue

			image_id = img_data[i]['image_id']
			with Image.open(os.path.join(image_dir, f'{image_id}.jpg'), 'r') as img_f:
				img_rels = f['relationships'][first_rel : last_rel+1]

				pred_labels = np.atleast_1d(f['predicates'][first_rel : last_rel+1].squeeze()).tolist()
				boxes = f['boxes_1024'][first_box : last_box+1].squeeze().tolist()
				box_labels = np.atleast_1d(f['labels'][first_box : last_box+1].squeeze()).tolist()
				pred_slabels = [d['idx_to_predicate'][str(j)] for j in pred_labels]
				box_slabels = [d['idx_to_label'][str(j)] for j in box_labels]

				for rel_i, rel in enumerate(img_rels):
					if pred_slabels[rel_i] != target_concept:
						i1 = rel[0] - first_box
						i2 = rel[1] - first_box
						row = [str(image_id) + '-' + str(rel_i), ' '.join(map(str, boxes[i1])), ' '.join(map(str, boxes[i2])), box_slabels[i1], box_slabels[i2], pred_slabels[rel_i]]

						# print(row)
						if split == 0:
							if toy and train_count > toy_count:
								continue
							train_rows.append(row)
							train_count += 1
						else:
							if toy and val_count > toy_count:
								continue
							val_rows.append(row)
							val_count += 1
					else:
						skip_count += 1
					
		except FileNotFoundError:
			print('Cannot find ' + f'{image_id}.jpg')
		# break

for img_id, caption in zip(imgIds, captions):
	# clean caption
	caption = caption.replace('\n', ' ').replace('\t', ' ').replace('\r', ' ').replace('  ', ' ').strip()
	train_rows.append(['coco-' + str(img_id), caption, '', '', '', ''])

with open(os.path.join(data_dir, f'vg_train_{version}.tsv'), 'w+', newline='\n') as f_train, \
	open(os.path.join(data_dir, f'vg_val_{version}.tsv'), 'w+', newline='\n') as f_val:
	writer_train = csv.writer(f_train, delimiter='\t', lineterminator='\n')
	writer_val = csv.writer(f_val, delimiter='\t', lineterminator='\n')

	random.shuffle(train_rows)
	random.shuffle(val_rows)
	for row in train_rows:
		assert len(row) == 6
		writer_train.writerow(row)
	for row in val_rows:
		assert len(row) == 6
		writer_val.writerow(row)
print('Train:', train_count, 'Val:', val_count, 'Skipped:', skip_count)


<KeysViewHDF5 ['active_object_mask', 'attributes', 'boxes_1024', 'boxes_512', 'img_to_first_box', 'img_to_first_rel', 'img_to_last_box', 'img_to_last_rel', 'labels', 'predicates', 'relationships', 'split']>


100%|██████████| 108073/108073 [01:06<00:00, 1633.94it/s]


Train: 388090 Val: 163494 Skipped: 71121


In [5]:
# train_rels = set()
# val_rels = set()
# skip_count = 0

# with h5py.File(os.path.join(vg_dir, 'VG-SGG-with-attri.h5'), 'r') as f, \
# 	 open(os.path.join(vg_dir, 'VG-SGG-dicts-with-attri.json'), 'r') as d, \
# 	 open(os.path.join(vg_dir, 'image_data.json')) as img_data:
# 	d = json.load(d)
# 	img_data = json.load(img_data)
# 	print(f.keys())
# 	# print(f['boxes_1024'][0])

# 	data = enumerate(zip(
# 		f['img_to_first_rel'], f['img_to_last_rel'],
# 		f['img_to_first_box'], f['img_to_last_box'],
# 		f['predicates'], f['split']))
# 	tqdm_obj = tqdm(data, total=len(f['split']))


# 	for i, (first_rel, last_rel, first_box, last_box, preds, split) in tqdm_obj:
# 		try:
# 			if last_rel - first_rel < 0:
# 				skip_count += 1
# 				continue

# 			image_id = img_data[i]['image_id']
# 			with Image.open(os.path.join(image_dir, f'{image_id}.jpg'), 'r') as img_f:
# 				img_rels = f['relationships'][first_rel : last_rel+1]

# 				pred_labels = np.atleast_1d(f['predicates'][first_rel : last_rel+1].squeeze()).tolist()
# 				boxes = f['boxes_1024'][first_box : last_box+1].squeeze().tolist()
# 				box_labels = np.atleast_1d(f['labels'][first_box : last_box+1].squeeze()).tolist()
# 				pred_slabels = [d['idx_to_predicate'][str(j)] for j in pred_labels]
# 				box_slabels = [d['idx_to_label'][str(j)] for j in box_labels]

# 				if split == 0:
# 					train_rels = train_rels.union(set(pred_slabels))
# 				else:
# 					val_rels = val_rels.union(set(pred_slabels))
# 		except FileNotFoundError:
# 			print('Cannot find ' + f'{image_id}.jpg')

In [6]:
# print('Zeroshot rels:', len(train_rels), len(val_rels))
# print('Zeroshot rels:', val_rels.difference(train_rels))
# print('skipped:', skip_count)