In [4]:
import os
import json
import glob
import cv2

from collections import defaultdict
import numpy as np
import matplotlib.pyplot as plt
from ultralytics import SAM
from huggingface_hub import snapshot_download
ROOT_DIR = "/home/data/pace"

In [None]:
# repo_id = "qq456cvb/PACE"  # Replace with the actual dataset ID
# local_dir = "/home/data/pace"  # The local path where you want to save the folder
# allow_patterns = ["val_pbr_cat.tar.gz"] # To download a specific folder within the dataset
# snapshot_download(repo_id=repo_id, local_dir=local_dir, allow_patterns=allow_patterns, repo_type="dataset")

# len(glob.glob(f'{ROOT_DIR}/test/*')), len(glob.glob(f'{ROOT_DIR}/val_inst/*')), len(glob.glob(f'{ROOT_DIR}/val_pbr_cat/*'))

In [5]:
### data[video_id][frame_id]: set of all instances (IDs) in that frame (of that video)
 
data = dict()

all_categories = list(set(['_'.join(x.split('/')[-1].split('_')[:-1]) for x in glob.glob(f'{ROOT_DIR}/model_splits/category/*_train.txt')+\
						   													glob.glob(f'{ROOT_DIR}/model_splits/category/*_val.txt')+\
						   													glob.glob(f'{ROOT_DIR}/model_splits/category/*_test.txt')]))
all_categories.sort()

data = dict()
video_splits = {'test': set(), 'val_inst': set(), 'val_pbr_cat': set()}
for video_path in glob.glob(f'{ROOT_DIR}/test/*')+glob.glob(f'{ROOT_DIR}/val_inst/*')+glob.glob(f'{ROOT_DIR}/val_pbr_cat/*'):
	# print(video_path)
	frame_ids = [int(f.split('/')[-1].split('.')[0]) for f in glob.glob(f'{video_path}/rgb/*')]
	if 'test' in video_path:
		split = 'test'
	elif 'val_inst' in video_path:
		split = 'val_inst'
	elif 'val_pbr_cat' in video_path:
		split = 'val_pbr_cat'
	else:
		raise ValueError(f'Unknown split for video {video_id} at path {video_path}')
	video_id = split+'_'+video_path.split('/')[-1]

	video_splits[split].add(video_id)

	data[video_id] = {frame_id: set() for frame_id in frame_ids}
		
	json_file = f'{video_path}/scene_gt_coco_det_modal_inst.json'
	with open(json_file, 'r') as f:
		video_json = json.load(f)

	for anno in video_json['annotations']:
		object_instance_id = anno['category_id']
		frame_id = anno['image_id']
		data[video_id][frame_id].add(object_instance_id)

instances_of_category = dict()
for category in all_categories:
	test_instance_ids = [int(x.split('_')[-1]) for x in open(f'{ROOT_DIR}/model_splits/category/{category}_test.txt', 'r').read().splitlines()]
	val_instance_ids = [int(x.split('_')[-1]) for x in open(f'{ROOT_DIR}/model_splits/category/{category}_val.txt', 'r').read().splitlines()]
	train_instance_ids = [int(x.split('_')[-1]) for x in open(f'{ROOT_DIR}/model_splits/category/{category}_train.txt', 'r').read().splitlines()]
	instances_of_category[category] = test_instance_ids + val_instance_ids + train_instance_ids
	print(f'Category {category}, has {len(instances_of_category[category])} instances: {instances_of_category[category]}')

all_instances = []
for cat in instances_of_category:
	all_instances.extend(instances_of_category[cat])

Category bottle, has 49 instances: [2, 3, 5, 6, 16, 20, 21, 24, 4, 705, 14, 18, 694, 15, 11, 700, 1, 7, 8, 9, 10, 12, 13, 17, 19, 22, 23, 693, 695, 696, 697, 698, 699, 701, 702, 703, 704, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715, 716, 717]
Category bowl, has 51 instances: [40, 36, 38, 735, 739, 43, 722, 734, 30, 42, 728, 720, 25, 26, 27, 28, 29, 31, 32, 33, 34, 35, 37, 39, 41, 44, 45, 46, 718, 719, 721, 723, 724, 725, 726, 727, 729, 730, 731, 732, 733, 736, 737, 738, 740, 741, 742, 743, 744, 745, 746]
Category box-base_link, has 30 instances: [579, 558, 593, 595, 603, 554, 572, 591, 608, 548, 577, 556, 568, 601, 597, 581, 589, 583, 546, 587, 570, 562, 552, 599, 560, 606, 566, 550, 564, 574]
Category box-link1, has 30 instances: [578, 557, 592, 594, 602, 553, 571, 590, 607, 547, 575, 555, 567, 600, 596, 580, 588, 582, 545, 586, 569, 561, 551, 598, 559, 605, 565, 549, 563, 573]
Category brush, has 19 instances: [49, 50, 51, 55, 56, 48, 53, 747, 748, 749, 750, 47, 751, 752, 753, 5

In [6]:
def isInstanceInVideo(instance_id, video_id):
	"""Check if a given instance ID is present in any frame of the specified video."""
	return any(instance_id in frame_instances for frame_instances in data[video_id].values())

def set_union(sets):
    return set().union(*sets)

def sets_are_disjoint(sets):
	"""Check if all sets in the list are disjoint."""
	combined = set()
	for s in sets:
		if not combined.isdisjoint(s):
			return False
		combined.update(s)
	return True

In [7]:
## create a dictionary where the key is an instance ID, and the value is a list of video IDs where it appears
videos_containing_instance = defaultdict(set)
for video_id in data.keys():
	for instance in all_instances:
		if isInstanceInVideo(instance, video_id):
			videos_containing_instance[instance].add(video_id)

In [8]:
videos_containing_instance[236]

{'val_pbr_cat_000000'}

In [9]:
########## CHECK IF THEY HAVE ANY INSTANCES APPEARING IN BOTH TRAIN AND TEST
all_train_instances = [int(x.split('_')[-1]) for x in open(f'/home/data/pace/model_splits/category/train.txt', 'r').read().splitlines()]

for instance_id in all_train_instances:  # Check the first frame of the video
	test_videos_containing_this = [video_id for video_id in videos_containing_instance[instance_id] if video_id in video_splits['test']]
	val_inst_videos_containing_this = [video_id for video_id in videos_containing_instance[instance_id] if video_id in video_splits['val_inst']]
	val_pbr_cat_videos_containing_this = [video_id for video_id in videos_containing_instance[instance_id] if video_id in video_splits['val_pbr_cat']]
	if len(test_videos_containing_this) > 0:
		print(f"TRAIN INSTANCE {instance_id} FOUND IN TEST VIDEOS: {test_videos_containing_this}; ALSO IN {videos_containing_instance[instance_id]}")
	# if len(val_inst_videos_containing_this) > 0:
	# 	print(f"TRAIN INSTANCE {instance_id} FOUND IN VAL_INST: {val_inst_videos_containing_this}; ALSO IN {videos_containing_instance[instance_id]}")
	if len(val_pbr_cat_videos_containing_this) > 0:
		print(f"TRAIN INSTANCE {instance_id} FOUND IN VAL_PBR_CAT: {val_pbr_cat_videos_containing_this}; ALSO IN {videos_containing_instance[instance_id]}")

TRAIN INSTANCE 236 FOUND IN VAL_PBR_CAT: ['val_pbr_cat_000000']; ALSO IN {'val_pbr_cat_000000'}
TRAIN INSTANCE 274 FOUND IN VAL_PBR_CAT: ['val_pbr_cat_000005']; ALSO IN {'val_pbr_cat_000005'}
TRAIN INSTANCE 296 FOUND IN VAL_PBR_CAT: ['val_pbr_cat_000000']; ALSO IN {'val_pbr_cat_000000'}
TRAIN INSTANCE 345 FOUND IN VAL_PBR_CAT: ['val_pbr_cat_000005']; ALSO IN {'val_pbr_cat_000005'}
TRAIN INSTANCE 146 FOUND IN VAL_PBR_CAT: ['val_pbr_cat_000000']; ALSO IN {'val_pbr_cat_000000'}
TRAIN INSTANCE 1011 FOUND IN VAL_PBR_CAT: ['val_pbr_cat_000005']; ALSO IN {'val_pbr_cat_000005'}
TRAIN INSTANCE 1051 FOUND IN VAL_PBR_CAT: ['val_pbr_cat_000005']; ALSO IN {'val_pbr_cat_000005'}
TRAIN INSTANCE 846 FOUND IN VAL_PBR_CAT: ['val_pbr_cat_000005']; ALSO IN {'val_pbr_cat_000005'}
TRAIN INSTANCE 552 FOUND IN VAL_PBR_CAT: ['val_pbr_cat_000000']; ALSO IN {'val_pbr_cat_000000'}
TRAIN INSTANCE 599 FOUND IN VAL_PBR_CAT: ['val_pbr_cat_000000']; ALSO IN {'val_pbr_cat_000000'}
TRAIN INSTANCE 888 FOUND IN VAL_PBR_CA