In [1]:
import tacoreader
import rasterio as rio
from tqdm.auto import tqdm
from pathlib import Path
from multiprocessing.pool import ThreadPool
import time

This notebook downloads the CloudSEN12 dataset both the 509 and the 2k images and labels, only the data with pixel level labels, 
About 30GB of data, 3-4 hours.

In [2]:
dst_dir = Path("/media/nick/4TB Working 7/Datasets/CloudSEN12")
dst_dir.mkdir(exist_ok=True, parents=True)

In [3]:
clip_data_extent = True  # remove padding from 512 images to 509
num_threads = 2  # number of threads to use for parallel processing
include_2k_images = False  # include 2k images in the dataset
bands = [4, 3, 9]  # Red(B04), Green(B03), NIR(B08A)

In [4]:
def save_bands(
    dataset,
    id: int,
    output_dir: Path,
    true_shape: int,
    file_name: str,
    processing_level: str,
) -> None:
    output_path = output_dir / f"{file_name}_image_{processing_level}.tif"
    if output_path.exists():
        return
    img_path = dataset.read(id).read(0)
    with rio.open(img_path) as src:
        bands_data = src.read(
            bands,
        )
        profile = src.profile
        profile.update(count=len(bands))
        if clip_data_extent:
            bands_data = bands_data[:, :true_shape, :true_shape]
            profile.update(width=true_shape, height=true_shape)

    with rio.open(output_path, "w", **profile) as dst:
        dst.write(bands_data)

In [5]:
def save_label(
    dataset,
    id: int,
    output_dir: Path,
    true_shape: int,
    file_name: str,
) -> None:
    output_path = output_dir / f"{file_name}_label.tif"
    if output_path.exists():
        return
    label_path = dataset.read(id).read(1)
    with rio.open(label_path) as lbl_src:
        labels_data = lbl_src.read()

        label_profile = lbl_src.profile
        label_profile["compress"] = "lzw"
        if clip_data_extent:
            labels_data = labels_data[:, :true_shape, :true_shape]
            label_profile.update(width=true_shape, height=true_shape)

    with rio.open(output_path, "w", **label_profile) as lbl_dst:
        lbl_dst.write(labels_data)

In [None]:
def download(
    output_dir: Path,
    id: int,
    dataset,
    file_name: str,
    true_shape: int,
    processing_level: str,
    label: bool,
    retry_count: int = 0,
):
    try:
        save_bands(dataset, id, output_dir, true_shape, file_name, processing_level)

        if label:
            save_label(dataset, id, output_dir, true_shape, file_name)
        return

    except Exception as e:
        retry_count += 1

        if retry_count > 5:
            print(f"Failed to process ID {id} after multiple retries. Skipping.")
            return

        sleep_time = retry_count * 4
        print(f"Retrying {id} in {sleep_time} seconds... {e}")

        time.sleep(sleep_time)

        download(
            output_dir,
            id,
            dataset,
            file_name,
            true_shape,
            processing_level,
            label,
            retry_count=retry_count,
        )

In [7]:
def process_dataset(
    output_dir: Path, processing_level: str, label=True, include_2k_images=False
):
    dataset = tacoreader.load(f"tacofoundation:cloudsen12-{processing_level}")
    args = []
    if include_2k_images:
        image_sizes = [509, 2000]
    else:
        image_sizes = [509]
    for id, (_, row) in enumerate(dataset.iterrows()):
        true_shape = int(row["real_proj_shape"])

        if row["label_type"] == "high" and true_shape in image_sizes:
            file_name = (
                f"{row['tortilla:id']}_{row['tortilla:data_split']}_{true_shape}"
            )

            args.append(
                (
                    output_dir,
                    id,
                    dataset,
                    file_name,
                    true_shape,
                    processing_level,
                    label,
                )
            )
    with ThreadPool(num_threads) as pool:
        list(tqdm(pool.imap(lambda x: download(*x), args), total=len(args)))

In [10]:
process_dataset(dst_dir, processing_level="l1c")

  0%|          | 0/10000 [00:00<?, ?it/s]

In [11]:
process_dataset(dst_dir, processing_level="l2a", label=False)

  0%|          | 0/10000 [00:00<?, ?it/s]