In [None]:
import pandas as pd
import rioxarray as rxr
import rasterio
from rasterio.mask import mask
from pyproj import Transformer
from shapely.geometry import Point, mapping
import numpy as np
import torch
from tqdm import tqdm

def extract_buffered_satellite_data(tiff_path, csv_path, buffer_distance=50, use_gpu=True):
    """
    Extracts the average band values within a specified buffer (in meters) 
    around each coordinate from a GeoTIFF using GPU (MPS) acceleration.
    
    Parameters:
        tiff_path (str): Path to the GeoTIFF file.
        csv_path (str): Path to the CSV file containing Latitude and Longitude.
        buffer_distance (float): Buffer radius in meters.
        use_gpu (bool): Whether to use GPU (MPS) acceleration.
    
    Returns:
        pd.DataFrame: DataFrame with average band values within the buffer.
    """
    # ✅ 1️⃣ GPU 설정
    device = torch.device("mps" if torch.backends.mps.is_available() and use_gpu else "cpu")
    print(f"Using device: {device}")

    # ✅ 2️⃣ Load the GeoTIFF data
    dataset = rxr.open_rasterio(tiff_path)
    tiff_crs = dataset.rio.crs

    # Read the CSV file
    df = pd.read_csv(csv_path)
    latitudes = torch.tensor(df['Latitude'].values, dtype=torch.float32, device=device)
    longitudes = torch.tensor(df['Longitude'].values, dtype=torch.float32, device=device)

    # Transformer to convert lat/lon to the GeoTIFF CRS
    transformer = Transformer.from_crs("EPSG:4326", tiff_crs, always_xy=True)

    # Convert lat/lon to projected coordinates
    coords = torch.vstack([longitudes, latitudes]).cpu().numpy()
    transformed_coords = np.array(transformer.transform(coords[0], coords[1]))

    # Create storage for band values (GPU tensor)
    num_bands = dataset.shape[0]
    band_values = {f'B{band+1:02d}': torch.full((len(df),), float('nan'), dtype=torch.float32, device=device) for band in range(num_bands)}

    # Read Raster using GPU
    with rasterio.open(tiff_path) as src:
        for idx, (x, y) in tqdm(enumerate(zip(transformed_coords[0], transformed_coords[1])),
                                 total=len(latitudes), desc="Extracting values"):
            # Create a circular buffer
            point = Point(x, y).buffer(buffer_distance)
            geojson_geom = [mapping(point)]

            try:
                out_image, _ = mask(src, geojson_geom, crop=True)
                out_image = torch.tensor(out_image, dtype=torch.float32, device=device)  # GPU calculation
                
                # Compute the mean for each band using GPU
                for band in range(out_image.shape[0]):
                    band_values[f'B{band+1:02d}'][idx] = torch.nanmean(out_image[band])

            except Exception as e:
                print(f"Skipping point ({y}, {x}) due to error: {e}")

    band_df = pd.DataFrame({key: band_values[key].cpu().numpy() for key in band_values})

    return band_df

final_data = extract_buffered_satellite_data('/Users/ibolam/Projects/Abroad/01. UMD/25SPRING_DATA605/InfoChallege/UMD_IC25_Participant_Package/S2_alloutput_IC25.tiff', '/Users/ibolam/Projects/Abroad/01. UMD/25SPRING_DATA605/InfoChallege/UMD_IC25_Participant_Package/Training_Data_IC25.csv', buffer_distance=50, use_gpu=True)