In [None]:
import os
import json
import glob
import cv2

from collections import defaultdict
import numpy as np
import matplotlib.pyplot as plt
from ultralytics import SAM
from huggingface_hub import snapshot_download
ROOT_DIR = "/home/data/pace"

In [None]:
## Download PACE Dataset folders

# repo_id = "qq456cvb/PACE"  # Replace with the actual dataset ID
# local_dir = "."  # The local path where you want to save the folder
# allow_patterns = ["model_splits/*"] # To download a specific folder within the dataset
# snapshot_download(repo_id=repo_id, local_dir=local_dir, allow_patterns=allow_patterns, repo_type="dataset")

# len(glob.glob(f'{ROOT_DIR}/test/*')), len(glob.glob(f'{ROOT_DIR}/val_inst/*')) ###, len(glob.glob(f'{ROOT_DIR}/val_pbr_cat/*'))

In [None]:
# Create list of all object classes
all_categories = ['_'.join(x.split('/')[-1].split('_')[:-1]) for x in glob.glob(f'{ROOT_DIR}/model_splits/category/*_train.txt')]
all_categories.sort()
print(f"Total number of unique categories: {len(all_categories)}")

# Compile the data dictionary: data[video_id][frame_id] = set of all instance IDs in that frame of that video
data = dict()

for video_path in glob.glob(f'{ROOT_DIR}/test/*')+glob.glob(f'{ROOT_DIR}/val_inst/*'):
	video_id = int(video_path.split('/')[-1])
	frame_ids = [int(f.split('/')[-1].split('.')[0]) for f in glob.glob(f'{video_path}/rgb/*')]
	data[video_id] = {frame_id: set() for frame_id in frame_ids}
	
	with open(f'{video_path}/scene_gt_coco_det_modal_inst.json', 'r') as f:
		video_json = json.load(f)

	for anno in video_json['annotations']:
		object_instance_id = anno['category_id']
		frame_id = anno['image_id']
		data[video_id][frame_id].add(object_instance_id)

# Create dictionary to fetch all instance IDs for a given category
instances_of_category = dict()
for category in all_categories:
	test_instance_ids = [int(x.split('_')[-1]) for x in open(f'{ROOT_DIR}/model_splits/category/{category}_test.txt', 'r').read().splitlines()]
	val_instance_ids = [int(x.split('_')[-1]) for x in open(f'{ROOT_DIR}/model_splits/category/{category}_val.txt', 'r').read().splitlines()]
	instances_of_category[category] = set(test_instance_ids + val_instance_ids)
	print(f'Category {category}, has {len(instances_of_category[category])} instances: {instances_of_category[category]}')

# Create set of all instance IDs
all_instances = set()
for cat in instances_of_category:
	all_instances.update(instances_of_category[cat])
print(f'Total number of unique instances: {len(all_instances)}')

# PACE Statistics

In [None]:
# plot the distribution of number of frames per video
video_lengths = [len(frames) for frames in data.values()]
print(f'Total number of frames: {sum(video_lengths)}')
plt.hist(video_lengths, bins=50)
plt.xlabel('Number of frames')
plt.ylabel('Number of videos')
plt.title('Distribution of number of frames per video')
plt.show()

In [None]:
chosen_category = 'toys'
chosen_instances = instances_of_category[chosen_category]

# for each video, count how many different chosen instances show up at all (even if not in every frame)
chosen_counts = defaultdict(int)
for video_id, frames in data.items():
	for chosen_instance in chosen_instances:
		if any(chosen_instance in frames[frame_id] for frame_id in frames):
			chosen_counts[video_id] += 1

plt.hist(list(chosen_counts.values()), align='mid')
plt.xlabel(f'Number of instances of {chosen_category} appearing in a video')
plt.ylabel('Number of videos')
plt.title(f'Distribution of number of videos containing {chosen_category}')
plt.show()

In [None]:
# for each object instance, plot how many videos it appears in (even if it appears in only one frame)
# plot object instance ID on the x axis, and number of videos it appears in on the y axis
# sort from most to least number of videos
instance_video_count = {instance_id: 0 for instance_id in instances_of_category[chosen_category]}
for instance_id in instance_video_count:
	for video_id, frames in data.items():
		if any(instance_id in frames[frame_id] for frame_id in frames):
			print(f"Found instance ID: {instance_id} in video ID: {video_id}")
			instance_video_count[instance_id] += 1

# sort the bars by number of videos
# instance_video_count = dict(sorted(instance_video_count.items(), key=lambda item: item[1], reverse=True))
plt.figure(figsize=(12, 6))
plt.bar(np.arange(len(instance_video_count.keys())), instance_video_count.values())
plt.xlabel('Object Instance ID')
plt.ylabel('Number of Videos')
plt.title('Number of Videos per Object Instance')
plt.xticks(rotation=90)
plt.tight_layout()
plt.show()

In [None]:
## make histogram of instaces per category
instances_counts = [len(instances_of_category[category]) for category in all_categories]

# sort categories by number of instances
sorted_categories = [x for _, x in sorted(zip(instances_counts, all_categories), reverse=True)]
sorted_counts = sorted(instances_counts, reverse=True)

plt.figure(figsize=(12, 6))
plt.bar(sorted_categories, sorted_counts)
plt.xlabel('Category')
plt.ylabel('Number of Instances')
plt.title(f'Number of Instances per Category')
plt.xticks(rotation=90)
# plt.tight_layout()
plt.show()

In [None]:
# plot number of videos that each category occurs in
category_video_counts = defaultdict(int)
for category in all_categories:
	for video_id, frames in data.items():
		foundCategoryInVideo = False
		for frame_id, object_instances in frames.items():
			for instance_id in object_instances:
				if instance_id in instances_of_category[category]:
					category_video_counts[category] += 1
					foundCategoryInVideo = True
					break
			if foundCategoryInVideo:
				break

# plot number of videos that each category occurs in
sorted_categories = [x for _, x in sorted(zip(category_video_counts.values(), category_video_counts.keys()), reverse=True)]
sorted_counts = sorted(category_video_counts.values(), reverse=True)

plt.figure(figsize=(12, 6))
plt.bar(sorted_categories, sorted_counts)
plt.xlabel('Category')
plt.ylabel('Number of Videos')
plt.title(f'Number of Videos per Category')
plt.xticks(rotation=90)
# plt.tight_layout()
plt.show()

# Instance & video train-val-test split

In [None]:
output_dir = f'{ROOT_DIR}/toycar_can_v2'

In [None]:
def isInstanceInVideo(instance_id, video_id):
	"""Check if a given instance ID is present in any frame of the specified video."""
	return any(instance_id in frame_instances for frame_instances in data[video_id].values())

def set_union(sets):
    return set().union(*sets)

def sets_are_disjoint(sets):
	"""Check if all sets in the list are disjoint."""
	combined = set()
	for s in sets:
		if not combined.isdisjoint(s):
			return False
		combined.update(s)
	return True

In [None]:
# # for each chosen instance, find all videos that contain it, and all other instances that appear in those videos
# video_groups = []
# chosen_category = 'toothbrush'
# chosen_instances = instances_of_category[chosen_category]
# print(chosen_instances)

# for chosen_instance in chosen_instances:
# 	# check if the chosen instance is already in a group
# 	for group in video_groups:
# 		if chosen_instance in group['instances']:
# 			break
# 	else:
# 		# if not, create a new group 
# 		new_group = {'video_ids': set(), 'instances': set()}
# 		for video_id in data.keys():
# 			if isInstanceInVideo(chosen_instance, video_id):
# 				new_group['video_ids'].add(video_id)
# 				# find all other instances in the same video
# 				for other_instance in instances_of_category[chosen_category]:
# 					if (other_instance not in new_group['instances']) and isInstanceInVideo(other_instance, video_id):
# 						new_group['instances'].add(other_instance)
# 						for video_id in data.keys():
# 							if isInstanceInVideo(other_instance, video_id):
# 								# print(f'Adding instance {other_instance} to group for video {video_id}')
# 								new_group['video_ids'].add(video_id)
# 		if new_group['video_ids']:
# 			video_groups.append(new_group)
# 		else:
# 			print(f'No videos found for instance {chosen_instance} in category {chosen_category}.')
		

# for group in video_groups:
# 	total_frames = sum(len(data[video_id]) for video_id in group["video_ids"])
# 	# print('---')
# 	print(f'Video group with {len(group["video_ids"])} videos, {len(group["instances"])} instances, and {total_frames} frames.', end=' ')
# 	# print(f'Video IDs: {sorted(list(group["video_ids"]))}')
# 	print(f'Instances: {sorted(list(group["instances"]))}')

In [None]:
# instance_splits['train']['distractor'] = {320, 324, 336, 340,    2, 16,    437, 448}
# instance_splits['val']['distractor'] = {315, 316, 317, 318,    5, 6,    434, 435}
# instance_splits['test']['distractor'] = {305, 306, 307, 308,    21, 24,   451, 436}

In [None]:
video_splits = {'train': dict(), 'val': dict(), 'test': dict()}
instance_splits = {'train': dict(), 'val': dict(), 'test': dict()}
distractor_splits = {'train': dict(), 'val': dict(), 'test': dict()}

target_categories = ['toy_car', 'can']

instance_splits['train']['toy_car'] = {456, 458, 461, 470}
instance_splits['val']['toy_car'] = {459, 460, 467, 468}
instance_splits['test']['toy_car'] = {455, 457, 469}

instance_splits['train']['can'] = {74, 57, 58}
instance_splits['val']['can'] = {66, 70, 71, 73}
instance_splits['test']['can'] = {61, 62, 63}

In [None]:
# Given the instance splits, construct video splits (all videos containing train instances become train videos, etc.)
already_assigned_videos = set()
for split in instance_splits:
	for category in instance_splits[split]:
		video_splits[split][category] = set()
		for instance in instance_splits[split][category]:
			# add each video that contains this instance to the split
			for video_id, frames in data.items():
				if isInstanceInVideo(instance, video_id):
					# if the video is already assigned to another category, this may create a conflict
					if (video_id in already_assigned_videos) and (video_id not in video_splits[split][category]):
						if category == 'distractor':
							continue  # Just don't add the distractor videos that would create conflict
						raise ValueError("Splits are not disjoint!")
					
					video_splits[split][category].add(video_id)
					already_assigned_videos.add(video_id)

for split in video_splits:
	for category in video_splits[split]:
		print(f'{split:>5} | {category:>7} | Videos:    {sorted(list(video_splits[split][category]))}')
		if len(video_splits[split][category]) == 0:
			print(f'Warning: No videos found for split {split} and category {category}.')

for split in video_splits:
	for category in video_splits[split]:
		print(f'{split:>5} | {category:>7} | Instances: {sorted(list(instance_splits[split][category]))}')
		if len(instance_splits[split][category]) == 0:
			print(f'Warning: No instances found for split {split} and category {category}.')

videos_in_data_split = {'train': set(), 'val': set(), 'test': set()} # set of video IDs for each data split
instances_in_data_split = {'train': set(), 'val': set(), 'test': set()} # set of TARGET instance IDs for each data split
for split in video_splits:
	videos_in_data_split[split] = set_union(video_splits[split][category] for category in video_splits[split])
	instances_in_data_split[split] = set_union(instance_splits[split][category] for category in instance_splits[split] if category!='distractor')

videos_in_category = {category:set() for category in video_splits['train']} # set of video IDs for each category
instances_in_category = {category:set() for category in instance_splits['train']} # set of instance IDs for each category
for category in video_splits[split]:
	videos_in_category[category] = videos_in_category[category].union(video_splits[split][category])
	instances_in_category[category] = instances_in_category[category].union(instance_splits[split][category])

assert sets_are_disjoint([videos_in_data_split['train'], videos_in_data_split['test'], videos_in_data_split['val']])
assert sets_are_disjoint([instances_in_data_split['train'], instances_in_data_split['val'], instances_in_data_split['test']])

print(f'Train videos: {len(videos_in_data_split["train"])}, Train instances: {len(instances_in_data_split["train"])}')
print(f'Val videos: {len(videos_in_data_split["val"])}, Val instances: {len(instances_in_data_split["val"])}')
print(f'Test videos: {len(videos_in_data_split["test"])}, Test instances: {len(instances_in_data_split["test"])}')

In [None]:
## create a dictionary where the key is an instance ID, and the value is a list of video IDs where it appears
videos_containing_instance = defaultdict(set)
for video_id in data.keys():
	for instance in all_instances:
		if isInstanceInVideo(instance, video_id):
			videos_containing_instance[instance].add(video_id)

In [None]:
# find distractor instances that appear in the train videos but not in the val or test videos
train_instances = set()
# for video_id in videos_in_data_split['train']:
for instance_id in videos_containing_instance.keys():  # Check the first frame of the video
	# if any(video_id in videos_in_data_split['train'] for video_id in videos_containing_instance[instance_id]) and \

	if any(video_id in videos_in_data_split['train'] for video_id in videos_containing_instance[instance_id]) and \
	   not any(video_id in videos_in_data_split['val'] for video_id in videos_containing_instance[instance_id]) and \
	   not any(video_id in videos_in_data_split['test'] for video_id in videos_containing_instance[instance_id]):
		train_instances.add(instance_id)

target_train_instances = set_union([instance_splits['train'][category] for category in target_categories])
distractor_instances = train_instances.difference(target_train_instances)
print(f'{len(distractor_instances)} distractor instances: {sorted(list(distractor_instances))}')

In [None]:
def convert_bbox_to_yolo(bbox, img_width, img_height):
    """
    Convert bounding box from [x, y, width, height] format to YOLO format
    [x_center, y_center, width, height] normalized by image dimensions
    """
    x, y, w, h = bbox
    x_center = (x + w/2) / img_width
    y_center = (y + h/2) / img_height
    norm_width = w / img_width
    norm_height = h / img_height
    return x_center, y_center, norm_width, norm_height

def process_annotation_file(json_file_path):
    """
    Process a video annotation file and convert to YOLO format
    Returns a dictionary mapping frame_id to list of YOLO format annotations
    """
    with open(json_file_path, 'r') as f:
        annotations = json.load(f)['annotations']
    
    frame_annotations = defaultdict(list)
    
    for anno in annotations:
        # if anno.get('ignore', False):  # Skip ignored annotations
        #     print(f"Skipping ignored annotation: {anno}")
        #     if anno['category_id'] in wallet_instances.union(can_instances):
        #         print('can or wallet skipped')
        #     continue
            
        frame_id = anno['image_id']
        class_id = anno['category_id']
        bbox = anno['bbox']  # [x, y, width, height]
        img_width = anno['width']
        img_height = anno['height']
        
        # Convert to YOLO format
        x_center, y_center, norm_width, norm_height = convert_bbox_to_yolo(bbox, img_width, img_height)
        
        yolo_annotation = {
            'class_id': class_id,
            'x_center': x_center,
            'y_center': y_center,
            'width': norm_width,
            'height': norm_height
        }
        
        frame_annotations[frame_id].append(yolo_annotation)
    
    return frame_annotations

In [None]:
# toy_car : 0
# can : 1
category_to_label = {category: i for i, category in enumerate(target_categories)}
label_to_category_name = {x: category for category, x in category_to_label.items()} # maps class numbers (from  darknet labels) to category/class names

# Dataset Processing

In [None]:
def get_category_name(instance_id):
	category_name_search = [cat for cat in instances_of_category if instance_id in instances_of_category[cat]]
	assert len(category_name_search) >= 0, f"No categories found for label {instance_id}"
	assert len(category_name_search) == 1, f"Multiple categories found for label {instance_id}"
	return category_name_search[0]

In [None]:
# Create train, val, test splits
for split in ['train', 'val', 'test']:
	if not os.path.exists(f'{output_dir}/{split}/images/'):
		os.makedirs(f'{output_dir}/{split}/images/')
	if not os.path.exists(f'{output_dir}/{split}/labels/'):
		os.makedirs(f'{output_dir}/{split}/labels/')
	if not os.path.exists(f'{output_dir}-extra/{split}/full_labels/'):
		os.makedirs(f'{output_dir}-extra/{split}/full_labels/')

for video_path in glob.glob(f'{ROOT_DIR}/test/*')+glob.glob(f'{ROOT_DIR}/val_inst/*'):
	video_id = int(video_path.split('/')[-1])
	if video_id in videos_in_data_split['train']:
		split = 'train'
	elif video_id in videos_in_data_split['val']:
		split = 'val'
	elif video_id in videos_in_data_split['test']:
		split = 'test'
	else:
		continue

	labels = process_annotation_file(f'{video_path}/scene_gt_coco_det_modal_inst.json')
	for frame_id in glob.glob(f'{video_path}/rgb/*'):
		# print(f"cp {frame_id} {ROOT_DIR}/wallet_can_v0/{split}/{str(video_id)}/{video_id}_{os.path.basename(frame_id)}")
		os.system(f"cp {frame_id} {output_dir}/{split}/images/{str(video_id)}_{int(os.path.basename(frame_id).split('.')[0])}.png")

		frame_num = int(os.path.basename(frame_id).split('.')[0])
		full_label_path = f"{output_dir}-extra/{split}/full_labels/{str(video_id)}_{frame_num}.txt"
		label_path = f"{output_dir}/{split}/labels/{str(video_id)}_{frame_num}.txt"

		for label in labels[frame_num]:
			category_name = get_category_name(label['class_id'])
			
			if category_name in target_categories:
				category_id = category_to_label[category_name]
				# if the category is one of the chosen categories, write to the label file
				with open(label_path, 'a') as f:
					f.write(f"{category_id} {label['x_center']} {label['y_center']} {label['width']} {label['height']}\n")
			
			# write to the full label file regardless of category
			### NOTE: store instance IDs here instead of class IDs
			with open(full_label_path, 'a') as f:
				f.write(f"{label['class_id']} {label['x_center']} {label['y_center']} {label['width']} {label['height']}\n")

In [None]:
for split in ['train', 'val', 'test']:
    for image_path in glob.glob(f'{output_dir}/{split}/images/*'):
        label_path = image_path.replace('.png', '.txt').replace('/images', '/labels')
        full_label_path = image_path.replace(f'{output_dir}/', f'{output_dir}-extra/').replace('.png', '.txt').replace('/images', '/full_labels')
        if not os.path.exists(label_path):
            os.system(f"touch {label_path}")
        if not os.path.exists(full_label_path):
            os.system(f"touch {full_label_path}")

In [None]:
for split in ['train', 'val', 'test']:
	path = f"{output_dir}/{split}"
	for image_path in glob.glob(f'{path}/images/*'):
		if not os.path.exists(image_path.replace('.png', '.txt').replace('/images', '/labels')):
			print(f"Missing label for image: {image_path}")

# for label_path in glob.glob(f'{path}/labels/*'):
# 	if not os.path.exists(label_path.replace('.txt', '.png').replace('/labels', '/images')):
# 		print(f"Missing image for label file: {label_path}")

In [None]:
# print number of images and labels in each split
for split in ['train', 'val', 'test']:
	num_images = len(glob.glob(f'{output_dir}/{split}/images/*'))
	num_labels = len(glob.glob(f'{output_dir}/{split}/labels/*'))
	num_full_labels = len(glob.glob(f'{output_dir}-extra/{split}/full_labels/*'))
	print(f"{split.capitalize()} - Images: {num_images}, Labels: {num_labels}, Full Labels: {num_full_labels}")

# Segment to get foreground object images

In [None]:
def read_darknet_bboxes(bbox_path, image_width, image_height):
	"""Read bounding boxes from darknet format file and convert to pixel coordinates"""
	bboxes = []
	
	with open(bbox_path, 'r') as f:
		foundLine = False
		class_ids = []
		for line in f:
			foundLine = True
			parts = line.strip().split()
			assert len(parts) == 5, f"Invalid bbox line: {line.strip()}"
			
			# Darknet format: class_id x_center y_center width height (normalized)
			class_id = int(parts[0])
			class_ids.append(class_id)

			x_center, y_center, width, height = map(float, parts[1:5])

			# Convert from normalized coordinates to pixel coordinates
			x_center_px = x_center * image_width
			y_center_px = y_center * image_height
			width_px = width * image_width
			height_px = height * image_height
			
			# Convert to x1, y1, x2, y2 format
			x1 = int(x_center_px - width_px / 2)
			y1 = int(y_center_px - height_px / 2)
			x2 = int(x_center_px + width_px / 2)
			y2 = int(y_center_px + height_px / 2)
			
			# Ensure coordinates are within image bounds
			x1 = max(0, min(x1, image_width - 1))
			y1 = max(0, min(y1, image_height - 1))
			x2 = max(0, min(x2, image_width - 1))
			y2 = max(0, min(y2, image_height - 1))
			
			bboxes.append([x1, y1, x2, y2])
		if not foundLine:
			print(f"No bounding boxes found in {bbox_path}. Returning empty list.")
	return class_ids, bboxes

In [None]:
def segment_images_from_folder_bbox(img_dir, label_dir, output_dir, frame_skip):
	"""
	Segments images in the given img_dir using the SAM model with bbox information from the given label_dir.
	Each image in img_dir should have a corresponding label file in label_dir (with the same name) with
	bounding box information in the format: class_id x_center y_center width height
	"""
	os.mkdir(output_dir) if not os.path.exists(output_dir) else None
	subdir_path = os.path.join(output_dir, 'masks')
	if not os.path.exists(subdir_path): 
		os.mkdir(subdir_path)

	for image_path in glob.glob(os.path.join(img_dir, '*')):
		image_name = os.path.basename(image_path).split('.')[0]
		bbox_path = os.path.join(label_dir, image_name + '.txt')
		video_id, frame_id = map(int, image_name.split('_')[:2])
		
		if frame_id % frame_skip != 0:
			continue

		print(f"Processing video {video_id}, frame {frame_id} from {image_path}")
		
		image_dimensions = cv2.imread(image_path).shape
		class_ids, bboxes = read_darknet_bboxes(bbox_path, image_dimensions[1], image_dimensions[0])
		if len(set(class_ids)) != len(class_ids):
			print(f"WARNING: Duplicate class IDs found in {bbox_path}: {class_ids}")
			continue

		if len(bboxes) == 0:
			print(f"No bounding boxes found for {image_path}. Skipping.")
			continue
		
		# Predict segmentation using the SAM model with bounding box
		results = model(image_path, bboxes=bboxes)[0]
		# visualize_image_annotations(image_path, bbox_path, output_dir)
		
		for class_id, mask in zip(class_ids, results.masks):
			# Assuming single class segmentation for simplicity, adjust as needed
			mask = mask.data.squeeze().cpu().numpy()  # For multi-class, iterate over masks

			mask = mask.astype(np.uint8) # Convert mask to uint8 if needed
			negative_mask = 1 - mask

			cv2.imwrite(os.path.join(output_dir, 'masks', os.path.basename(image_path).split('.')[0] + f'_mask_{class_id}.png'), negative_mask*255)

In [None]:
frame_skip = 70
model = SAM("sam2.1_l.pt")

# for class_dir in glob.glob(os.path.join(output_dir, '*')):
img_dir = f'{ROOT_DIR}/toycar_can_v2/train/images'
label_dir = f'{ROOT_DIR}/toycar_can_v2-extra/train/full_labels'
mask_output_dir = f'{ROOT_DIR}/toycar_can_v2-SAMTEST/train'
segment_images_from_folder_bbox(img_dir, label_dir, mask_output_dir, frame_skip=frame_skip)