In [14]:
from datasets import load_dataset

# Login using e.g. `huggingface-cli login` to access this dataset
ds = load_dataset("EPFL-ECEO/coralscapes")
#train_dataset = load_dataset("EPFL-ECEO/coralscapes", split="train")
#valid_dataset = load_dataset("EPFL-ECEO/coralscapes", split="validation")
#test_dataset  = load_dataset("EPFL-ECEO/coralscapes", split="test")

Generating train split:   0%|          | 0/1517 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/166 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/392 [00:00<?, ? examples/s]

In [20]:
from datasets import load_dataset
from PIL import Image
import numpy as np
from io import BytesIO
from pathlib import Path
import pyarrow as pa
import pyarrow.parquet as pq
from tqdm import tqdm
import os

# ----- class id sets -----
HEALTHY_IDS = {6, 17, 22, 25, 28, 31, 34, 36, 27}
UNHEALTHY_IDS = {3, 4, 16, 23, 19, 20, 32, 33, 37}
RED   = np.array([255, 0,   0], dtype=np.uint8)
BLUE  = np.array([0,   0, 255], dtype=np.uint8)
BLACK = np.array([0,   0,   0], dtype=np.uint8)

# ----- helpers -----
def pil_to_png_bytes(img: Image.Image) -> bytes:
    buf = BytesIO(); img.save(buf, format="PNG"); return buf.getvalue()

def label_pil_to_id_array(lbl: Image.Image) -> np.ndarray:
    if lbl.mode in ("P", "L", "I"):
        return np.array(lbl).astype(np.int32)
    raise ValueError(f"Label mode {lbl.mode} is not an ID mask.")

def ids_to_health_rgb(id_mask: np.ndarray) -> Image.Image:
    out = np.zeros_like(id_mask, dtype=np.uint8)
    out[np.isin(id_mask, list(HEALTHY_IDS))] = 1
    out[np.isin(id_mask, list(UNHEALTHY_IDS))] = 2
    h, w = out.shape
    rgb = np.zeros((h, w, 3), dtype=np.uint8)
    rgb[out == 1] = RED
    rgb[out == 2] = BLUE
    return Image.fromarray(rgb, mode="RGB")

def table_size_bytes(table: pa.Table, compression: str = "zstd") -> int:
    sink = pa.BufferOutputStream()
    pq.write_table(table, sink, compression=compression)
    return sink.getvalue().size

def write_parquet_chunk(rows, out_path: Path, compression: str = "zstd"):
    table = pa.table(rows)
    pq.write_table(table, out_path.as_posix(), compression=compression)

def export_split_to_parquet(
    ds_split,
    split_name: str,
    out_parquet_dir: Path,
    include_images: bool,
    target_mb: float = 19.0,
    compression: str = "zstd",
    preview_dir: Path | None = None,
    preview_n: int = 5,
):
    out_parquet_dir.mkdir(parents=True, exist_ok=True)
    if preview_dir: preview_dir.mkdir(parents=True, exist_ok=True)
    TARGET = int(target_mb * 1024 * 1024)

    cur = {"split": [], "index": [], "label_health_rgb_png": []}
    if include_images: cur["image_png"] = []
    part_idx, n = 1, len(ds_split)

    for i in tqdm(range(n), desc=f"{split_name}: recolor+pack"):
        rec = ds_split[i]
        img: Image.Image  = rec["image"]
        lbl: Image.Image  = rec["label"]

        ids = label_pil_to_id_array(lbl)
        health_rgb = ids_to_health_rgb(ids)

        label_png = pil_to_png_bytes(health_rgb)
        if include_images: image_png = pil_to_png_bytes(img)

        cur["split"].append(split_name)
        cur["index"].append(i)
        cur["label_health_rgb_png"].append(label_png)
        if include_images: cur["image_png"].append(image_png)

        est = table_size_bytes(pa.table(cur), compression=compression)
        if est > TARGET:
            for k in list(cur.keys()): cur[k].pop()
            if len(cur["index"]) > 0:
                out_path = out_parquet_dir / f"{split_name}_part{part_idx:03d}.parquet"
                write_parquet_chunk(cur, out_path, compression=compression)
                part_idx += 1
                cur = {k: [] for k in cur.keys()}
            cur["split"].append(split_name)
            cur["index"].append(i)
            cur["label_health_rgb_png"].append(label_png)
            if include_images: cur["image_png"].append(image_png)
            if table_size_bytes(pa.table(cur), compression=compression) > TARGET:
                out_path = out_parquet_dir / f"{split_name}_part{part_idx:03d}.parquet"
                write_parquet_chunk(cur, out_path, compression=compression)
                part_idx += 1
                cur = {k: [] for k in cur.keys()}

        if preview_dir and i < preview_n:
            health_rgb.save(preview_dir / f"{split_name}_{i:05d}_label_health_rgb.png")

    if len(cur["index"]) > 0:
        out_path = out_parquet_dir / f"{split_name}_part{part_idx:03d}.parquet"
        write_parquet_chunk(cur, out_path, compression=compression)

def run_pipeline(outdir="coralscapes_export", include_images=False, target_mb=19.0):
    out_root = Path(outdir)
    parquet_dir = out_root / "parquet"
    samples_dir = out_root / "samples"
    parquet_dir.mkdir(parents=True, exist_ok=True)
    samples_dir.mkdir(parents=True, exist_ok=True)

    ds = load_dataset("EPFL-ECEO/coralscapes")
    for split in ("train", "validation", "test"):
        if split not in ds: continue
        export_split_to_parquet(
            ds_split=ds[split],
            split_name=split,
            out_parquet_dir=parquet_dir / split,
            include_images=include_images,
            target_mb=target_mb,
            compression="zstd",
            preview_dir=samples_dir / split,
            preview_n=5,
        )
    print(f"Done. Parquet parts in: {parquet_dir}")
    print(f"Preview PNGs in: {samples_dir}")

# <<< RUN IT HERE >>>
run_pipeline(outdir="coralscapes_export", include_images=False, target_mb=19.0)


train: recolor+pack: 100%|██████████| 1517/1517 [03:38<00:00,  6.95it/s]
validation: recolor+pack: 100%|██████████| 166/166 [00:19<00:00,  8.30it/s]
test: recolor+pack: 100%|██████████| 392/392 [00:47<00:00,  8.28it/s]


Done. Parquet parts in: coralscapes_export\parquet
Preview PNGs in: coralscapes_export\samples
