In [53]:

import json
import os
from datetime import datetime

import geopandas as gpd
import numpy as np
import pandas as pd
import rasterio
import torch
from einops import rearrange
import sys
sys.path.append(os.path.abspath("/Users/nisla/codigos/pangaea-bench/"))

from pangaea.datasets.base import RawGeoFMDataset

In [174]:
###
# Modified version of the PASTIS-HD dataset
# original code https://github.com/gastruc/OmniSat/blob/main/src/data/Pastis.py
###

import json
import os
from datetime import datetime

import geopandas as gpd
import numpy as np
import pandas as pd
import rasterio
import torch
from einops import rearrange
import os
import gdown
import zipfile

from pangaea.datasets.base import RawGeoFMDataset


def prepare_dates(date_dict, reference_date):
    """Date formating."""
    if type(date_dict) is str:
        date_dict = json.loads(date_dict)
    d = pd.DataFrame().from_dict(date_dict, orient="index")
    d = d[0].apply(
        lambda x: (
            datetime(int(str(x)[:4]), int(str(x)[4:6]), int(str(x)[6:]))
            - reference_date
        ).days
    )
    return torch.tensor(d.values)





class Dummy(RawGeoFMDataset):
    def __init__(
        self,
        split: str,
        dataset_name: str,
        multi_modal: bool,
        multi_temporal: int,
        root_path: str,
        classes: list,
        num_classes: int,
        ignore_index: int,
        img_size: int,
        bands: dict[str, list[str]],
        distribution: list[int],
        data_mean: dict[str, list[str]],
        data_std: dict[str, list[str]],
        data_min: dict[str, list[str]],
        data_max: dict[str, list[str]],
        download_url: str,
        auto_download: bool,
    ):
        """Initialize the PASTIS dataset.

        Args:
            split (str): split of the dataset (train, val, test).
            dataset_name (str): dataset name.
            multi_modal (bool): if the dataset is multi-modal.
            multi_temporal (int): number of temporal frames.
            root_path (str): root path of the dataset.
            classes (list): classes of the dataset.
            num_classes (int): number of classes.
            ignore_index (int): index to ignore for metrics and loss.
            img_size (int): size of the image.
            bands (dict[str, list[str]]): bands of the dataset.
            distribution (list[int]): class distribution.
            data_mean (dict[str, list[str]]): mean for each band for each modality.
            Dictionary with keys as the modality and values as the list of means.
            e.g. {"s2": [b1_mean, ..., bn_mean], "s1": [b1_mean, ..., bn_mean]}
            data_std (dict[str, list[str]]): str for each band for each modality.
            Dictionary with keys as the modality and values as the list of stds.
            e.g. {"s2": [b1_std, ..., bn_std], "s1": [b1_std, ..., bn_std]}
            data_min (dict[str, list[str]]): min for each band for each modality.
            Dictionary with keys as the modality and values as the list of mins.
            e.g. {"s2": [b1_min, ..., bn_min], "s1": [b1_min, ..., bn_min]}
            data_max (dict[str, list[str]]): max for each band for each modality.
            Dictionary with keys as the modality and values as the list of maxs.
            e.g. {"s2": [b1_max, ..., bn_max], "s1": [b1_max, ..., bn_max]}
            download_url (str): url to download the dataset.
            auto_download (bool): whether to download the dataset automatically.
        """
        super(Dummy, self).__init__(
            split=split,
            dataset_name=dataset_name,
            multi_modal=multi_modal,
            multi_temporal=multi_temporal,
            root_path=root_path,
            classes=classes,
            num_classes=num_classes,
            ignore_index=ignore_index,
            img_size=img_size,
            bands=bands,
            distribution=distribution,
            data_mean=data_mean,
            data_std=data_std,
            data_min=data_min,
            data_max=data_max,
            download_url=download_url,
            auto_download=auto_download,
        )
       



        assert split in ["train", "val", "test"], "Split must be train, val or test"
        if split == "train":
            folds = [1, 2, 3]
        elif split == "val":
            folds = [4]
        else:
            folds = [5]
        self.modalities = ["S2"]

        reference_date = "2017-06-01"
        self.reference_date = datetime(*map(int, reference_date.split("-")))

        self.num_classes = 10

        print("Reading patch metadata . . .")
        self.meta_patch = gpd.read_file(os.path.join(root_path, "metadata.geojson"))
        if folds is not None:
            self.meta_patch = pd.concat(
                [self.meta_patch[self.meta_patch["Fold"] == f] for f in folds]
            )
        self.meta_patch.index = self.meta_patch["id"].astype(int)
        self.meta_patch.sort_index(inplace=True)
        self.memory_dates = {}

        self.len = self.meta_patch.shape[0]
        self.id_patches = self.meta_patch.index
        self.date_tables = {s: None for s in self.modalities}

        for s in self.modalities:
            dates = self.meta_patch["dates-{}".format(s)]
            self.date_range = np.array(range(-200, 600))

            date_table = pd.DataFrame(
                index=self.meta_patch.index, columns=self.date_range, dtype=int
            )
            for pid, date_seq in dates.items():
                if type(date_seq) == str:
                    date_seq = json.loads(date_seq)
                d = pd.DataFrame().from_dict(date_seq, orient="index")
                d = d[0].apply(
                    lambda x: (
                        datetime(int(str(x)[:4]), int(str(x)[4:6]), int(str(x)[6:]))
                        - self.reference_date
                    ).days
                )
                date_table.loc[pid, d.values] = 1
            date_table = date_table.fillna(0)
            self.date_tables[s] = {
                index: np.array(list(d.values()))
                for index, d in date_table.to_dict(orient="index").items()
            }

        print("Done.")

    def __len__(self):
        return self.len
    
    def get_dates(self, id_patch, sat):
        indices = np.where(self.date_tables[sat][id_patch] == 1)[0]
        indices = indices[indices < len(self.date_range)]  # Ensure indices are within bounds
        return torch.tensor(self.date_range[indices], dtype=torch.int32)


    def __getitem__(self, i: int) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]:
        """Get the item at index i.

        Args:
            i (int): index of the item.

        Returns:
            dict[str, torch.Tensor | dict[str, torch.Tensor]]: output dictionary follwing the format
            {"image":
                {"optical": torch.Tensor,
                 "sar": torch.Tensor},
            "target": torch.Tensor,
             "metadata": dict}.
        """
        line = self.meta_patch.iloc[i]
        id_patch = self.id_patches[i]
        name = line["id"]
        target = torch.from_numpy(
            np.load(
                os.path.join(self.root_path, "ANNOTATIONS/TARGET_" + str(name) + ".npy")
            ).astype(np.int32)
        )
        # only for s2
        modality_name = "S2"
        data = {
                modality_name: np.load(
                    os.path.join(
                        self.root_path,
                        "DATA_{}".format(modality_name),
                        "{}_{}.npy".format(modality_name, name),
                    )
                ).astype(np.float32)
        }
        data = {s: torch.from_numpy(a) for s, a in data.items()}

        dates = {
                s: self.get_dates(id_patch, s) for s in self.modalities
            }
        # output[modality_name] = images
        # optical_ts = rearrange(output["s2"], "t c h w -> c t h w")


        

        if self.multi_temporal == 1:
            # we only take the last frame
            optical_ts = optical_ts[:, -1]
        

        return {
            "image": {
                    "optical": {s: a.to(torch.float32) for s, a in data.items()},
                    },

            "target": target.to(torch.int64),
            "dates": {s: dates[s].to(torch.int32) for s in self.modalities},
        }
    @staticmethod
    def download(self):
       pass

   

In [175]:
data_test = Dummy(
    split="train",
    dataset_name="DUMMY",
    # multi_modal=True,
    multi_temporal=6,
    root_path="/Users/nisla/codigos/pangaea-bench/data_2/",
    classes=["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"],
    num_classes=10,
    ignore_index=255,
    img_size=128,
    multi_modal=False,
    bands={"s2": ["B01","B02", "B03", "B04", 
                  "B05", "B06", "B07", "B08",
                  "B8A", "B09", "B11", "B12"]},
    distribution=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
    data_mean={"S2": [
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1
    ]},
    data_std={"S2": [
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
    ]},
    data_min={"S2": [
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
    ]},
    data_max={"S2": [
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1
    ]},
    download_url="https://drive.google.com/file/d/1J-k5kqIBI7sSXZYPyT49X-GuBEGJaRPR/view?usp=share_link",
    auto_download=True,
    
)
# download(data_test)

Reading patch metadata . . .


DataSourceError: /Users/nisla/codigos/pangaea-bench/data_2/metadata.geojson: No such file or directory

In [135]:
len(data_test)

465

In [133]:
print(data_test[0]["image"]["optical"]["S2"].shape)
print(data_test[0]["target"].shape)
print(data_test[0]["dates"]["S2"].shape)

torch.Size([6, 12, 128, 128])
torch.Size([128, 128])
torch.Size([6])


In [134]:
data_test[0]["dates"]

{'S2': tensor([  2,  42,  87, 127, 172, 212], dtype=torch.int32)}

In [130]:
data_test[0]

{'image': {'optical': {'S2': tensor([[[[ 1149.,  1149.,  1149.,  ...,  1225.,  1225.,  1225.],
             [ 1149.,  1149.,  1149.,  ...,  1225.,  1225.,  1225.],
             [ 1149.,  1149.,  1149.,  ...,  1225.,  1225.,  1225.],
             ...,
             [ 1175.,  1175.,  1175.,  ...,  1407.,  1407.,  1407.],
             [ 1173.,  1173.,  1173.,  ...,  1410.,  1410.,  1410.],
             [ 1173.,  1173.,  1173.,  ...,  1410.,  1410.,  1410.]],
   
            [[ 1290.,  1498.,  1497.,  ...,  1476.,  1478.,  1501.],
             [ 1408.,  1432.,  1192.,  ...,  1466.,  1482.,  1520.],
             [ 1401.,  1236.,  1101.,  ...,  1471.,  1482.,  1479.],
             ...,
             [ 1074.,  1050.,  1068.,  ...,  1366.,  1396.,  1353.],
             [ 1065.,  1049.,  1042.,  ...,  1354.,  1342.,  1390.],
             [ 1036.,  1058.,  1046.,  ...,  1332.,  1309.,  1332.]],
   
            [[ 1606.,  1844.,  1834.,  ...,  1782.,  1766.,  1796.],
             [ 1750.,  1835.,  