In [None]:
%%capture
!pip install "dask[distributed]"
!pip install planetary-computer

## Imports

In [78]:
import os
import re
import numpy as np
import geopandas as gpd
import xarray as xr
import pystac_client
import stackstac
import rioxarray
from shapely.geometry import mapping

from pystac_client import Client
from stackstac import stack

from rasterio.enums import Resampling
import planetary_computer

## Functions

In [79]:
def get_utm_epsg(lat, lon):
    zone = int((lon + 180) / 6) + 1
    if lat >= 0:
        return 32600 + zone  # Northern Hemisphere
    else:
        return 32700 + zone  # Southern Hemisphere


def extract_month(filename):
    match = re.match(r'(\d{4}_\d{2})/raster_\d{2}_\d+\.tif$', filename)
    if match:
        return match.group(1)
    return None



## Making grids, skip if you already have the grids

In [5]:
# india_boundary = gpd.read_file("./india_boundary.geojson")

In [4]:
# utm_crs = india_boundary.estimate_utm_crs()
# india_boundary = india_boundary.to_crs(utm_crs)

# # 10 meters per pixel → 2240m grid size
# grid_size = 2240

# minx, miny, maxx, maxy = india_boundary.total_bounds

# grid_tiles = []
# for x in np.arange(minx, maxx, grid_size):
#     for y in np.arange(miny, maxy, grid_size):
#         grid_tiles.append(box(x, y, x + grid_size, y + grid_size))

# grid_gdf = gpd.GeoDataFrame({"geometry": grid_tiles}, crs=utm_crs)

# india_grid = gpd.clip(grid_gdf, india_boundary)

# india_grid = india_grid.to_crs(epsg=4326)

# india_grid.to_file("india_grid.geojson", driver="GeoJSON")

# print(f"Created {len(india_grid)} grid tiles.")

Created 188 grid tiles.


### Generating random grids

In [None]:
# # Load data 
# india_grid = gpd.read_file("india_grid.geojson")
# bboxes = gpd.read_file("bbox.geojson")  

# #  Ensure CRS match
# if india_grid.crs != bboxes.crs:
#     bboxes = bboxes.to_crs(india_grid.crs)

# # Select all grids intersecting the bounding boxes (fully or partially) 
# bbox_union = bboxes.unary_union
# grids_in_bbox = india_grid[india_grid.geometry.intersects(bbox_union)]

# # Identify square-like grids (not just 4-sided, but width ≈ height) 
# def is_square_like(geom, tolerance=0.1):
#     if geom.geom_type != "Polygon":
#         return False
#     coords = list(geom.exterior.coords)
#     if len(coords) - 1 != 4:
#         return False
#     minx, miny, maxx, maxy = geom.bounds
#     width = maxx - minx
#     height = maxy - miny
#     if height == 0:
#         return False
#     aspect_ratio = width / height
#     return abs(aspect_ratio - 1) <= tolerance  # e.g. between 0.90 and 1.1

# # Apply the square-like check
# perfect_grids = india_grid[india_grid.geometry.apply(is_square_like)]

# # Remove any square-like grids that intersect with the bounding boxes 
# perfect_grids_outside = perfect_grids[~perfect_grids.geometry.intersects(bbox_union)]

# # Sampling logic 
# TOTAL = 300_000
# n_from_outside = TOTAL - len(grids_in_bbox)

# if n_from_outside < 0:
#     raise ValueError(f"Too many bbox-intersecting grids: {len(grids_in_bbox)} exceeds total {TOTAL}")

# grids_outside_random = perfect_grids_outside.sample(n=n_from_outside, random_state=42)

# # Combine and save
# final_grids = gpd.GeoDataFrame(
#     pd.concat([grids_in_bbox, grids_outside_random], ignore_index=True),
#     crs=india_grid.crs
# )

# final_grids.to_file("india_random_grids.geojson", driver="GeoJSON")


# print(f" Final grid count: {len(final_grids)}")
# print(f" BBox-intersecting grids (all kept): {len(grids_in_bbox)}")
# print(f" Random perfect-square grids from outside: {len(grids_outside_random)}")


### check the grids

In [6]:

# random_grids = gpd.read_file("./irrigation_raw_labels/partial_labels_grids.geojson")
# grid = random_grids  

# ma = leafmap.Map(center=[22.5, 78.9], zoom=4)
# ma.add_basemap('SATELLITE')
# ma.add_geojson(random_grids)
# ma


Map(center=[22.5, 78.9], controls=(ZoomControl(options=['position', 'zoom_in_text', 'zoom_in_title', 'zoom_out…

## Downloading S2 pipeline

### Load and Prepare India Grids

In [80]:
# Paths
geojson_path = "./rice_practices/mask_ready/partial_labels_grids_FID.geojson"
output_folder = "./rice_practices/Sentinel2_UTM_TimeSeries_new"
os.makedirs(output_folder, exist_ok=True)

# STAC catalog earth_search API
catalog = Client.open("https://earth-search.aws.element84.com/v1")

# Assets to download
assets = ['green', 'nir', 'red', 'rededge1', 'swir1', 'scl']

### Search, Stack, Composite, and Save

In [None]:
# Load grid with FID, filename, and geometry
grids = gpd.read_file(geojson_path)
grids["FID"] = grids["FID"].astype(str)

# Identify unique tile-date combinations
needed_downloads = set()
for _, row in grids.iterrows():
    fid = row["FID"]                       # e.g., tile000000
    yyyymm = row["filename"].split("/")[0]  # e.g., 2024_07
    needed_downloads.add((fid, yyyymm))

# Process each tile-month
for fid, yyyymm in sorted(needed_downloads):
    print(f"\nProcessing: {fid} | {yyyymm}")
    year, month = yyyymm.split("_")
    date_range = f"{year}-{month}-01/{year}-{month}-30"
    print(date_range)

    # Get geometry from matching row
    match = grids[(grids["FID"] == fid) & (grids["filename"].str.contains(yyyymm))]
    if match.empty:
        print(f"No matching geometry for {fid} in {yyyymm}, skipping.")
        break

    row = match.iloc[0]
    grid_geom = row.geometry
    centroid = grid_geom.centroid
    utm_epsg = get_utm_epsg(centroid.y, centroid.x)

    # Search STAC
    search = catalog.search(
        collections=["sentinel-2-l2a"],
        intersects=grid_geom,
        datetime=date_range,
        limit=10
    )

    items = search.get_all_items()
    if not items:
        print(f"No images for {fid} in {yyyymm}, skipping.")
        continue

    item = items[0]
    if "proj:code" not in item.properties:
        print(f"No EPSG info in item metadata for {fid}, skipping.")
        continue

    sentinel_epsg = int(item.properties["proj:code"].replace("EPSG:", ""))

    # Stack
    stack = stackstac.stack(
        [item],
        assets=["B02", "B03", "B04", "B08", "scl"],
        epsg=sentinel_epsg,
        resolution=10,
        chunksize=2048
    )

    grid_utm = gpd.GeoSeries(grid_geom, crs=grids.crs).to_crs(f"EPSG:{sentinel_epsg}").geometry[0]
    clipped = stack.rio.clip([grid_utm], f"EPSG:{sentinel_epsg}", drop=True)

    # Quality filter
    scl = clipped.sel(band="scl")
    valid_mask = scl.isin([4, 5, 6, 7])
    valid_ratio = valid_mask.sum() / (valid_mask.sizes["x"] * valid_mask.sizes["y"])

    if valid_ratio < 0.6:
        print(f"Low valid pixel ratio ({valid_ratio:.2f}) for {fid}, skipping.")
        continue

    # Apply mask and get median
    clipped = clipped.where(valid_mask, np.nan)
    median_img = clipped.median(dim="time", keep_attrs=True)

    # Remove SCL band
    if "scl" in median_img.band.values:
        median_img = median_img.sel(band=median_img.band != "scl")

    # Resample to 224×224
    median_img = median_img.rio.reproject(
        dst_crs=f"EPSG:{sentinel_epsg}",
        shape=(224, 224),
        resampling=Resampling.bilinear
    )

    # Save with correct naming
    output_name = f"{fid}_{yyyymm}_s2.tif"
    output_path = os.path.join(output_folder, output_name)

    median_img.rio.to_raster(output_path)
    print(f" Saved: {output_path} | Shape: {median_img.shape}")


Processing: tile000000 | 2024_07
2024-07-01/2024-07-30
