In [None]:
import pystac_client
import planetary_computer as pc
from odc.stac import stac_load
import numpy as np
import os
import rioxarray
from tqdm import tqdm

import glob
import struct
from pathlib import Path
import json
from pathlib import Path
import pandas as pd
import calendar
import logging
import sys
from datetime import datetime

logging.basicConfig(
    level=logging.INFO,
    format='%(levelname)s: %(message)s',
    handlers=[
        logging.FileHandler('log.txt'),
        logging.StreamHandler(sys.stdout)
    ]
)

def calculate_cloud_cover_percentage(qa_band):
    """
    Calculate cloud cover percentage from Landsat Collection 2 QA_PIXEL band
    
    QA_PIXEL bit meanings:
    - Bit 3: Cloud
    - Bit 4: Cloud shadow
    - Bit 1: Dilated cloud
    """
    # Create cloud mask (bits 1, 3, 4 indicate clouds/shadows)
    cloud_mask = (
        ((qa_band & (1 << 1)) != 0) |  # Dilated cloud
        ((qa_band & (1 << 3)) != 0) |  # Cloud
        ((qa_band & (1 << 4)) != 0)    # Cloud shadow
    )
    
    total_pixels = qa_band.size
    cloud_pixels = cloud_mask.sum().item()
    
    return (cloud_pixels / total_pixels) * 100

def get_scenes_by_cloud_cover(query_results, bbox):
    """
    Load scenes and sort by cloud cover percentage
    
    Args:
        query_results: STAC query results
        bbox: Bounding box for the area of interest
        max_scenes: Maximum number of scenes to evaluate
    
    Returns:
        List of tuples: (scene_data, cloud_cover_percentage)
    """
    scenes_with_cloud_cover = []
    
    # Get all items and limit to max_scenes to avoid excessive processing
    items = list(query_results.items())
    
    print(f"Evaluating cloud cover for {len(items)} scenes...")
    
    for i, item in enumerate(tqdm(items, desc="Calculating cloud cover")):
        try:
            # Load just the QA band for this scene
            scene_qa = stac_load(
                [item],
                bands=["qa_pixel"],
                bbox=bbox,
                resolution=30,
                crs="EPSG:3857",
                skip_broken=True,
                fail_on_error=False
            )
            
            if len(scene_qa.time) == 0:
                continue
                
            # Calculate cloud cover percentage
            qa_band = scene_qa['qa_pixel'].isel(time=0)
            cloud_cover = calculate_cloud_cover_percentage(qa_band)
            
            # Store the item with its cloud cover percentage
            scenes_with_cloud_cover.append((item, cloud_cover))
            
        except Exception as e:
            print(f"Error processing scene {i}: {e}")
            continue
    
    # Sort by cloud cover (ascending - least cloudy first)
    scenes_with_cloud_cover.sort(key=lambda x: x[1])
    
    return scenes_with_cloud_cover


def load_checkpoint(checkpoint_file="./Data/checkpoint.json"):
    """Load existing checkpoint or create new one"""
    if os.path.exists(checkpoint_file):
        with open(checkpoint_file, 'r') as f:
            return json.load(f)
    return {"completed_years": [], "completed_cities": {}, "errored_cities": {}, "missing_months": {}, "current_year": None, "current_city": None}

def save_checkpoint(checkpoint_data, checkpoint_file="./Data/checkpoint.json"):
    """Save checkpoint data"""
    with open(checkpoint_file, 'w') as f:
        json.dump(checkpoint_data, f, indent=2)

years = ["2013", "2014", "2015", "2016", "2017"]#, "2018", "2019", "2020", "2021", "2022", "2023", "2024"]
NODATA = 0     
checkpoint = load_checkpoint()
start_year_idx = 0
if checkpoint["current_year"]:
    try:
        start_year_idx = years.index(checkpoint["current_year"])
    except ValueError:
        pass
shapes_folder = "./Data/City_Shapes"
cities = {}
def read_shapefile_bounds(shp_path):
    """Read bounding box from shapefile header (bytes 36-68)"""
    with open(shp_path, 'rb') as f:
        f.seek(36)
        bounds = struct.unpack('<4d', f.read(32))
        return bounds  # [xmin, ymin, xmax, ymax]
for shp_file in glob.glob(f"{shapes_folder}/*.shp"):
    city_name = Path(shp_file).stem
    xmin, ymin, xmax, ymax = read_shapefile_bounds(shp_file)
    cities[city_name] = [xmin, ymin, xmax, ymax]
for year_idx, year in enumerate(years[start_year_idx:], start_year_idx):
    checkpoint["current_year"] = year
    city_items = sorted(list(cities.items()))
    start_city_idx = 0
    if year == checkpoint.get("current_year") and checkpoint.get("current_city"):
        city_names = [name for name, _ in city_items]
        try:
            start_city_idx = city_names.index(checkpoint["current_city"]) #If None go to ValueError
        except ValueError: # Just default to the upcoming city and keep start_city_idx as 0
            pass
    for city_idx, (city, bbox) in enumerate(city_items[start_city_idx:], start_city_idx):
        checkpoint["current_city"] = city # This becomes default
        save_checkpoint(checkpoint)
        if year in checkpoint["completed_cities"] and city in checkpoint["completed_cities"][str(year)]:
            continue
        try:
            firstMonth, lastMonth = 1, 12
            if year == "2013": #Set limits for beginning of Landsat and present timing
                firstMonth = 6
            if year == "2025":
                lastMonth = 6
            for month in range(firstMonth,lastMonth+1):            
                for lstRetry in range(0,20):  
                    logging.info(datetime.now().strftime('%Y-%m-%d %H:%M:%S')) #Print Timestamp once
                    logging.info(f"Starting {city} {year} month {month}, retry -> {lstRetry}")
                    #Restart here to continue script with retry+=1
                    last_day = calendar.monthrange(int(year), int(month))[1]
                    client = pystac_client.Client.open(
                        "https://planetarycomputer.microsoft.com/api/stac/v1",
                        modifier=pc.sign_inplace,
                    )

                    query_landsat = client.search(
                        collections=["landsat-c2-l2"],
                        bbox=bbox,
                        datetime=f"{year}-{month:02d}-01/{year}-{month:02d}-{last_day}",   
                        query={"platform": {"in": ["landsat-8", "landsat-9"]}}                 
                    )
                    query_dem = client.search(
                        collections=["nasadem"],
                        bbox=bbox,
                    )
                    landsat_items = list(query_landsat.items())
                    if len(landsat_items) == 0:
                        logging.warning(f"No Landsat scenes found for {city}. Skipping to next month/city.")
                        continue
                    dem_items = list(query_dem.items())
                    if len(dem_items) == 0:
                        print(f"No DEM data found for {city}. Skipping to next month/city.")
                        continue                
                    scenes_with_cloud_cover = get_scenes_by_cloud_cover(query_landsat, bbox)
                    if lstRetry > len(scenes_with_cloud_cover)-1:
                        logging.warning("Retry thermal had no scenes left")
                        if year not in checkpoint["missing_months"]:
                            checkpoint["missing_months"][year] = {}
                        if city not in checkpoint["missing_months"][year]:
                            checkpoint["missing_months"][year][city] = "1"
                        elif city in checkpoint["missing_months"][year]:
                            checkpoint["missing_months"][year][city] = str(int(checkpoint["missing_months"][year][city]) + 1)
                        break
                    if not scenes_with_cloud_cover:                                            
                        logging.warning(f"No scenes found for {city} in {year}-{month}. Skipping to next month or year.")
                        if year not in checkpoint["missing_months"]:
                            checkpoint["missing_months"][year] = {}
                        if city not in checkpoint["missing_months"][year]:
                            checkpoint["missing_months"][year][city] = "1"
                        if city in checkpoint["missing_months"][year]:
                            checkpoint["missing_months"][year][city] = str(int(checkpoint["missing_months"][year][city]) + 1)                        
                        break
                        
                    best_scene_item = scenes_with_cloud_cover[lstRetry][0]  # [0][0] gets the item from first tuple
                    logging.info(f"Selected scene with {scenes_with_cloud_cover[0][1]:.1f}% cloud cover")

                    # Load just that one scene
                    landsat_rasters = stac_load(
                        [best_scene_item],  # Load only the best scene
                        bands=["blue", "green", "red", "nir08", "swir16", "lwir11", "qa_pixel"],
                        bbox=bbox,
                        resolution=30,
                        crs="EPSG:3857",
                        progress=tqdm,
                        skip_broken=True,
                        fail_on_error=True
                    )
                    dem_rasters = stac_load(
                        query_dem.items(),
                        bands=["elevation"],
                        bbox=bbox,
                        resolution=30,
                        crs="EPSG:3857",
                        progress=tqdm,
                        skip_broken=True,
                        fail_on_error=True
                    )
                    print(f"Loaded {len(landsat_rasters.time)} scenes")

                    def calculate_lst_fahrenheit(thermal_band):
                        lst_kelvin = thermal_band * 0.00341802 + 149.0
                        
                        # Convert Kelvin to Fahrenheit: F = (K - 273.15) × 9/5 + 32
                        lst_fahrenheit = (lst_kelvin - 273.15) * 9/5 + 32
                        return lst_fahrenheit

                    def convert_to_surface_reflectance(band_data):
                        # Collection 2 Level 2 scaling factors
                        scale_factor = 0.0000275
                        add_offset = -0.2
                        
                        # Apply the conversion formula
                        reflectance = band_data * scale_factor + add_offset
                        
                        # Clip to valid reflectance range [0, 1]
                        reflectance = reflectance.clip(0, 1)
                        return reflectance

                    def calculate_ndvi(nir_band, red_band):
                        nir_reflectance = convert_to_surface_reflectance(nir_band)
                        red_reflectance = convert_to_surface_reflectance(red_band)
                        numerator = nir_reflectance - red_reflectance
                        denominator = nir_reflectance + red_reflectance
                        ndvi = numerator.where(denominator != 0) / denominator.where(denominator != 0, 1)
                        ndvi = np.clip(ndvi, -1, 1)
                        return ndvi * 10_000

                    def calculate_ndwi(nir_band, green_band):
                        nir_reflectance = convert_to_surface_reflectance(nir_band)
                        green_reflectance = convert_to_surface_reflectance(green_band)
                        numerator = green_reflectance - nir_reflectance
                        denominator = green_reflectance + nir_reflectance
                        ndwi = numerator.where(denominator != 0) / denominator.where(denominator != 0, 1)
                        ndwi = np.clip(ndwi, -1, 1)
                        return ndwi * 10000

                    def calculate_ndbi(nir_band, swir_band):
                        nir_reflectance = convert_to_surface_reflectance(nir_band)
                        swir_reflectance = convert_to_surface_reflectance(swir_band)
                        numerator = swir_reflectance - nir_reflectance
                        denominator = swir_reflectance + nir_reflectance
                        ndbi = numerator.where(denominator != 0) / denominator.where(denominator != 0, 1)
                        ndbi = np.clip(ndbi, -1, 1)
                        return ndbi * 10000

                    def calculate_albedo(blue_band, green_band, red_band, nir_band, swir_band):
                        blue_refl = convert_to_surface_reflectance(blue_band).clip(0, 1)
                        green_refl = convert_to_surface_reflectance(green_band).clip(0, 1)
                        red_refl = convert_to_surface_reflectance(red_band).clip(0, 1)
                        nir_refl = convert_to_surface_reflectance(nir_band).clip(0, 1)
                        swir_refl = convert_to_surface_reflectance(swir_band).clip(0, 1)
                        albedo = (
                            0.356 * blue_refl +
                            0.130 * green_refl +
                            0.373 * red_refl +
                            0.085 * nir_refl +
                            0.072 * swir_refl -
                            0.018
                        )
                        return albedo.clip(0, 1) * 10_000

                    def calculate_color_band(color_band):
                        reflectance = convert_to_surface_reflectance(color_band)
                        return reflectance * 10000
                    try:
                        date = landsat_rasters.time.values[0]
                        scene = landsat_rasters.isel(time=0)
                        dem_band = dem_rasters['elevation']
                        thermal_band = scene['lwir11']
                        red_band = scene['red']
                        green_band = scene['green']
                        blue_band = scene['blue']
                        nir_band = scene['nir08']
                        swir_band = scene['swir16']
                        qa_band = scene['qa_pixel']
                        
                        # Calculate integer format
                        dem = dem_band + 10_000
                        lst_fahrenheit = calculate_lst_fahrenheit(thermal_band)
                        red255 = calculate_color_band(red_band)
                        green255 = calculate_color_band(green_band)
                        blue255 = calculate_color_band(blue_band)
                        ndvi = calculate_ndvi(nir_band, red_band)
                        ndwi = calculate_ndwi(nir_band, green_band)
                        ndbi = calculate_ndbi(nir_band, swir_band)
                        albedo = calculate_albedo(blue_band, green_band, red_band, nir_band, swir_band)

                        # Create mask for valid pixels
                        valid_thermal = thermal_band > 0
                        clear_pixels = (qa_band & 0b00011000) == 0
                        valid_mask = valid_thermal & clear_pixels
                        
                        # Apply mask (set invalid pixels to 0 instead of NaN)
                        lst_fahrenheit_masked = lst_fahrenheit.where(valid_mask, NODATA)
                        red_masked = red255.where(valid_mask, NODATA)
                        green_masked = green255.where(valid_mask, NODATA)
                        blue_masked = blue255.where(valid_mask, NODATA)
                        ndvi_masked = ndvi.where(valid_mask, NODATA)
                        ndwi_masked = ndwi.where(valid_mask, NODATA)
                        ndbi_masked = ndbi.where(valid_mask, NODATA)
                        albedo_masked = albedo.where(valid_mask, NODATA)

                        # Create directory structure
                        date_str = pd.to_datetime(date).strftime('%Y-%m-%dT%H:%M:%SZ')
                        output_dir = f"./Data/Dataset/Cities/{city}/{date_str}"
                        output_dir_dem = f"./Data/Dataset/DEM_2014/{city}"                        

                        # Now define file paths
                        filename_dem = f"./Data/Dataset/DEM_2014/{city}/DEM.tif"
                        filename_fahrenheit = f"./Data/Dataset/Cities/{city}/{date_str}/LST.tif"
                        filename_red = f"./Data/Dataset/Cities/{city}/{date_str}/red.tif"
                        filename_green = f"./Data/Dataset/Cities/{city}/{date_str}/green.tif"
                        filename_blue = f"./Data/Dataset/Cities/{city}/{date_str}/blue.tif"
                        filename_ndvi = f"./Data/Dataset/Cities/{city}/{date_str}/ndvi.tif"
                        filename_ndwi = f"./Data/Dataset/Cities/{city}/{date_str}/ndwi.tif"
                        filename_ndbi = f"./Data/Dataset/Cities/{city}/{date_str}/ndbi.tif"
                        filename_albedo = f"./Data/Dataset/Cities/{city}/{date_str}/albedo.tif"

                        # Set nodata values
                        lst_fahrenheit_masked.rio.write_nodata(NODATA, inplace=True)
                        red_masked.rio.write_nodata(NODATA, inplace=True)
                        green_masked.rio.write_nodata(NODATA, inplace=True)
                        blue_masked.rio.write_nodata(NODATA, inplace=True)
                        ndvi_masked.rio.write_nodata(NODATA, inplace=True)
                        ndwi_masked.rio.write_nodata(NODATA, inplace=True)
                        ndbi_masked.rio.write_nodata(NODATA, inplace=True)
                        albedo_masked.rio.write_nodata(NODATA, inplace=True)

                        valid_TIRS = lst_fahrenheit_masked.values[lst_fahrenheit_masked.values != NODATA]
                        valid_OLI = albedo_masked.values[albedo_masked.values != NODATA]
                        if len(valid_TIRS) > 0 and len(valid_OLI) > 0:
                            os.makedirs(output_dir, exist_ok=True)
                            os.makedirs(output_dir_dem, exist_ok=True)
                            print(f"Processing date: {date_str}")
                            print(f"Output directory: {output_dir}")
                            print(f"Directory exists: {os.path.exists(output_dir)}")
                            masked_paths = [dem, lst_fahrenheit_masked, red_masked, green_masked, blue_masked, ndvi_masked, ndwi_masked, ndbi_masked, albedo_masked]
                            file_paths = [filename_dem, filename_fahrenheit, filename_red, filename_green, filename_blue, filename_ndvi, filename_ndwi, filename_ndbi, filename_albedo]
                            for i in range(len(file_paths)):
                                try:
                                    if not os.path.exists(file_paths[i]):
                                        masked_paths[i].rio.to_raster(file_paths[i], compress="LZW", dtype='int16')
                                        logging.info(f"Saved: {file_paths[i]}")
                                    else:
                                        logging.info(f"File already exists: {file_paths[i]}")
                                except Exception as e:
                                    logging.error(f"Error saving file: {file_paths[i]}: {e}")
                            logging.info("Scene save complete!")        
                            break        
                        else:
                            if len(valid_TIRS) <= 0:
                                logging.warning(f"No valid temperature data for {city} {date_str}")
                            if len(valid_OLI) <= 0:                                
                                logging.warning(f"No valid spectral data for {city} {date_str}")
                            # Go back to after retry was set to 0 so we can try again                            
                    except Exception as e:
                        logging.error(f"Error saving scene {year}, {month}, {city}: {e}")
                        continue
                # print(f"Successfully processed and saved {rastersSaved} complete images")
                # Mark city as completed for this year
            if year not in checkpoint["completed_cities"]:
                checkpoint["completed_cities"][year] = []
            checkpoint["completed_cities"][year].append(city)
            if year not in checkpoint["errored_cities"]:
                checkpoint["errored_cities"][year] = []
            if city in checkpoint["errored_cities"][year]:
                checkpoint["errored_cities"][year].remove(city)
            save_checkpoint(checkpoint)
        except KeyboardInterrupt:
            print(f"Interrupted while processing {city} {year}")
            save_checkpoint(checkpoint)
            raise
        except Exception as e:
            if year not in checkpoint["errored_cities"]:
                checkpoint["errored_cities"][year] = []
            checkpoint["errored_cities"][year].append(city)
            save_checkpoint(checkpoint)
            logging.error(f"Error processing {city} {year}: {e}")
            continue
    if year not in checkpoint["errored_cities"]:
        checkpoint["errored_cities"][year] = []
    if len(checkpoint["errored_cities"][year]) == 0:
        checkpoint["completed_years"].append(year)
    checkpoint["current_city"] = None
    save_checkpoint(checkpoint)
print("All processing completed!")

INFO: 2025-07-07 02:16:05
INFO: Starting Abilene_TX 2013 month 6, retry -> 0


Evaluating cloud cover for 8 scenes...


Calculating cloud cover: 100%|██████████| 8/8 [00:04<00:00,  1.70it/s]

INFO: Selected scene with 0.0% cloud cover



100%|██████████| 7/7 [00:04<00:00,  1.68it/s]
100%|██████████| 1/1 [00:01<00:00,  1.63s/it]

Loaded 1 scenes
Processing date: 2013-06-19T17:10:19Z
Output directory: ./Data/Dataset/Cities/Abilene_TX/2013-06-19T17:10:19Z
Directory exists: True
INFO: File already exists: ./Data/Dataset/DEM_2014/Abilene_TX/DEM.tif





INFO: Saved: ./Data/Dataset/Cities/Abilene_TX/2013-06-19T17:10:19Z/LST.tif
INFO: Saved: ./Data/Dataset/Cities/Abilene_TX/2013-06-19T17:10:19Z/red.tif
INFO: Saved: ./Data/Dataset/Cities/Abilene_TX/2013-06-19T17:10:19Z/green.tif
INFO: Saved: ./Data/Dataset/Cities/Abilene_TX/2013-06-19T17:10:19Z/blue.tif
INFO: Saved: ./Data/Dataset/Cities/Abilene_TX/2013-06-19T17:10:19Z/ndvi.tif
INFO: Saved: ./Data/Dataset/Cities/Abilene_TX/2013-06-19T17:10:19Z/ndwi.tif
INFO: Saved: ./Data/Dataset/Cities/Abilene_TX/2013-06-19T17:10:19Z/ndbi.tif
INFO: Saved: ./Data/Dataset/Cities/Abilene_TX/2013-06-19T17:10:19Z/albedo.tif
INFO: Scene save complete!
INFO: 2025-07-07 02:16:18
INFO: Starting Abilene_TX 2013 month 7, retry -> 0
Evaluating cloud cover for 8 scenes...


Calculating cloud cover:  25%|██▌       | 2/8 [00:01<00:04,  1.28it/s]