In [1]:
import os
import shutil
from collections import defaultdict

In [14]:
train_dataset_path = 'data/images/train'
train_images = sorted(os.listdir(train_dataset_path))
train_images_paths = [os.path.join(train_dataset_path, image) for image in train_images if image.endswith('.jpg')]

val_dataset_path = 'data/images/val'
val_images = sorted(os.listdir(val_dataset_path))
val_images_paths = [os.path.join(val_dataset_path, image) for image in val_images if image.endswith('.jpg')]

test_dataset_path = 'data/images/test'
test_images = sorted(os.listdir(test_dataset_path))
test_images_paths = [os.path.join(test_dataset_path, image) for image in test_images if image.endswith('.jpg')]

In [15]:
old_class_names = {
  0: 'animal',
  1: 'bike',
  2: 'bird',
  3: 'bus',
  4: 'car',
  5: 'dog',
  6: 'face',
  7: 'hydrant',
  8: 'license plate',
  9: 'light',
  10: 'motor',
  11: 'other vehicle',
  12: 'person',
  13: 'rider',
  14: 'scooter',
  15: 'sign',
  16: 'skateboard',
  17: 'stroller',
  18: 'train',
  19: 'truck'
}
class_filter = [0,2,3,5,6,7,8,14,16,17,18]
new_class_names = {
  0: 'bike',
  1: 'car',
  2: 'light',
  3: 'motor',
  4: 'other vehicle',
  5: 'person',
  6: 'rider',
  7: 'sign',
  8: 'truck'
}

In [41]:
def count_classes(image_paths, class_names):
    class_counts = defaultdict(int)
    tot_count = 0
    for image_path in image_paths:
      label_path = image_path.replace("images", "labels").replace(".jpg", ".txt")
      with open(label_path, 'r') as file:
        for line in file:
            parts = line.strip().split()
            class_id = int(parts[0])
            class_name = class_names.get(class_id, "Unknown class")
            class_counts[class_id] += 1
            tot_count += 1
          
    sorted_class_counts = sorted((class_names[class_id], count) for class_id, count in class_counts.items())
    return tot_count, sorted_class_counts

In [17]:
tot_train_labels, train_labels_counts = count_classes(train_images_paths, old_class_names)
print("Classes in train dataset (TOTAL: ",tot_train_labels ,"): ", train_labels_count)

tot_val_labels, val_labels_counts = count_classes(val_images_paths, old_class_names)
print("Classes in val dataset: (TOTAL: ",tot_val_labels ,"): ", val_labels_count)

tot_test_labels, test_labels_counst = count_classes(test_images_paths, old_class_names)
print("Classes in test dataset: (TOTAL: ",tot_test_labels ,"): ", test_labels_count)

Classes in train dataset (TOTAL:  176063 ):  [('animal', 8), ('bike', 7237), ('bird', 1), ('bus', 2245), ('car', 73623), ('dog', 4), ('face', 752), ('hydrant', 1095), ('license plate', 270), ('light', 16198), ('motor', 1116), ('other vehicle', 1373), ('person', 44527), ('rider', 5951), ('scooter', 15), ('sign', 20770), ('skateboard', 29), ('stroller', 15), ('train', 5), ('truck', 829)]
Classes in val dataset: (TOTAL:  16786 ):  [('bike', 170), ('bus', 179), ('car', 7133), ('face', 73), ('hydrant', 94), ('license plate', 17), ('light', 2005), ('motor', 55), ('other vehicle', 63), ('person', 4309), ('rider', 161), ('sign', 2472), ('skateboard', 3), ('stroller', 6), ('truck', 46)]
Classes in test dataset: (TOTAL:  62459 ):  [('bike', 113), ('car', 30517), ('dog', 25), ('face', 142), ('hydrant', 277), ('light', 6758), ('motor', 3314), ('other vehicle', 696), ('person', 11242), ('rider', 1081), ('sign', 5660), ('truck', 2634)]


In [37]:
def generate_filtered_data(image_paths, old_class_names, new_class_names, class_filter):
  for image_path in image_paths:
    label_path = image_path.replace("images", "labels").replace(".jpg", ".txt")
    new_image_path = image_path.replace("data", "final-data")
    new_label_path = label_path.replace("data", "final-data")
    
    new_lines = []
    with open(label_path, 'r') as file:
        for line in file:
            parts = line.strip().split()
            class_id = int(parts[0])
            class_name = old_class_names[class_id]
          
            if class_id not in class_filter:
              for new_id, name in new_class_names.items():
                if name == class_name:
                    new_class_id = new_id
              parts[0] = str(new_class_id)
              new_line = ' '.join(parts) + '\n'
              new_lines.append(new_line)

    with open(new_label_path, 'w') as file:
        file.writelines(new_lines)
    
    shutil.copyfile(image_path, new_image_path)

In [38]:
generate_filtered_data(train_images_paths, old_class_names, new_class_names, class_filter)
generate_filtered_data(val_images_paths, old_class_names, new_class_names, class_filter)
generate_filtered_data(test_images_paths, old_class_names, new_class_names, class_filter)

In [39]:
new_train_dataset_path = 'final-data/images/train'
new_train_images = sorted(os.listdir(new_train_dataset_path))
new_train_images_paths = [os.path.join(new_train_dataset_path, image) for image in new_train_images if image.endswith('.jpg')]

new_val_dataset_path = 'final-data/images/val'
new_val_images = sorted(os.listdir(new_val_dataset_path))
new_val_images_paths = [os.path.join(new_val_dataset_path, image) for image in new_val_images if image.endswith('.jpg')]

new_test_dataset_path = 'final-data/images/test'
new_test_images = sorted(os.listdir(new_test_dataset_path))
new_test_images_paths = [os.path.join(new_test_dataset_path, image) for image in new_test_images if image.endswith('.jpg')]

In [44]:
new_tot_train_labels, new_train_labels_counts = count_classes(new_train_images_paths, new_class_names)
print("Classes in train dataset (TOTAL: ",new_tot_train_labels ,"): ", new_train_labels_counts)

new_tot_val_labels, new_val_labels_counts = count_classes(new_val_images_paths, new_class_names)
print("Classes in val dataset: (TOTAL: ",new_tot_val_labels ,"): ", new_val_labels_counts)

new_tot_test_labels, new_test_labels_counts = count_classes(new_test_images_paths, new_class_names)
print("Classes in test dataset: (TOTAL: ",new_tot_test_labels ,"): ", new_test_labels_counts)

Classes in train dataset (TOTAL:  171624 ):  [('bike', 7237), ('car', 73623), ('light', 16198), ('motor', 1116), ('other vehicle', 1373), ('person', 44527), ('rider', 5951), ('sign', 20770), ('truck', 829)]
Classes in val dataset: (TOTAL:  16414 ):  [('bike', 170), ('car', 7133), ('light', 2005), ('motor', 55), ('other vehicle', 63), ('person', 4309), ('rider', 161), ('sign', 2472), ('truck', 46)]
Classes in test dataset: (TOTAL:  62015 ):  [('bike', 113), ('car', 30517), ('light', 6758), ('motor', 3314), ('other vehicle', 696), ('person', 11242), ('rider', 1081), ('sign', 5660), ('truck', 2634)]
