This notebook will download the CloudSEN12 dataset via Hugging Face, we get the data labeled as "high", "scribble" and "2k"

The only thing you will need to change in this is the base_dataset_dir to a local drive with 300 GB of available storage

In [None]:
import tacoreader
import rasterio as rio
from tqdm.auto import tqdm
from pathlib import Path
from multiprocessing.pool import ThreadPool
import time
from threading import Thread
import numpy as np
from typing import Optional

In [None]:
base_dataset_dir = Path("/media/nick/4TB Working 7/Datasets/OCM datasets")

In [None]:
high_dir = base_dataset_dir / "CloudSEN12 high"
scribble_dir = base_dataset_dir / "CloudSEN12 scribble"
two_k_dir = base_dataset_dir / "CloudSEN12 2k"
validation_dir = base_dataset_dir / "CloudSEN12 validation"
test_dir = base_dataset_dir / "CloudSEN12 test"

high_dir.mkdir(exist_ok=True, parents=True)
scribble_dir.mkdir(exist_ok=True, parents=True)
two_k_dir.mkdir(exist_ok=True, parents=True)
validation_dir.mkdir(exist_ok=True, parents=True)
test_dir.mkdir(exist_ok=True, parents=True)

In [None]:
clip_data_extent = True  # remove padding from 512 images to 509
num_threads = 1  # number of threads to use for parallel processing
bands = [
    4,  # Red(B04)
    3,  # Green(B03)
    9,  # NIR(B08A)
]
image_types = [
    "high",
    "scribble",
    "2k",
]

In [None]:
def remap_scribble_label(label):
    # 0-> 0
    # 1-> 1
    # 2-> 1
    # 3-> 2
    # 4-> 2
    # 5-> 3
    # 6 -> 3
    # 99-> 99
    new_label = np.zeros_like(label, dtype=np.uint8)
    new_label[label == 0] = 0
    new_label[label == 1] = 1
    new_label[label == 2] = 1
    new_label[label == 3] = 2
    new_label[label == 4] = 2
    new_label[label == 5] = 3
    new_label[label == 6] = 3
    new_label[label == 99] = 99
    return new_label

In [None]:
def export(output_path: Path, bands_data: np.ndarray, profile: dict):
    with rio.open(output_path, "w", **profile) as dst:
        dst.write(bands_data)

In [None]:
def save_product(
    dataset,
    id: int,
    output_dir: Path,
    true_shape: int,
    file_name: str,
    bands: list[int],
    processing_level: Optional[str] = None,
    type: str = "image",
) -> None:
    if type == "image":
        output_path = output_dir / f"{file_name}_{type}_{processing_level}.tif"
        item_index = 0
    elif type == "label":
        output_path = output_dir / f"{file_name}_{type}.tif"
        item_index = 1
    else:
        raise ValueError(f"Unknown type: {type}, expected 'image' or 'label'.")

    if output_path.exists():
        return

    img_path = dataset.read(id).read(item_index)

    with rio.open(img_path) as src:
        bands_data = src.read(
            bands,
        )
        profile = src.profile
        profile.update(count=len(bands), compress="lzw")
        if clip_data_extent:
            bands_data = bands_data[:, :true_shape, :true_shape]
            profile.update(width=true_shape, height=true_shape)

    if type == "label" and "scribble" in file_name:
        bands_data = remap_scribble_label(bands_data)

    export_thread = Thread(target=export, args=(output_path, bands_data, profile))
    export_thread.start()

In [None]:
def download(
    output_dir: Path,
    id: int,
    dataset,
    file_name: str,
    true_shape: int,
    processing_level: str,
    label: bool,
    retry_count: int = 0,
):
    try:
        save_product(
            dataset=dataset,
            id=id,
            output_dir=output_dir,
            true_shape=true_shape,
            file_name=file_name,
            processing_level=processing_level,
            bands=bands,
            type="image",
        )

        if label:
            save_product(
                dataset=dataset,
                id=id,
                output_dir=output_dir,
                true_shape=true_shape,
                file_name=file_name,
                bands=[1],
                type="label",
            )
        return

    except Exception as e:
        retry_count += 1

        if retry_count > 5:
            print(f"Failed to process ID {id} after multiple retries. Skipping.")
            return

        sleep_time = retry_count * 4
        print(f"Retrying {id} in {sleep_time} seconds... {e}")

        time.sleep(sleep_time)

        download(
            output_dir,
            id,
            dataset,
            file_name,
            true_shape,
            processing_level,
            label,
            retry_count=retry_count,
        )

In [None]:
def process_dataset(
    processing_level: str,
    image_types: list[str],
    label: bool = True,
):
    dataset = tacoreader.load(f"tacofoundation:cloudsen12-{processing_level}")
    args = []
    image_sizes = []
    image_types_internal = []
    if "high" in image_types:
        image_sizes.append(509)
        image_types_internal.append("high")

    if "scribble" in image_types:
        image_sizes.append(509)
        image_types_internal.append("scribble")

    if "2k" in image_types:
        image_sizes.append(2000)
        image_types_internal.append("high")

    image_sizes = set(image_sizes)
    image_types_internal = set(image_types_internal)

    for id, (_, row) in enumerate(dataset.iterrows()):
        true_shape = int(row["real_proj_shape"])

        if row["label_type"] in image_types_internal and true_shape in image_sizes:
            if (
                row["label_type"] == "high"
                and row["tortilla:data_split"] == "validation"
                and true_shape == 509
            ):
                out_dir = validation_dir
            elif (
                row["label_type"] == "high"
                and row["tortilla:data_split"] == "test"
                and true_shape == 509
            ):
                out_dir = test_dir
            elif (
                row["label_type"] == "high"
                and row["tortilla:data_split"] == "train"
                and true_shape == 509
            ):
                out_dir = high_dir
            elif true_shape == 2000:
                out_dir = two_k_dir
            elif row["label_type"] == "scribble":
                out_dir = scribble_dir

            else:
                raise ValueError(f"Unknown label type: {row['label_type']}")

            file_name = f"CloudSEN12_{row['tortilla:id']}_{row['tortilla:data_split']}_{true_shape}_{row['label_type']}"  # noqa: E501
            args.append(
                (
                    out_dir,
                    id,
                    dataset,
                    file_name,
                    true_shape,
                    processing_level,
                    label,
                )
            )
    with ThreadPool(num_threads) as pool:
        list(tqdm(pool.imap(lambda x: download(*x), args), total=len(args)))

In [None]:
process_dataset(processing_level="l1c", image_types=image_types)
process_dataset(processing_level="l2a", image_types=image_types, label=False)