## Build Tile Stacks
This notebook will read tiles from an s3 bucket and create multi band stacks.  

In [None]:
import os
import re
import getpass
from pathlib import Path
import numpy as np
import geopandas as gpd
import rasterio
import rasterio.shutil
from rasterio.session import AWSSession
from rasterio.vrt import WarpedVRT
import matplotlib.pyplot as plt

import xml.etree.ElementTree as ET
import subprocess
import boto3

from vegmapper import build_stack, build_condensed_stack

## User inputs

### Provide here access credentials for AWS S3 bucket

In [None]:
user_s3_key_id = getpass.getpass("Enter the S3 user key ID: ")
user_access_key = getpass.getpass("Enter the S3 user access key: ")

### Define paths and observation information
These fields should match the bucket structure containing the source tiles. 

In [None]:
location = 'ucayali'
observation_date = '2024'
bucket_name = "name_of_your_bucket"  # Replace with your actual bucket name
bucket_base_path = f"servir_peru/{location}/{observation_date}" 

### Define bands
Add/remove here as needed

In [None]:
band_definitions = [
    {"name": "VV", "subdir": "s1/vv"},
    {"name": "VH", "subdir": "s1/vh"},
    {"name": "RVI", "subdir": "s1/rvi"},
    {"name": "NDVI", "subdir": "landsat/ndvi"},
    # {"name": "NDFI", "subdir": "landsat/ndfi"},  # Example of an easy addition
]

## S3 bucket session

In [None]:
boto3_session = boto3.Session(
    aws_access_key_id=user_s3_key_id,
    aws_secret_access_key=user_access_key,
    # region_name="your-region"  # Optional
)

## Helper functions

In [None]:
# Function to list .tif files in a given S3 prefix
def list_s3_file_paths(bucket, prefix):
    client = boto3_session.client("s3")
    paginator = client.get_paginator("list_objects_v2")
    paths = []
    for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
        for obj in page.get("Contents", []):
            key = obj["Key"]
            if key.lower().endswith((".tif", ".tiff")):
                paths.append(f"s3://{bucket}/{key}")
    return paths

# Function to extract tile_id like "h0_v1"
def extract_tile_id(path):
    match = re.search(r'h\d+_v\d+', path)
    return match.group() if match else None

## Read Tile ID's and setup bands

In [None]:
print("Fetching all file paths for all bands...")
all_band_files = {}  # {band_name: [file paths]}
all_paths_flat = []

for band in band_definitions:
    full_prefix = f"{bucket_base_path}/{band['subdir']}/"
    print(f"Looking for {band['name']} files in: s3://{bucket_name}/{full_prefix}")
    files = list_s3_file_paths(bucket_name, full_prefix)
    print(f"Found {len(files)} {band['name']} files")
    all_band_files[band["name"]] = files
    all_paths_flat.extend(files)

# Get unique tile IDs
tile_ids = sorted(set(filter(None, [extract_tile_id(p) for p in all_paths_flat])))
print(f"Found {len(tile_ids)} unique tile(s): {tile_ids}")

## Prepare Tiles
This step will create the multi-band stacks and uploas them to the S3 bucket.

In [None]:
for tile_id in tile_ids:
    print(f"\n Processing tile: {tile_id}")

    stack_name = f"stack_{location}_{observation_date}_{tile_id}"
    stack_s3_path = f"{bucket_base_path}/multi_sensor_stacks/{stack_name}_cog.tif"
    local_stack = f"./{location}/opera_rtc/tile_vrts/{stack_name}.tif"
    local_cog = f"./{location}/opera_rtc/tile_vrts/{stack_name}_cog.tif"

    os.makedirs(os.path.dirname(local_stack), exist_ok=True)

    # Gather tile-specific paths for all bands
    band_paths = []
    band_names = []

    for band in band_definitions:
        tile_id_nounder = tile_id.replace("_", "")  # e.g., h3v2
        matching_files = [
            p for p in all_band_files[band["name"]]
            if tile_id in p or tile_id_nounder in p
        ]
        if matching_files:
            band_paths.extend(matching_files)
            band_names.extend([band["name"]] * len(matching_files))
        else:
            print(f"!! Missing {band['name']} for tile {tile_id}")

    if not band_paths:
        print(f"!! No valid files found for tile {tile_id}, skipping.")
        continue

    # Convert to GDAL VSI paths
    s3_paths = [p.replace("s3://", "/vsis3/") for p in band_paths]
    aws_session = AWSSession(boto3_session)

## Process tiles
    # Stack bands
    with rasterio.Env(aws_session):
        sources = [rasterio.open(p) for p in s3_paths]
        meta = sources[0].meta.copy()
        meta.update({
            "count": len(sources),
            "dtype": sources[0].dtypes[0],
            "driver": "GTiff"
        })
    
        with rasterio.open(local_stack, "w", **meta) as dst:
            for i, src in enumerate(sources, start=1):
                dst.write(src.read(1), i)
                dst.set_band_description(i, band_names[i - 1])
    
        for src in sources:
            src.close()
    
## Convert to COG 
    subprocess.run([
        "gdal_translate", local_stack, local_cog,
        "-of", "COG",
        "-co", "COMPRESS=LZW",
        "-co", "NUM_THREADS=ALL_CPUS"
    ], check=True)

    # --- STEP 5: Upload to S3 ---
    boto3_session.client("s3").upload_file(local_cog, bucket_name, stack_s3_path)
    print(f"--> Uploaded: s3://{bucket_name}/{stack_s3_path}")

    os.remove(local_stack)
    print(f"-> Removed: {local_stack}")

## Read and display sample Stack
Grab one of the generated stacks and display.

In [None]:
# Open VRT from S3
s3_vrt_path = f"/vsis3/{bucket_name}/{bucket_base_path}/multi_sensor_stacks/stack_ucayali_2024_h1_v1_cog.tif"

# Open the stacked VRT from S3
with rasterio.Env(aws_session):
    with rasterio.open(s3_vrt_path) as dataset:
        band1 = dataset.read(1)  # Read VV
        band2 = dataset.read(2)  # Read VH
        band3 = dataset.read(3)  # Read RVI
        band4 = dataset.read(4)  # Read NDVI
        band_names = dataset.descriptions

# Display Bands
fig, ax = plt.subplots(1, 4, figsize=(15, 5))

bands = [band1, band2, band3, band4]
cmaps = ["gray", "gray", "gray", "viridis"]
vmin_values = [0, 0, 0, 0.6]
vmax_values = [0.4, 0.1, 1.0, 0.9]

for i in range(4):
    ax[i].imshow(bands[i], cmap=cmaps[i], vmin=vmin_values[i], vmax=vmax_values[i])
    title = band_names[i] if band_names[i] else f"Band {i+1}"
    ax[i].set_title(title)
    ax[i].axis("off")

plt.tight_layout()
plt.show()