In [1]:
import os
import json
import glob
import cv2

from collections import defaultdict
import numpy as np
import matplotlib.pyplot as plt
from ultralytics import SAM
from huggingface_hub import snapshot_download
ROOT_DIR = "/home/data/pace"

In [None]:
# repo_id = "qq456cvb/PACE"  # Replace with the actual dataset ID
# local_dir = "."  # The local path where you want to save the folder
# allow_patterns = ["model_splits/*"] # To download a specific folder within the dataset
# snapshot_download(repo_id=repo_id, local_dir=local_dir, allow_patterns=allow_patterns, repo_type="dataset")

# len(glob.glob(f'{ROOT_DIR}/test/*')), len(glob.glob(f'{ROOT_DIR}/val_inst/*')), len(glob.glob(f'{ROOT_DIR}/val_pbr_cat/*'))

In [2]:
### data[video_id][frame_id]: set of all instances (IDs) in that frame (of that video)
 
data = dict()

all_categories = ['_'.join(x.split('/')[-1].split('_')[:-1]) for x in glob.glob(f'{ROOT_DIR}/model_splits/category/*_train.txt')]
all_categories.sort()

data = dict()
for video_path in glob.glob(f'{ROOT_DIR}/test/*')+glob.glob(f'{ROOT_DIR}/val_inst/*'): #+glob.glob(f'{ROOT_DIR}/val_pbr_cat/*'):
	# print(video_path)
	video_id = int(video_path.split('/')[-1])
	frame_ids = [int(f.split('/')[-1].split('.')[0]) for f in glob.glob(f'{video_path}/rgb/*')]
	# print(f'Number of frames: {len(frame_ids)}')
	data[video_id] = {frame_id: set() for frame_id in frame_ids}
		
	json_file = f'{video_path}/scene_gt_coco_det_modal_inst.json'
	with open(json_file, 'r') as f:
		video_json = json.load(f)

	for anno in video_json['annotations']:
		object_instance_id = anno['category_id']
		frame_id = anno['image_id']
		data[video_id][frame_id].add(object_instance_id)

instances_of_category = dict()
for category in all_categories:
	test_instance_ids = [int(x.split('_')[-1]) for x in open(f'{ROOT_DIR}/model_splits/category/{category}_test.txt', 'r').read().splitlines()]
	val_instance_ids = [int(x.split('_')[-1]) for x in open(f'{ROOT_DIR}/model_splits/category/{category}_val.txt', 'r').read().splitlines()]
	instances_of_category[category] = test_instance_ids + val_instance_ids
	print(f'Processing category: {category}, with instances: {instances_of_category[category]}')

Processing category: bottle, with instances: [2, 3, 5, 6, 16, 20, 21, 24, 4, 705, 14, 18, 694, 15, 11, 700]
Processing category: bowl, with instances: [40, 36, 38, 735, 739, 43, 722, 734, 30, 42, 728, 720]
Processing category: box-base_link, with instances: [579, 558, 593, 595, 603, 554, 572, 591, 608, 548]
Processing category: box-link1, with instances: [578, 557, 592, 594, 602, 553, 571, 590, 607, 547]
Processing category: brush, with instances: [49, 50, 51, 55, 56, 48, 53]
Processing category: can, with instances: [66, 70, 71, 73, 74, 57, 58, 61, 62, 63, 59, 758, 764, 72, 761, 766, 778]
Processing category: chip_can, with instances: [75, 76, 77, 78, 79, 81, 82, 789]
Processing category: clip-link1, with instances: [610, 628, 618, 632]
Processing category: clip-link2, with instances: [609, 627, 617, 631]
Processing category: clock, with instances: [93, 85, 87, 88, 797, 90, 808, 792, 89]
Processing category: container, with instances: [96, 97, 95, 823, 829, 825, 815]
Processing catego

# PACE Statistics

In [None]:
# plot the distribution of number of frames per video
video_lengths = [len(frames) for frames in data.values()]
print(f'Total number of frames: {sum(video_lengths)}')
plt.hist(video_lengths, bins=50)
plt.xlabel('Number of frames')
plt.ylabel('Number of videos')
plt.title('Distribution of number of frames per video')
plt.show()

In [None]:
# show histogram of number of unique instances that appear in each video (for however many frames)
all_instances = []
for cat in instances_of_category:
	all_instances.extend(instances_of_category[cat])

In [None]:
print(f'Total number of unique instances: {len(set(all_instances))}')

In [None]:
instances_of_category['toys']

In [None]:
chosen_category = 'toys'
chosen_instances = instances_of_category[chosen_category]
# toothbrush, wallet, can
# for each video, count how many different chosen instances show up at all (even if not in every frame)
chosen_counts = defaultdict(int)
for video_id, frames in data.items():
	for chosen_instance in chosen_instances:
		if any(chosen_instance in frames[frame_id] for frame_id in frames):
			chosen_counts[video_id] += 1

plt.hist(list(chosen_counts.values()), align='mid')
plt.xlabel(f'Number of instances of {chosen_category} appearing in a video')
plt.ylabel('Number of videos')
plt.title(f'Distribution of number of videos containing {chosen_category}')
plt.show()

In [None]:
# for each object instance, plot how many videos it appears in (even if it appears in only one frame)
# plot object instance ID on the x axis, and number of videos it appears in on the y axis
# sort from most to least number of videos
instance_video_count = {instance_id: 0 for instance_id in instances_of_category[chosen_category]}
for instance_id in instance_video_count:
	for video_id, frames in data.items():
		if any(instance_id in frames[frame_id] for frame_id in frames):
			print(f"Found instance ID: {instance_id} in video ID: {video_id}")
			instance_video_count[instance_id] += 1

# sort the bars by number of videos
# instance_video_count = dict(sorted(instance_video_count.items(), key=lambda item: item[1], reverse=True))
plt.figure(figsize=(12, 6))
plt.bar(np.arange(len(instance_video_count.keys())), instance_video_count.values())
plt.xlabel('Object Instance ID')
plt.ylabel('Number of Videos')
plt.title('Number of Videos per Object Instance')
plt.xticks(rotation=90)
plt.tight_layout()
plt.show()

In [None]:
len(all_categories)

In [None]:
## make histogram of instaces per category
instances_counts = [len(instances_of_category[category]) for category in all_categories]

# sort categories by number of instances
sorted_categories = [x for _, x in sorted(zip(instances_counts, all_categories), reverse=True)]
sorted_counts = sorted(instances_counts, reverse=True)

plt.figure(figsize=(12, 6))
plt.bar(sorted_categories, sorted_counts)
plt.xlabel('Category')
plt.ylabel('Number of Instances')
plt.title(f'Number of Instances per Category')
plt.xticks(rotation=90)
# plt.tight_layout()
plt.show()

In [None]:
# plot number of videos that each category occurs in
category_video_counts = defaultdict(int)
for category in all_categories:
	for video_id, frames in data.items():
		foundCategoryInVideo = False
		for frame_id, object_instances in frames.items():
			for instance_id in object_instances:
				if instance_id in instances_of_category[category]:
					category_video_counts[category] += 1
					foundCategoryInVideo = True
					break
			if foundCategoryInVideo:
				break

# plot number of videos that each category occurs in
sorted_categories = [x for _, x in sorted(zip(category_video_counts.values(), category_video_counts.keys()), reverse=True)]
sorted_counts = sorted(category_video_counts.values(), reverse=True)

plt.figure(figsize=(12, 6))
plt.bar(sorted_categories, sorted_counts)
plt.xlabel('Category')
plt.ylabel('Number of Videos')
plt.title(f'Number of Videos per Category')
plt.xticks(rotation=90)
# plt.tight_layout()
plt.show()

# Instance & video train-val-test split

In [3]:
def isInstanceInVideo(instance_id, video_id):
	"""Check if a given instance ID is present in any frame of the specified video."""
	return any(instance_id in frame_instances for frame_instances in data[video_id].values())

In [4]:
# for each chosen instance, find all videos that contain it, and all other instances that appear in those videos
video_groups = []
chosen_category = 'snack_box'
chosen_instances = instances_of_category[chosen_category]
print(chosen_instances)

for chosen_instance in chosen_instances:
	# check if the chosen instance is already in a group
	for group in video_groups:
		if chosen_instance in group['instances']:
			break
	else:
		# if not, create a new group 
		new_group = {'video_ids': set(), 'instances': set()}
		for video_id in data.keys():
			if isInstanceInVideo(chosen_instance, video_id):
				new_group['video_ids'].add(video_id)
				# find all other instances in the same video
				for other_instance in instances_of_category[chosen_category]:
					if (other_instance not in new_group['instances']) and isInstanceInVideo(other_instance, video_id):
						new_group['instances'].add(other_instance)
						for video_id in data.keys():
							if isInstanceInVideo(other_instance, video_id):
								# print(f'Adding instance {other_instance} to group for video {video_id}')
								new_group['video_ids'].add(video_id)
		if new_group['video_ids']:
			video_groups.append(new_group)
		else:
			print(f'No videos found for instance {chosen_instance} in category {chosen_category}.')
		

for group in video_groups:
	total_frames = sum(len(data[video_id]) for video_id in group["video_ids"])
	# print('---')
	print(f'Video group with {len(group["video_ids"])} videos, {len(group["instances"])} instances, and {total_frames} frames.', end=' ')
	# print(f'Video IDs: {sorted(list(group["video_ids"]))}')
	print(f'Instances: {sorted(list(group["instances"]))}')

[0, 320, 324, 332, 336, 305, 306, 307, 308, 338, 339, 340, 342, 315, 316, 317, 318, 319, 311, 309, 347, 343, 328, 337, 326]
No videos found for instance 311 in category snack_box.
No videos found for instance 309 in category snack_box.
No videos found for instance 347 in category snack_box.
No videos found for instance 343 in category snack_box.
No videos found for instance 328 in category snack_box.
No videos found for instance 337 in category snack_box.
No videos found for instance 326 in category snack_box.
Video group with 15 videos, 2 instances, and 2547 frames. Instances: [0, 342]
Video group with 3 videos, 2 instances, and 585 frames. Instances: [320, 336]
Video group with 6 videos, 1 instances, and 1056 frames. Instances: [324]
Video group with 6 videos, 1 instances, and 1626 frames. Instances: [332]
Video group with 3 videos, 1 instances, and 780 frames. Instances: [305]
Video group with 3 videos, 1 instances, and 459 frames. Instances: [306]
Video group with 6 videos, 1 insta

In [5]:
def set_union(sets):
    return set().union(*sets)
def sets_are_disjoint(sets):
	"""Check if all sets in the list are disjoint."""
	combined = set()
	for s in sets:
		if not combined.isdisjoint(s):
			return False
		combined.update(s)
	return True

In [50]:
video_splits = {'train': dict(), 'val': dict(), 'test': dict()}
instance_splits = {'train': dict(), 'val': dict(), 'test': dict()}
distractor_splits = {'train': dict(), 'val': dict(), 'test': dict()}

target_categories = ['toy_car', 'can']
distractor_categories = ['snack_box']

instance_splits['train']['toy_car'] = {456, 458, 461, 470}
instance_splits['val']['toy_car'] = {459, 460, 467, 468}
instance_splits['test']['toy_car'] = {455, 457, 469}

instance_splits['train']['can'] = {74, 57, 58}
instance_splits['val']['can'] = {66, 70, 71, 73}
instance_splits['test']['can'] = {61, 62, 63}

instance_splits['train']['distractor'] = {320, 336, 324, 306, 338}
instance_splits['val']['distractor'] = {0, 342, 332, 307, 339, 317}
instance_splits['test']['distractor'] = {315, 316, 305, 308, 340, 318}

already_assigned_videos = set()
for split in instance_splits:
	for category in instance_splits[split]:
		video_splits[split][category] = set()
		for instance in instance_splits[split][category]:
			# add each video that contains this instance to the split
			for video_id, frames in data.items():
				if isInstanceInVideo(instance, video_id):
					# if the video is already assigned to another category, this may create a conflict
					if (video_id in already_assigned_videos) and (video_id not in video_splits[split][category]):
						if category == 'distractor':
							continue  # Just don't add the distractor videos that would create conflict
						else:
							assert False, "Splits are not disjoint!"
					video_splits[split][category].add(video_id)
					already_assigned_videos.add(video_id)

for split in video_splits:
	for category in video_splits[split]:
		print(f'{split:>5} | {category:>7} | Videos:    {sorted(list(video_splits[split][category]))}')
		if len(video_splits[split][category]) == 0:
			print(f'Warning: No videos found for split {split} and category {category}.')

for split in video_splits:
	for category in video_splits[split]:
		print(f'{split:>5} | {category:>7} | Instances: {sorted(list(instance_splits[split][category]))}')
		if len(instance_splits[split][category]) == 0:
			print(f'Warning: No instances found for split {split} and category {category}.')

videos_in_data_split = {'train': set(), 'val': set(), 'test': set()} # set of video IDs for each data split
instances_in_data_split = {'train': set(), 'val': set(), 'test': set()} # set of TARGET instance IDs for each data split
for split in video_splits:
	videos_in_data_split[split] = set_union(video_splits[split][category] for category in video_splits[split])
	instances_in_data_split[split] = set_union(instance_splits[split][category] for category in instance_splits[split] if category!='distractor')

videos_in_category = {category:set() for category in video_splits['train']} # set of video IDs for each category
instances_in_category = {category:set() for category in instance_splits['train']} # set of instance IDs for each category
for category in video_splits[split]:
	videos_in_category[category] = videos_in_category[category].union(video_splits[split][category])
	instances_in_category[category] = instances_in_category[category].union(instance_splits[split][category])

assert sets_are_disjoint([videos_in_data_split['train'], videos_in_data_split['test'], videos_in_data_split['val']])
assert sets_are_disjoint([instances_in_data_split['train'], instances_in_data_split['val'], instances_in_data_split['test']])

print(f'Train videos: {len(videos_in_data_split["train"])}, Train instances: {len(instances_in_data_split["train"])}')
print(f'Val videos: {len(videos_in_data_split["val"])}, Val instances: {len(instances_in_data_split["val"])}')
print(f'Test videos: {len(videos_in_data_split["test"])}, Test instances: {len(instances_in_data_split["test"])}')

train | toy_car | Videos:    [141, 142, 143, 180, 181, 182, 183, 184, 185, 186, 187, 188, 213, 214, 215, 216, 217, 218, 222, 223, 224]
train |     can | Videos:    [0, 1, 2, 24, 25, 26, 72, 73, 74, 93, 94, 95]
train | distractor | Videos:    [9, 10, 11, 42, 43, 44, 138, 139, 140, 153, 154, 155]
  val | toy_car | Videos:    [120, 121, 122, 144, 145, 146, 159, 160, 161, 168, 169, 170, 249, 250, 251, 279, 280, 281, 282, 283, 284]
  val |     can | Videos:    [45, 46, 47, 48, 49, 50, 51, 52, 53, 66, 67, 68, 96, 97, 98]
  val | distractor | Videos:    [27, 28, 29, 57, 58, 59, 63, 64, 65, 78, 79, 80, 102, 103, 104, 126, 127, 128, 177, 178, 179, 261, 262, 263, 288, 289, 290, 291, 292, 293]
 test | toy_car | Videos:    [195, 196, 197, 204, 205, 206, 225, 226, 227, 228, 229, 230]
 test |     can | Videos:    [3, 4, 5, 21, 22, 23, 30, 31, 32, 39, 40, 41]
 test | distractor | Videos:    [12, 13, 14, 15, 16, 17, 18, 19, 20, 33, 34, 35]
train | toy_car | Instances: [456, 458, 461, 470]
train |     

In [7]:
def convert_bbox_to_yolo(bbox, img_width, img_height):
    """
    Convert bounding box from [x, y, width, height] format to YOLO format
    [x_center, y_center, width, height] normalized by image dimensions
    """
    x, y, w, h = bbox
    x_center = (x + w/2) / img_width
    y_center = (y + h/2) / img_height
    norm_width = w / img_width
    norm_height = h / img_height
    return x_center, y_center, norm_width, norm_height

def process_annotation_file(json_file_path):
    """
    Process a video annotation file and convert to YOLO format
    Returns a dictionary mapping frame_id to list of YOLO format annotations
    """
    with open(json_file_path, 'r') as f:
        annotations = json.load(f)['annotations']
    
    frame_annotations = defaultdict(list)
    
    for anno in annotations:
        # if anno.get('ignore', False):  # Skip ignored annotations
        #     print(f"Skipping ignored annotation: {anno}")
        #     if anno['category_id'] in wallet_instances.union(can_instances):
        #         print('can or wallet skipped')
        #     continue
            
        frame_id = anno['image_id']
        class_id = anno['category_id']
        bbox = anno['bbox']  # [x, y, width, height]
        img_width = anno['width']
        img_height = anno['height']
        
        # Convert to YOLO format
        x_center, y_center, norm_width, norm_height = convert_bbox_to_yolo(bbox, img_width, img_height)
        
        yolo_annotation = {
            'class_id': class_id,
            'x_center': x_center,
            'y_center': y_center,
            'width': norm_width,
            'height': norm_height
        }
        
        frame_annotations[frame_id].append(yolo_annotation)
    
    return frame_annotations

In [None]:
# toy_car : 0
# can : 1
category_to_label = {category: i for i, category in enumerate(target_categories)}
other_categories = sorted(list(set(all_categories) - set(target_categories)))
for category in other_categories:
	category_to_label[category] = len(category_to_label)
label_to_category_name = {x: category for category, x in category_to_label.items()} # maps class numbers (from  darknet labels) to category/class names

In [9]:
label_to_category_name

{0: 'toy_car',
 1: 'can',
 2: 'bottle',
 3: 'bowl',
 4: 'box-base_link',
 5: 'box-link1',
 6: 'brush',
 7: 'chip_can',
 8: 'clip-link1',
 9: 'clip-link2',
 10: 'clock',
 11: 'container',
 12: 'cutter-base_link',
 13: 'cutter-link1',
 14: 'drinkbox',
 15: 'dustpan',
 16: 'hammer',
 17: 'handbag',
 18: 'helmet',
 19: 'marker',
 20: 'mug',
 21: 'notebook',
 22: 'pan',
 23: 'ping_pong_ball',
 24: 'plate',
 25: 'ramen_box',
 26: 'ramen_package',
 27: 'razor',
 28: 'remote',
 29: 'sausage',
 30: 'scissor-link1',
 31: 'scissor-link2',
 32: 'slipper',
 33: 'snack_box',
 34: 'snack_package',
 35: 'sneaker',
 36: 'spanner',
 37: 'squeegee',
 38: 'steel_tape',
 39: 'tape',
 40: 'tennis_ball',
 41: 'thermos',
 42: 'tissue',
 43: 'toothbrush',
 44: 'toys',
 45: 'trash_bin',
 46: 'wallet'}

# Dataset Processing

In [12]:
output_dir = f'{ROOT_DIR}/toycar_can_v0'

In [None]:
# Create train, val, test splits
for split in ['train', 'val', 'test']:
	if not os.path.exists(f'{output_dir}/{split}/images/'):
		os.makedirs(f'{output_dir}/{split}/images/')
	if not os.path.exists(f'{output_dir}/{split}/labels/'):
		os.makedirs(f'{output_dir}/{split}/labels/')
	if not os.path.exists(f'{output_dir}/{split}/full_labels/'):
		os.makedirs(f'{output_dir}/{split}/full_labels/')

for video_path in glob.glob(f'{ROOT_DIR}/test/*')+glob.glob(f'{ROOT_DIR}/val_inst/*'):
	video_id = int(video_path.split('/')[-1])
	if video_id in videos_in_data_split['train']:
		split = 'train'
	elif video_id in videos_in_data_split['val']:
		split = 'val'
	elif video_id in videos_in_data_split['test']:
		split = 'test'

	labels = process_annotation_file(f'{video_path}/scene_gt_coco_det_modal_inst.json')
	for frame_id in glob.glob(f'{video_path}/rgb/*'):
		# print(f"cp {frame_id} {ROOT_DIR}/wallet_can_v0/{split}/{str(video_id)}/{video_id}_{os.path.basename(frame_id)}")
		os.system(f"cp {frame_id} {output_dir}/{split}/images/{str(video_id)}_{int(os.path.basename(frame_id).split('.')[0])}.png")

		frame_num = int(os.path.basename(frame_id).split('.')[0])
		full_label_path = f"{output_dir}/{split}/full_labels/{str(video_id)}_{frame_num}.txt"
		label_path = f"{output_dir}/{split}/labels/{str(video_id)}_{frame_num}.txt"

		for label in labels[frame_num]:
			category_name_search = [cat for cat in instances_of_category if label['class_id'] in instances_of_category[cat]]
			if len(category_name_search) == 0:
				print(f"Warning: No category found for label {label['class_id']} in video {video_id}, frame {frame_num}.")
				continue
			assert len(category_name_search) == 1, f"Multiple categories found for label {label['class_id']} in video {video_id}, frame {frame_num}."
			category_name = category_name_search[0]
			
			class_id = category_to_label[category_name]
			if category_name in target_categories:
				# if the category is one of the chosen categories, write to the label file
				with open(label_path, 'a') as f:
					f.write(f"{class_id} {label['x_center']} {label['y_center']} {label['width']} {label['height']}\n")
			
			# write to the full label file regardless of category
			with open(full_label_path, 'a') as f:
				f.write(f"{class_id} {label['x_center']} {label['y_center']} {label['width']} {label['height']}\n")

In [16]:
for split in ['train', 'val', 'test']:
    for image_path in glob.glob(f'{output_dir}/{split}/images/*'):
        label_path = image_path.replace('.png', '.txt').replace('/images', '/labels')
        full_label_path = image_path.replace('.png', '.txt').replace('/images', '/full_labels')
        if not os.path.exists(label_path):
            os.system(f"touch {label_path}")
        if not os.path.exists(full_label_path):
            os.system(f"touch {full_label_path}")

In [14]:
for split in ['train', 'val', 'test']:
	path = f"{output_dir}/{split}"
	for image_path in glob.glob(f'{path}/images/*'):
		if not os.path.exists(image_path.replace('.png', '.txt').replace('/images', '/labels')):
			print(f"Missing label for image: {image_path}")

# for label_path in glob.glob(f'{path}/labels/*'):
# 	if not os.path.exists(label_path.replace('.txt', '.png').replace('/labels', '/images')):
# 		print(f"Missing image for label file: {label_path}")

In [15]:
# print number of images and labels in each split
for split in ['train', 'val', 'test']:
	num_images = len(glob.glob(f'{output_dir}/{split}/images/*'))
	num_labels = len(glob.glob(f'{output_dir}/{split}/labels/*'))
	num_full_labels = len(glob.glob(f'{output_dir}/{split}/full_labels/*'))
	print(f"{split.capitalize()} - Images: {num_images}, Labels: {num_labels}, Full Labels: {num_full_labels}")

Train - Images: 15854, Labels: 15854, Full Labels: 15630
Val - Images: 25881, Labels: 25881, Full Labels: 24896
Test - Images: 13210, Labels: 13210, Full Labels: 13210


# Segment to get foreground object images

In [52]:
foreground_output_dir = '/home/data/pace/toycar_can_v0/foreground_objects'
source_dir = '/home/data/pace/toycar_can_v0/train'
frame_skip=70

In [34]:
def read_darknet_bboxes(bbox_path, image_width, image_height):
	"""Read bounding boxes from darknet format file and convert to pixel coordinates"""
	bboxes = []
	
	with open(bbox_path, 'r') as f:
		foundLine = False
		class_ids = []
		for line in f:
			foundLine = True
			parts = line.strip().split()
			assert len(parts) == 5, f"Invalid bbox line: {line.strip()}"
			
			# Darknet format: class_id x_center y_center width height (normalized)
			class_id = int(parts[0])
			class_ids.append(class_id)

			x_center, y_center, width, height = map(float, parts[1:5])

			# Convert from normalized coordinates to pixel coordinates
			x_center_px = x_center * image_width
			y_center_px = y_center * image_height
			width_px = width * image_width
			height_px = height * image_height
			
			# Convert to x1, y1, x2, y2 format
			x1 = int(x_center_px - width_px / 2)
			y1 = int(y_center_px - height_px / 2)
			x2 = int(x_center_px + width_px / 2)
			y2 = int(y_center_px + height_px / 2)
			
			# Ensure coordinates are within image bounds
			x1 = max(0, min(x1, image_width - 1))
			y1 = max(0, min(y1, image_height - 1))
			x2 = max(0, min(x2, image_width - 1))
			y2 = max(0, min(y2, image_height - 1))
			
			bboxes.append([x1, y1, x2, y2])
		if not foundLine:
			print(f"No bounding boxes found in {bbox_path}. Returning empty list.")
	return class_ids, bboxes


In [39]:
model = SAM("sam2.1_l.pt")

def segment_images_from_folder_bbox(root_dir, output_dir, frame_skip):
	"""
	Segments images in the specified folder using the SAM model with bbox information.
	Assumes root_dir contains two folders: 'images' and 'full_labels'.
	Each image in 'images' should have a corresponding label file in 'full_labels' with
	bounding box information in the format: x y w h (where x, y are the
	top-left corner coordinates and w, h are the width and height of the bounding box).
	"""
	os.mkdir(output_dir) if not os.path.exists(output_dir) else None
	subdir_path = os.path.join(output_dir, 'masks')
	if not os.path.exists(subdir_path): 
		os.mkdir(subdir_path)

	for image_path, bbox_path in list(zip(sorted(glob.glob(os.path.join(root_dir, 'images', '*'))), 
										  sorted(glob.glob(os.path.join(root_dir, 'full_labels', '*.txt'))))):
		image_name = os.path.basename(image_path)
		video_id, frame_id = map(int, image_name.split('.')[0].split('_')[:2])
		if frame_id % frame_skip != 0:
			continue
		print(f"Processing video {video_id}, frame {frame_id} from {image_path}")
		
		image_dimensions = cv2.imread(image_path).shape
		class_ids, bboxes = read_darknet_bboxes(bbox_path, image_dimensions[1], image_dimensions[0])
		if len(set(class_ids)) != len(class_ids):
			print(f"WARNING: Duplicate class IDs found in {bbox_path}: {class_ids}")
			continue

		if len(bboxes) == 0:
			print(f"No bounding boxes found for {image_path}. Skipping.")
			continue
		
		# Predict segmentation using the SAM model with bounding box
		results = model(image_path, bboxes=bboxes)[0]
		# visualize_image_annotations(image_path, bbox_path, output_dir)
		for class_id, mask in zip(class_ids, results.masks):
			# Assuming single class segmentation for simplicity, adjust as needed
			mask = mask.data.squeeze().cpu().numpy()  # For multi-class, iterate over masks
			# if len(bboxes) > 1:
			# 	print(f"Sample mask shape: {mask.shape}, bboxes: {bboxes}, masks: {masks.shape}")
			# 	1/0
			mask = mask.astype(np.uint8) # Convert mask to uint8 if needed)
			mask = cv2.resize(mask, (image_dimensions[1], image_dimensions[0]))
			
			image = cv2.imread(image_path)
			image = cv2.resize(image, (image_dimensions[1], image_dimensions[0]))
			
			# Negate the mask and mask the image
			negative_mask = 1 - mask
			negative_image = cv2.bitwise_not(image)
			negative_image = cv2.bitwise_and(negative_image, negative_image, mask=mask)

			cv2.imwrite(os.path.join(output_dir, 'masks', os.path.basename(image_path).split('.')[0] + f'_mask_{class_id}.png'), negative_mask*255)

In [None]:
# for class_dir in glob.glob(os.path.join(output_dir, '*')):
segment_images_from_folder_bbox(source_dir, source_dir, frame_skip=frame_skip)

Processing video 0, frame 0 from /home/data/pace/toycar_can_v0/train/images/0_0.png

image 1/1 /home/data/pace/toycar_can_v0/train/images/0_0.png: 1024x1024 1 0, 1 1, 1 2, 229.6ms
Speed: 3.8ms preprocess, 229.6ms inference, 0.5ms postprocess per image at shape (1, 3, 1024, 1024)
Processing video 0, frame 140 from /home/data/pace/toycar_can_v0/train/images/0_140.png

image 1/1 /home/data/pace/toycar_can_v0/train/images/0_140.png: 1024x1024 1 0, 1 1, 1 2, 231.7ms
Speed: 3.7ms preprocess, 231.7ms inference, 0.5ms postprocess per image at shape (1, 3, 1024, 1024)
Processing video 0, frame 70 from /home/data/pace/toycar_can_v0/train/images/0_70.png

image 1/1 /home/data/pace/toycar_can_v0/train/images/0_70.png: 1024x1024 1 0, 1 1, 1 2, 229.5ms
Speed: 3.7ms preprocess, 229.5ms inference, 0.5ms postprocess per image at shape (1, 3, 1024, 1024)
Processing video 10, frame 0 from /home/data/pace/toycar_can_v0/train/images/10_0.png
Processing video 10, frame 140 from /home/data/pace/toycar_can_v0

In [55]:
'''
folder structure:
/home/data/pace/toycar_can_v0/
train/
	images/
		{video_id}_{frame_id}.png
		0_0.png
		0_1.png
		...
	labels/
		{video_id}_{frame_id}.txt
		0_0.txt
		0_1.txt
		...
	masks/
		{video_id}_{frame_id}_mask_{class_id}.png
		0_0_mask_1.png
		0_0_mask_33.png
		0_1_mask_25.png
        ...
'''

# copy all the masks and their corresponding images to a new folder, with a subfolder depending on the class (as mentioned in the label file)
if not os.path.exists(foreground_output_dir):
	os.mkdir(foreground_output_dir)


for mask_img_path in glob.glob(os.path.join('/home/data/pace/toycar_can_v0/train/masks', '*_mask_*.png')):
	image_name = os.path.basename(mask_img_path)
	video_id, frame_id, _, class_id = image_name.split('_')
	video_id = int(video_id)
	frame_id = int(frame_id)
	class_id = int(class_id.split('.')[0])  # Extract class ID from the filename
	category = label_to_category_name[class_id]
	if category not in target_categories+distractor_categories:
		print(f"Skipping mask {mask_img_path} with class ID {class_id} not in target categories.")
		continue
	
	# if class_id < 1e9:
	class_dir = os.path.join(foreground_output_dir, label_to_category_name[class_id])
	# else:
	# 	class_dir = os.path.join(foreground_output_dir, f'distractor}')
	if not os.path.exists(class_dir):
		os.mkdir(class_dir)
	if not os.path.exists(os.path.join(class_dir, 'masks')):
		os.mkdir(os.path.join(class_dir, 'masks'))
	if not os.path.exists(os.path.join(class_dir, 'images')):
		os.mkdir(os.path.join(class_dir, 'images'))

	# Copy the mask image
	os.system(f"cp {mask_img_path} {class_dir}/masks/{video_id}_{frame_id}_mask.png")
	
	# Also copy the corresponding original image
	original_image_path = f'/home/data/pace/toycar_can_v0/train/images/{video_id}_{frame_id}.png'
	# if os.path.exists(original_image_path):
	assert os.path.exists(original_image_path), f"Original image {original_image_path} not found for {mask_img_path}"
	os.system(f"cp {original_image_path} {class_dir}/images/{video_id}_{frame_id}.png")

Skipping mask /home/data/pace/toycar_can_v0/train/masks/155_70_mask_37.png with class ID 37 not in target categories.
Skipping mask /home/data/pace/toycar_can_v0/train/masks/8_70_mask_41.png with class ID 41 not in target categories.
Skipping mask /home/data/pace/toycar_can_v0/train/masks/240_210_mask_43.png with class ID 43 not in target categories.
Skipping mask /home/data/pace/toycar_can_v0/train/masks/125_70_mask_25.png with class ID 25 not in target categories.
Skipping mask /home/data/pace/toycar_can_v0/train/masks/74_70_mask_13.png with class ID 13 not in target categories.
Skipping mask /home/data/pace/toycar_can_v0/train/masks/26_140_mask_25.png with class ID 25 not in target categories.
Skipping mask /home/data/pace/toycar_can_v0/train/masks/273_140_mask_6.png with class ID 6 not in target categories.
Skipping mask /home/data/pace/toycar_can_v0/train/masks/24_140_mask_41.png with class ID 41 not in target categories.
Skipping mask /home/data/pace/toycar_can_v0/train/masks/191


# Discarded

In [None]:
# Given a directory with a images and labels folder, I want pictures of each object annotated by a label
# Look through each image with its corrseponding label file, and crop out the object from the image based on the bounding box
# Save the cropped images in a new directory with the filename format {video_id}_{frame_id}_{class_id}.png
import cv2, os, glob

def crop_and_save_objects(root_dir, output_dir, frame_skip, exclude_classes=[]):
	if not os.path.exists(output_dir):
		os.makedirs(output_dir)

	images_path = os.path.join(root_dir, 'images')
	labels_path = os.path.join(root_dir, 'labels')

	for label_file in glob.glob(os.path.join(labels_path, '*.txt')):
		video_id, frame_id = os.path.basename(label_file).split('.')[0].split('_')[:2]
		image_file = os.path.join(images_path, f"{video_id}_{frame_id}.png")

		# skip frames so we only take objects from every frame_skip-th frame
		if int(frame_id) % frame_skip != 0:
			continue
		
		if not os.path.exists(image_file):
			print(f"Image file {image_file} does not exist. Skipping.")
			continue
		
		image = cv2.imread(image_file)
		if image is None:
			print(f"Failed to read image {image_file}. Skipping.")
			continue
		
		with open(label_file, 'r') as f:
			for line in f:
				class_id, x_center, y_center, width, height = map(float, line.strip().split())
				if int(class_id) in exclude_classes:
					continue

				x_center *= image.shape[1]
				y_center *= image.shape[0]
				width *= image.shape[1]
				height *= image.shape[0]
				width *= 1.2  # Optional: increase width by 20% for better cropping
				height *= 1.2  # Optional: increase height by 20% for better cropping

				x1 = max(0,int(x_center - width / 2))
				y1 = max(0,int(y_center - height / 2))
				x2 = min(image.shape[1],int(x_center + width / 2))
				y2 = min(image.shape[0],int(y_center + height / 2))

				# Crop the object from the image
				cropped_image = image[y1:y2, x1:x2]
				
				# Save the cropped image
				output_filename = f"{video_id}_{frame_id}_{int(class_id)}.png"
				
				if not os.path.exists(os.path.join(output_dir, str(int(class_id)))):
					os.makedirs(os.path.join(output_dir, str(int(class_id))))

				output_path = os.path.join(output_dir, str(int(class_id)), output_filename)
				cv2.imwrite(output_path, cropped_image)

root = "/home/data/pace/toycar_can_v0/train"
output_dir = os.path.join(root, '..', 'cropped_objects_train')
crop_and_save_objects(root, output_dir, frame_skip=120, exclude_classes=[0])

In [None]:
model = SAM("sam2.1_b.pt")

In [None]:
results = model('/home/data/pace/toycar_can_v0/cropped_objects_train/1/0_0_1.png')

In [None]:
from ultralytics import SAM
import numpy as np


In [None]:
# segment the cropped object images
for class_dir in glob.glob(os.path.join(output_dir, '*')):
	assert os.path.isdir(class_dir)
	class_name = os.path.basename(class_dir)
	output_class_dir = os.path.join(output_dir, class_name)
	if not os.path.exists(output_class_dir):
		os.makedirs(output_class_dir)
	
	model = SAM("sam2.1_l.pt")
	model.to('cuda')  # Move model to GPU if available

	for image_path in sorted(glob.glob(os.path.join(class_dir, '*'))):
		if not os.isdir(image_path):
			continue
		print(f"Processing image: {image_path}")
		image_dimensions = cv2.imread(image_path).shape
		
		# Predict segmentation using the SAM model
		results = model(image_path)[0]
		masks = results.masks
		assert masks is not None, f"No masks found for image {image_path}"

		# Assuming single class segmentation for simplicity, adjust as needed
		mask = masks[0].data.squeeze().cpu().numpy()  # For multi-class, iterate over masks
		mask = mask.astype(np.uint8) # Convert mask to uint8 if needed)
		mask = cv2.resize(mask, (image_dimensions[1], image_dimensions[0]))
		
		image = cv2.imread(image_path)
		image = cv2.resize(image, (image_dimensions[1], image_dimensions[0]))
		
		# Negate the mask and mask the image
		negative_mask = 1-mask
		negative_image = cv2.bitwise_not(image)
		negative_image = cv2.bitwise_and(negative_image, negative_image, mask=mask)
		masked_image = cv2.bitwise_not(negative_image)

		os.mkdir(output_class_dir) if not os.path.exists(output_class_dir) else None
		if not os.path.exists(os.path.join(output_class_dir, 'masks')): 
			os.mkdir(os.path.join(output_class_dir, 'masks'))
				
		cv2.imwrite(os.path.join(output_class_dir, 'masks', os.path.basename(image_path).split('.')[0] + '_mask.png'), negative_mask*255)