In [1]:
import webdataset as wds
from presto.dataops.dataset import (
    TAR_BUCKET,
    Dataset,
    S1_S2_ERA5_SRTM_DynamicWorldMonthly_2020_2021,
)
import torch
from datasets import load_dataset
from typing import List, Tuple, Optional, Dict, Any, Iterable
from collections import OrderedDict
from typing import OrderedDict as OrderedDictType
import numpy as np
from presto.dataops import MASK_STRATEGIES, plot_masked
from presto.utils import (
    DEFAULT_SEED,
    config_dir,
    device,
    initialize_logging,
    seed_everything,
    timestamp_dirname,
    update_data_dir,
)
import random
from random import choice, randint, sample
from presto.dataops.pipelines.dynamicworld import DynamicWorld2020_2021
from presto.dataops.pipelines.s1_s2_era5_srtm import (
    NORMED_BANDS,
)
NUM_TIMESTEPS = 60
TIMESTEPS_IDX = list(range(NUM_TIMESTEPS))
from collections import namedtuple
from dataclasses import dataclass
from presto.dataops.utils import construct_single_presto_input
from torch.utils.data import Dataset

MASK_STRATEGIES = (
    "group_bands",
    "random_timesteps",
    "chunk_timesteps",
    "random_combinations",
)
NUM_TIMESTEPS = 60
S1_BANDS = ["VV", "VH"]
ERA5_BANDS = ["temperature_2m", "total_precipitation"]
SRTM_BANDS = ["elevation", "slope"]
BANDS_GROUPS_IDX: OrderedDictType[str, List[int]] = OrderedDict(
    {
        "S1": [NORMED_BANDS.index(b) for b in S1_BANDS],
        "S2_RGB": [NORMED_BANDS.index(b) for b in ["B2", "B3", "B4"]],
        "S2_Red_Edge": [NORMED_BANDS.index(b) for b in ["B5", "B6", "B7"]],
        "S2_NIR_10m": [NORMED_BANDS.index(b) for b in ["B8"]],
        "S2_NIR_20m": [NORMED_BANDS.index(b) for b in ["B8A"]],
        "S2_SWIR": [NORMED_BANDS.index(b) for b in ["B11", "B12"]],  # Include B10?
        "ERA5": [NORMED_BANDS.index(b) for b in ERA5_BANDS],
        "SRTM": [NORMED_BANDS.index(b) for b in SRTM_BANDS],
        "NDVI": [NORMED_BANDS.index("NDVI")],
    }
)

BAND_EXPANSION = [len(x) for x in BANDS_GROUPS_IDX.values()]
SRTM_INDEX = list(BANDS_GROUPS_IDX.keys()).index("SRTM")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def make_mask(strategy: str, mask_ratio: float) -> Tuple[np.ndarray, np.ndarray]:
    """
    Make a mask for a given strategy and percentage of masked values.
    Args:
        strategy: The masking strategy to use. One of MASK_STRATEGIES
        mask_ratio: The percentage of values to mask. Between 0 and 1.
    """

    # SRTM is included here, but ignored by Presto
    mask = np.full((NUM_TIMESTEPS, len(BANDS_GROUPS_IDX)), False)
    dw_mask = np.full(NUM_TIMESTEPS, False)
    srtm_mask = False
    num_tokens_to_mask = int(((NUM_TIMESTEPS * len(BANDS_GROUPS_IDX)) + 1) * mask_ratio)

    def mask_topography(srtm_mask, num_tokens_to_mask, mask_ratio):
        should_flip = random.random() < mask_ratio
        if should_flip:
            srtm_mask = True
            num_tokens_to_mask -= 1
        return srtm_mask, num_tokens_to_mask

    def random_masking(mask, dw_mask, num_tokens_to_mask: int):
        if num_tokens_to_mask > 0:
            # we set SRTM to be True - this way, it won't get randomly assigned.
            # at the end of the function, it gets properly assigned
            mask[:, SRTM_INDEX] = True
            # then, we flatten the mask and dw arrays
            all_tokens_mask = np.concatenate([dw_mask, mask.flatten()])
            unmasked_tokens = all_tokens_mask == False
            idx = np.flatnonzero(unmasked_tokens)
            np.random.shuffle(idx)
            idx = idx[:num_tokens_to_mask]
            all_tokens_mask[idx] = True
            mask = all_tokens_mask[NUM_TIMESTEPS:].reshape((NUM_TIMESTEPS, len(BANDS_GROUPS_IDX)))
            dw_mask = all_tokens_mask[:NUM_TIMESTEPS]
        return mask, dw_mask

    # RANDOM BANDS
    if strategy == "random_combinations":
        srtm_mask, num_tokens_to_mask = mask_topography(srtm_mask, num_tokens_to_mask, mask_ratio)
        mask, dw_mask = random_masking(mask, dw_mask, num_tokens_to_mask)

    elif strategy == "group_bands":
        srtm_mask, num_tokens_to_mask = mask_topography(srtm_mask, num_tokens_to_mask, mask_ratio)
        # next, we figure out how many tokens we can mask
        num_band_groups_to_mask = int(num_tokens_to_mask / NUM_TIMESTEPS)
        num_tokens_to_mask -= NUM_TIMESTEPS * num_band_groups_to_mask
        assert num_tokens_to_mask >= 0
        # tuple because of mypy, which thinks lists can only hold one type
        band_groups: List[Any] = list(range(len(BANDS_GROUPS_IDX))) + ["DW"]
        band_groups.remove(SRTM_INDEX)
        band_groups_to_mask = sample(band_groups, num_band_groups_to_mask)
        for band_group in band_groups_to_mask:
            if band_group == "DW":
                dw_mask[:] = True
            else:
                mask[:, band_group] = True
        mask, dw_mask = random_masking(mask, dw_mask, num_tokens_to_mask)

    # RANDOM TIMESTEPS
    elif strategy == "random_timesteps":
        srtm_mask, num_tokens_to_mask = mask_topography(srtm_mask, num_tokens_to_mask, mask_ratio)
        # +1 for dynamic world, -1 for the SRTM
        timesteps_to_mask = int(num_tokens_to_mask / (len(BANDS_GROUPS_IDX)))
        num_tokens_to_mask -= (len(BANDS_GROUPS_IDX)) * timesteps_to_mask
        timesteps = sample(TIMESTEPS_IDX, k=timesteps_to_mask)
        mask[timesteps] = True
        dw_mask[timesteps] = True
        mask, dw_mask = random_masking(mask, dw_mask, num_tokens_to_mask)
    elif strategy == "chunk_timesteps":
        srtm_mask, num_tokens_to_mask = mask_topography(srtm_mask, num_tokens_to_mask, mask_ratio)
        timesteps_to_mask = int(num_tokens_to_mask / (len(BANDS_GROUPS_IDX)))
        num_tokens_to_mask -= (len(BANDS_GROUPS_IDX)) * timesteps_to_mask
        start_idx = randint(0, NUM_TIMESTEPS - timesteps_to_mask)
        mask[start_idx : start_idx + timesteps_to_mask] = True  # noqa
        dw_mask[start_idx : start_idx + timesteps_to_mask] = True  # noqa
        mask, dw_mask = random_masking(mask, dw_mask, num_tokens_to_mask)
    else:
        raise ValueError(f"Unknown strategy {strategy} not in {MASK_STRATEGIES}")

    mask[:, SRTM_INDEX] = srtm_mask
    return np.repeat(mask, BAND_EXPANSION, axis=1), dw_mask

In [3]:
@dataclass
class MaskParams:
    strategies: Tuple[str, ...] = ("NDVI",)
    ratio: float = 0.5

    def __post_init__(self):
        for strategy in self.strategies:
            assert strategy in [
                "group_bands",
                "random_timesteps",
                "chunk_timesteps",
                "random_combinations",
            ]

    def mask_data(self, eo_data: np.ndarray, dw_data: np.ndarray):
        strategy = choice(self.strategies)
        mask, dw_mask = make_mask(strategy=strategy, mask_ratio=self.ratio)
        x = eo_data * ~mask
        y = np.zeros(eo_data.shape).astype(np.float32)
        y[mask] = eo_data[mask]

        masked_dw_tokens = np.ones_like(dw_data) * DynamicWorld2020_2021.missing_data_class
        x_dw = np.where(dw_mask, masked_dw_tokens, dw_data)
        y_dw = np.zeros(x_dw.shape).astype(np.int16)
        y_dw[dw_mask] = dw_data[dw_mask]

        return mask, dw_mask, x, y, x_dw, y_dw, strategy

In [None]:
import os
import json
from typing import Optional
from datasets import load_dataset, Dataset
import torch
from torch.utils.data import Dataset
import numpy as np
from your_module import MaskParams, construct_single_presto_input  # Replace with actual imports

class FranceCropsFullDataset(Dataset):
    def __init__(
        self,
        dataset: str,
        split: str,
        mask_params: MaskParams,
        shuffle: bool = True,
        seed: int = 42,
        preprocess: bool = True,
        cache_dir: Optional[str] = None,
    ):
        super().__init__()
        self.mask_params = mask_params
        self.split = split
        self.shuffle = shuffle
        self.seed = seed
        self.cache_dir = cache_dir

        if preprocess:
            if cache_dir is not None and os.path.exists(cache_dir):
                # Load metadata and check compatibility
                metadata_path = os.path.join(cache_dir, 'metadata.json')
                if not os.path.exists(metadata_path):
                    raise ValueError(f"Metadata not found in {cache_dir}")
                with open(metadata_path, 'r') as f:
                    saved_metadata = json.load(f)
                # Generate current metadata for comparison
                current_metadata = self._get_metadata(dataset, split, mask_params, shuffle, seed)
                if saved_metadata != current_metadata:
                    raise ValueError("Cache parameters do not match current parameters.")
                # Load the preprocessed dataset
                self.base_dataset = load_dataset(os.path.join(cache_dir, 'dataset'))
            else:
                # Preprocess the dataset
                self.base_dataset = self._load_dataset(dataset)
                self.base_dataset = self._preprocess()
                # Save to cache if directory provided
                if cache_dir is not None:
                    os.makedirs(cache_dir, exist_ok=True)
                    # Save the dataset
                    dataset_path = os.path.join(cache_dir, 'dataset')
                    self.base_dataset.save_to_disk(dataset_path)
                    # Save metadata
                    metadata = self._get_metadata(dataset, split, mask_params, shuffle, seed)
                    metadata_path = os.path.join(cache_dir, 'metadata.json')
                    with open(metadata_path, 'w') as f:
                        json.dump(metadata, f, indent=4)
        else:
            # Load without preprocessing
            self.base_dataset = self._load_dataset(dataset)
        
        # Ensure the dataset is in PyTorch format
        self.base_dataset.set_format(type='torch')

    def _get_metadata(self, dataset: str, split: str, mask_params: MaskParams, shuffle: bool, seed: int) -> dict:
        """Generate metadata dictionary for parameter compatibility checks."""
        return {
            'dataset': dataset,
            'split': split,
            'mask_params': mask_params.__dict__,
            'shuffle': shuffle,
            'seed': seed,
        }

    def _load_dataset(self, dataset: str) -> Dataset:
        """Load the base dataset from Hugging Face."""
        return load_dataset(dataset, split=self.split)

    def _expand_function(self, examples):
        """Expand the time series data into individual slices."""
        new_x = [slice for x_array in examples['x'] for slice in x_array]
        new_y = [y_label for y_label, x_array in zip(examples['y'], examples['x']) 
                 for _ in range(len(x_array))]
        return {"x": new_x, "y": new_y}

    def _convert_to_presto(self, examples):
        """Convert examples to Presto input format."""
        x_tensor = torch.tensor(examples['x'], dtype=torch.float32)
        bands = ["B1", "B2", "B3", "B4", "B5", "B6", "B7", "B8", "B8A", "B9", "B11", "B12"]
        presto_input, mask, dynamic = construct_single_presto_input(s2=x_tensor, s2_bands=bands)
        latlons = np.zeros(2, dtype=np.float32)
        start_month = 0

        mask, mask_dw, x, y, x_dw, y_dw, strat = self.mask_params.mask_data(presto_input, dynamic)

        return {
            "x": x, "y": y, "mask": mask, "start_month": start_month,
            "latlons": latlons, "mask_dw": mask_dw, "x_dw": x_dw, 
            "y_dw": y_dw, "strategy": strat
        }

    def _preprocess(self) -> Dataset:
        """Apply preprocessing steps including expansion and conversion."""
        # Expand the dataset
        expanded_dataset = self.base_dataset.map(
            self._expand_function,
            batched=True,
            remove_columns=["x", "y"],
        )
        # Shuffle if required
        if self.shuffle:
            expanded_dataset = expanded_dataset.shuffle(seed=self.seed)
        # Convert to Presto format
        processed_dataset = expanded_dataset.map(self._convert_to_presto)
        return processed_dataset

    def __len__(self) -> int:
        return len(self.base_dataset)

    def __getitem__(self, idx) -> dict:
        return self.base_dataset[idx]

In [5]:
mask_params = MaskParams(MASK_STRATEGIES, 0.75)
france_crops_dataset = FranceCropsFullDataset(
    dataset="saget-antoine/francecrops",
    split="train",
    mask_params=mask_params,
    shuffle=True,
    seed=42,
)

Map: 100%|█████████████████████████████████████████████████████████████████| 20000/20000 [04:43<00:00, 70.51 examples/s]
Map: 100%|████████████████████████████████████████████████████████████| 2000000/2000000 [41:59<00:00, 793.69 examples/s]


In [9]:
import pickle
with open("france_crops_full_dataset.pkl", "wb") as f:
    pickle.dump(france_crops_dataset, f)


In [10]:
train_dataloader = torch.utils.data.DataLoader(
    france_crops_dataset, batch_size=32, shuffle=True, num_workers=4, pin_memory=True
)

In [12]:
for epoch_step, b in enumerate(train_dataloader):
    mask, x, y, start_month = b["mask"].to(device), b["x"].to(device), b["y"].to(device), b["start_month"]
    dw_mask, x_dw, y_dw = b["mask_dw"].to(device), b["x_dw"].to(device).long(), b["y_dw"].to(device).long()
    latlons = b["latlons"].to(device)
    break