In [1]:
import geopandas as gpd

CLASS_MAP = {
    "Coconut": 0,
    "Mango": 1,
    "Banana": 2,
    "Papaya": 3,
}

def geojson_to_yolo_multiclass(
    tile_geojson_path,
    output_txt_path,
    img_size=256,
):
    gdf = gpd.read_file(tile_geojson_path)

    if gdf.empty:
        open(output_txt_path, "w").close()
        return

    minx, miny, maxx, maxy = gdf.total_bounds
    tile_w = maxx - minx
    tile_h = maxy - miny

    yolo_lines = []

    for _, row in gdf.iterrows():
        geom = row.geometry
        species = row.get("species_mapped")

        if species not in CLASS_MAP:
            continue

        class_id = CLASS_MAP[species]

        bx_min, by_min, bx_max, by_max = geom.bounds

        x_center = ((bx_min + bx_max) / 2 - minx) / tile_w
        y_center = ((by_min + by_max) / 2 - miny) / tile_h
        w = (bx_max - bx_min) / tile_w
        h = (by_max - by_min) / tile_h

        # flip Y for YOLO
        y_center = 1 - y_center

        if 0 < w <= 1 and 0 < h <= 1:
            yolo_lines.append(
                f"{class_id} {x_center:.6f} {y_center:.6f} {w:.6f} {h:.6f}"
            )

    with open(output_txt_path, "w") as f:
        f.write("\n".join(yolo_lines))


In [2]:
import os

input_dir = "labels/geojson"
output_dir = "labels/yolo"

os.makedirs(output_dir, exist_ok=True)

for fname in os.listdir(input_dir):
    if fname.endswith(".geojson"):
        geojson_path = os.path.join(input_dir, fname)
        txt_path = os.path.join(
            output_dir,
            fname.replace(".geojson", ".txt")
        )

        geojson_to_yolo_multiclass(
            tile_geojson_path=geojson_path,
            output_txt_path=txt_path,
        )


In [3]:
from collections import Counter

counter = Counter()

for f in os.listdir("labels/yolo"):
    with open(os.path.join("labels/yolo", f)) as file:
        for line in file:
            cls = int(line.split()[0])
            counter[cls] += 1

print(counter)


Counter({0: 11726, 1: 297, 2: 205, 3: 109})
