In [1]:
# get row-col of point.
import polars as pl
from rasterio import CRS
from pyproj import Transformer
import geopandas as gpd
import xarray as xr
from tqdm import tqdm

In [None]:
samples = gpd.read_parquet("../samples.parquet")

In [None]:
desired_band_order = [
    "B02",
    "B03",
    "B04",
    "B05",
    "B06",
    "B07",
    "B08",
    "B8A",
    "B11",
    "B12",
    "SCL",
]


def get_pixel_data(sample):
    ds = xr.open_zarr(
        f"../chips/{sample['sample_id']}.zarr",
        decode_coords="all",
        mask_and_scale=False,
        chunks=None,
        use_zarr_fill_value_as_mask=False,
    )
    ds[desired_band_order].drop_duplicates("time")
    epsg = CRS.from_wkt(ds.spatial_ref.crs_wkt).to_epsg()
    ds = ds.drop_vars("spatial_ref")
    ds = ds.rio.write_crs(epsg)
    valid_scl = [4, 5, 6]
    ds["clear"] = ds["SCL"].isin(valid_scl)
    transformer = Transformer.from_crs("EPSG:4326", epsg, always_xy=True)
    # we get the row/col based on the actual point and geotransform. The point will always be
    # at 16x16 (0 indexed) of the chip.
    xx, yy = transformer.transform(sample["geometry"].x, sample["geometry"].y)
    # Get the index along each dimension
    ix = ds.indexes["x"].get_indexer([xx], method="nearest")[0]
    iy = ds.indexes["y"].get_indexer([yy], method="nearest")[0]
    pixel_ts = ds.isel(x=ix, y=iy)

    # compute proportion of clear pixels in different chunk sizes
    clear_4x4 = (
        ds["clear"]
        .isel(x=slice(ix - 2, ix + 2), y=slice(iy - 2, iy + 2))
        .mean(["x", "y"])
        .values
    )
    clear_8x8 = (
        ds["clear"]
        .isel(x=slice(ix - 4, ix + 4), y=slice(iy - 4, iy + 4))
        .mean(["x", "y"])
        .values
    )
    clear_16x16 = (
        ds["clear"]
        .isel(x=slice(ix - 8, ix + 8), y=slice(iy - 8, iy + 8))
        .mean(["x", "y"])
        .values
    )
    clear_32x32 = (
        ds["clear"]
        .isel(x=slice(ix - 16, ix + 16), y=slice(iy - 16, iy + 16))
        .mean(["x", "y"])
        .values
    )
    df = (
        pl.from_pandas(pixel_ts.to_pandas().reset_index())
        .drop("y", "x", "spatial_ref")
        .with_columns(
            percent_clear_4x4=clear_4x4,
            percent_clear_8x8=clear_8x8,
            percent_clear_16x16=clear_16x16,
            percent_clear_32x32=clear_32x32,
            sample_id=pl.lit(sample["sample_id"]),
        )
        .rename({"time": "timestamps"})
    )
    return df

In [219]:
pixel_data_dfs = [get_pixel_data(sample) for _, sample in tqdm(samples.iterrows())]

  ret = um.true_divide(
  ret = um.true_divide(
  ret = um.true_divide(
  ret = um.true_divide(
  ret = um.true_divide(
  ret = um.true_divide(
  ret = um.true_divide(
  ret = um.true_divide(
  ret = um.true_divide(
  ret = um.true_divide(
  ret = um.true_divide(
  ret = um.true_divide(
  ret = um.true_divide(
  ret = um.true_divide(
  ret = um.true_divide(
  ret = um.true_divide(
  ret = um.true_divide(
  ret = um.true_divide(
  ret = um.true_divide(
  ret = um.true_divide(
  ret = um.true_divide(
  ret = um.true_divide(
3892it [38:21,  1.69it/s]


In [243]:
full_df = (
    pl.concat(pixel_data_dfs, how="diagonal_relaxed")
    .fill_nan(0)
    .with_columns(
        pl.col.timestamps.dt.date(),
        pl.col.SCL.cast(pl.UInt8),
        (pl.selectors.starts_with("percent_clear") * 100).cast(pl.UInt8),
    )
    .sort(["sample_id", "timestamps"])
)

In [244]:
full_df.write_parquet("../pixel_data.parquet")

In [4]:
pixel_data = pl.read_parquet("../pixel_data.parquet").drop(
    pl.selectors.starts_with("percent_clear"), "label"
)
pixel_data

timestamps,B02,B03,B07,B06,B04,B05,B08,B11,SCL,B12,B8A,clear,sample_id
date,u16,u16,u16,u16,u16,u16,u16,u16,u8,u16,u16,bool,i32
2019-01-04,4604,4300,4721,4580,3806,4160,4808,1999,9,2051,4633,false,0
2019-01-09,181,451,4138,3232,251,918,4100,1411,4,583,4509,true,0
2019-01-14,210,425,3668,2881,221,840,3550,1335,4,567,3964,true,0
2019-01-24,1262,1372,3962,3318,1116,1699,3930,2044,8,1340,4223,false,0
2019-01-29,3472,3334,4660,4454,3016,3617,4888,3105,9,2557,4765,false,0
…,…,…,…,…,…,…,…,…,…,…,…,…,…
2024-11-04,6864,6416,6453,6606,6424,6580,6876,4742,9,3562,6452,false,4302
2024-11-29,3552,3052,3494,3541,3136,3555,3184,2813,9,2738,3560,false,4302
2024-12-04,954,847,1791,1636,939,1224,2342,1162,10,831,2232,false,4302
2024-12-16,4036,2998,2530,2587,2244,2661,4660,1288,8,590,2598,false,4302


In [6]:
labels_df = pl.read_parquet("../labels.parquet").select(
    pl.col.sample_id.cast(pl.Int32),
    pl.col.label.cast(pl.UInt16),
    timestamps=pl.col.start.dt.date(),
)
added_labels = pixel_data.join_asof(labels_df, by="sample_id", on="timestamps")

  added_labels = pixel_data.join_asof(labels_df, by="sample_id", on="timestamps")


In [7]:
added_labels

timestamps,B02,B03,B07,B06,B04,B05,B08,B11,SCL,B12,B8A,clear,sample_id,label
date,u16,u16,u16,u16,u16,u16,u16,u16,u8,u16,u16,bool,i32,u16
2019-01-04,4604,4300,4721,4580,3806,4160,4808,1999,9,2051,4633,false,0,
2019-01-09,181,451,4138,3232,251,918,4100,1411,4,583,4509,true,0,
2019-01-14,210,425,3668,2881,221,840,3550,1335,4,567,3964,true,0,
2019-01-24,1262,1372,3962,3318,1116,1699,3930,2044,8,1340,4223,false,0,
2019-01-29,3472,3334,4660,4454,3016,3617,4888,3105,9,2557,4765,false,0,
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
2024-11-04,6864,6416,6453,6606,6424,6580,6876,4742,9,3562,6452,false,4302,222
2024-11-29,3552,3052,3494,3541,3136,3555,3184,2813,9,2738,3560,false,4302,222
2024-12-04,954,847,1791,1636,939,1224,2342,1162,10,831,2232,false,4302,222
2024-12-16,4036,2998,2530,2587,2244,2661,4660,1288,8,590,2598,false,4302,222


In [8]:
added_labels.write_parquet("../pixel_data.parquet")