In [3]:
import os
import random
from pathlib import Path
from copy import deepcopy
import matplotlib.pyplot as plt

from ml_carbucks.utils.conversions import convert_yolo_to_coco
from ml_carbucks.utils.logger import setup_logger
import json

logger = setup_logger(__name__)

def create_counter(images_dir_root: str | Path, splits: list = ["train", "val"], normalize: bool = False):
    counter = {"all": {}, "img_counts": {}}
    for split in splits:
        # load each file and count how many classes there are in each split and what is their distribution
        split_dir = os.path.join(images_dir_root, split)
        counter[split] = {}
        for root, dirs, files in os.walk(split_dir):
            for file in files:
                if not file.endswith(".jpg"):
                    continue

                file_path = os.path.join(root, file)
                label_path = file_path.replace("images", "labels").replace(".jpg", ".txt")

                counter["img_counts"][split] = counter["img_counts"].get(split, 0) + 1
                counter["img_counts"]['all'] = counter["img_counts"].get('all', 0) + 1
                if not os.path.exists(label_path):
                    counter['all']['no_label'] = counter['all'].get("no_label", 0) + 1
                    counter[split]["no_label"] = counter[split].get("no_label", 0) + 1
                    continue

                with open(label_path, "r") as f:
                    if len(f.read().strip()) == 0:
                        counter['all']['no_label'] = counter['all'].get("no_label", 0) + 1
                        counter[split]["no_label"] = counter[split].get("no_label", 0) + 1
                        continue
                    f.seek(0)
                    for line in f:
                        class_id = line.strip().split()[0]
                        counter[split][class_id] = counter[split].get(class_id, 0) + 1
                        counter['all'][class_id] = counter['all'].get(class_id, 0) + 1

    if normalize:
        counter_normalized = deepcopy(counter)
        for split in splits + ['all']:
            total = sum(counter_normalized[split].values())
            for class_id in counter_normalized[split]:
                counter_normalized[split][class_id] = round(counter_normalized[split][class_id] / total * 100, 4) 

            counter_normalized[split] = dict(sorted(counter_normalized[split].items(), key=lambda item: (item[0] != "no_label", int(item[0]) if item[0] != "no_label" else -1)))
        return counter_normalized
    else:
        for split in splits + ['all']:
            counter[split] = dict(sorted(counter[split].items(), key=lambda item: (item[0] != "no_label", int(item[0]) if item[0] != "no_label" else -1)))
    return counter

def visualize_counter(counter:dict, split: str = 'all', counter_name: str = ""):
    plt.bar(
        x=list(counter[split].keys()),
        height=[v for v in counter[split].values()]
    )
    plt.xlabel("Class ID")
    plt.ylabel("Proportion")
    # write actual numbers on top of bars
    for i, v in enumerate(counter[split].values()):
        plt.text(i, v + 0.5, str(v), ha='center')
    plt.title(f"Class Distribution in '{split}' Split for {counter_name} Dataset")
    plt.show()

def display_dataset_analysis(images_dir_root: str | Path, splits: list = ["train", "val"], counter_name: str = "", normalize: bool = False, visualize_splits: list = ['all']):
    counter = create_counter(images_dir_root, splits, normalize)
    print(f"Dataset Analysis for {counter_name} Dataset:")
    print(json.dumps(counter, indent=4))
    for split in visualize_splits:
        visualize_counter(counter, split, counter_name)

def clean_up_empty_labels(dataset_dir: str | Path, splits: list):
    print(f"Cleaning up empty labels in dataset at: {dataset_dir}")
    for split in splits:
        for root, dirs, files in os.walk(Path(dataset_dir) / "images" / split):
            for file in files:
                if not file.endswith(".jpg"):
                    continue
                
                img_file_path = os.path.join(root, file)
                label_file_path = img_file_path.replace(".jpg", ".txt").replace("images", "labels")
                img_name = file
                label_name = img_name.replace(".jpg", ".txt")

                if not os.path.exists(label_file_path):
                    print(f"Found image with no corresponding label file: {img_file_path}")
                    os.makedirs(os.path.join(dataset_dir, "images", "empty", split), exist_ok=True)
                    # move image file
                    new_img_path = os.path.join(dataset_dir, "images", "empty", split, file)
                    os.rename(img_file_path, new_img_path)
                    continue

                with open(label_file_path, "r") as f:
                    lines = f.readlines()
                
                if len(lines) == 0:
                    print(f"Found empty label file: {label_file_path}")
                    os.makedirs(os.path.join(dataset_dir, "images", "empty", split), exist_ok=True)
                    os.makedirs(os.path.join(dataset_dir, "labels", "empty", split), exist_ok=True)
                    # move label file
                    new_label_path = os.path.join(dataset_dir, "labels", "empty", split, label_name)
                    os.rename(label_file_path, new_label_path)
                    # move image file
                    new_img_path = os.path.join(dataset_dir, "images", "empty", split, img_name)
                    os.rename(img_file_path, new_img_path)

    convert_yolo_to_coco(
        base_dir=dataset_dir,
        splits=splits,
    )
                    
def balance_dataset(dataset_dir: str | Path, splits: list, remove_class_probabilities: dict[str, float] | None = None):
    for split in splits:
        files_moved_cnt = {class_id: 0 for class_id in remove_class_probabilities.keys()} if remove_class_probabilities else {}
        for root, dirs, files in os.walk(Path(dataset_dir) / "labels" / split):
            for file in files:
                if not file.endswith(".txt"):
                    continue
                file_path = os.path.join(root, file)
                with open(file_path, "r") as f:
                    lines = f.readlines()
                
                class_labels = set()
                for line in lines:
                    class_id = line.strip().split()[0]
                    class_labels.add(class_id)

                # we want to move the files only that have pure one class labels, not mixed
                if len(class_labels) != 1:
                    continue

                class_id = class_labels.pop()
                move_file = False
                if remove_class_probabilities and class_id in remove_class_probabilities:
                    prob = remove_class_probabilities[class_id]
                    if random.random() <= prob:
                        move_file = True

                if move_file:
                    files_moved_cnt[class_id] += 1
                    print(f"Moving pure class {class_id} label file: {file_path}")
                    os.makedirs(os.path.join(dataset_dir, "images", "balancing", split), exist_ok=True)
                    os.makedirs(os.path.join(dataset_dir, "labels", "balancing", split), exist_ok=True)
                    # move label file
                    new_label_path = os.path.join(dataset_dir, "labels", "balancing", split, file)
                    os.rename(file_path, new_label_path)
                    # move image file
                    img_file = file.replace(".txt", ".jpg")
                    img_path = os.path.join(dataset_dir, "images", split, img_file)
                    if os.path.exists(img_path):
                        new_img_path = os.path.join(dataset_dir, "images", "balancing", split, img_file)
                        os.rename(img_path, new_img_path)
                    else:
                        print(f"Corresponding image file not found for label: {file_path}")
        if files_moved_cnt:
            print(f"Moved files for {split}: {files_moved_cnt}")


    convert_yolo_to_coco(
        base_dir=dataset_dir,
        splits=splits,
    )
  

In [None]:
display_dataset_analysis("/home/bachelor/ml-carbucks/data/final_carbucks/standard/images", splits=["train", "val", "test"], visualize_splits=["all"], normalize=False )
display_dataset_analysis("/home/bachelor/ml-carbucks/data/final_carbucks/standard/images", splits=["train", "val", "test"], visualize_splits=["all"], normalize=True )

In [None]:
display_dataset_analysis("/home/bachelor/ml-carbucks/data/car_dd/images", splits=["train", "val", "test"], visualize_splits=["all"], normalize=False )