In [1]:
import numpy as np
import tensorflow as tf

from sklearn.metrics import mean_squared_error
from keras.models import Sequential
from keras.layers import Dense, Dropout, Input
from keras.optimizers import Adam

class NeuralNetwork:
    def __init__(self, input_shape, layers, dropout_rate, learning_rate, device):
        self.input_shape = input_shape
        self.layers = layers
        self.dropout_rate = dropout_rate
        self.learning_rate = learning_rate
        self.device = device
        self.model = self.build_model()

    def build_model(self):
        model = Sequential()
        model.add(Input(shape=(self.input_shape,)))
        for layer_size in self.layers:
            model.add(Dense(layer_size, activation='relu'))
            model.add(Dropout(self.dropout_rate))
        model.add(Dense(1, activation='linear'))  # Output layer for regression
        return model

    def compile_model(self):
        optimizer = Adam(learning_rate=self.learning_rate)
        self.model.compile(optimizer=optimizer, loss='mean_squared_error', metrics=['mae'])  # Use MSE for regression

    def train_model(self, X_train, y_train, epochs=20, batch_size=32, validation_split=0.2):
        with tf.device(self.device):
            history = self.model.fit(X_train, y_train, epochs=epochs, batch_size=batch_size, validation_split=validation_split)
        return history

    def evaluate_model(self, X_test, y_test):
        with tf.device(self.device):
            loss, mae = self.model.evaluate(X_test, y_test)
        return loss, mae

    def predict(self, X_test):
        with tf.device(self.device):
            predictions = self.model.predict(X_test)
        return predictions.flatten()

    def calculate_rmse(self, y_test, predictions):
        return np.sqrt(mean_squared_error(y_test, predictions))
    
    def save_model(self, filename):
        self.model.save(filename)

    @classmethod
    def load_model(cls, filename, input_shape, device):
        loaded_model = tf.keras.models.load_model(filename)
        nn = cls(input_shape, [], 0, 0, device)  # Dummy values for layers, dropout_rate, and learning_rate
        nn.model = loaded_model
        return nn

def objective(trial):
    layers = []
    for i in range(trial.suggest_int('n_layers', 1, 3)):
        layers.append(trial.suggest_int(f'n_units_l{i}', 64, 512))
    
    dropout_rate = trial.suggest_float('dropout_rate', 0.2, 0.5)
    learning_rate = trial.suggest_float('learning_rate', 1e-5, 1e-2)
    
    nn = NeuralNetwork(input_shape=X_train.shape[1], device=device, layers=layers, dropout_rate=dropout_rate, learning_rate=learning_rate)
    nn.compile_model()
    
    nn.train_model(X_train, y_train, epochs=20, batch_size=32, validation_split=0.2)
    
    predictions = nn.predict(X_test)
    predictions_exp = np.exp(predictions)  # Inverse log transform predictions
    rmse = nn.calculate_rmse(y_test_exp, predictions_exp)
    
    return rmse

2024-08-08 15:56:35.926940: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-08-08 15:56:35.944984: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-08-08 15:56:35.950542: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-08-08 15:56:35.963124: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


#### Load in test set geojson and turn into grid of points

In [2]:
import geopandas as gpd

year = 2022

# Load the GeoJSON file
geojson_path = 'test_data/challenge_1_bb.geojson'
gdf = gpd.read_file(geojson_path)
gdf

Unnamed: 0,geometry
0,"POLYGON ((-106.08092 35.78627, -106.08092 35.4..."


In [3]:
import pyproj

def get_utm_zone(longitude):
    return int((longitude + 180) / 6) + 1

# Get the bounds of the geometry
minx, miny, maxx, maxy = gdf.geometry.bounds.iloc[0]

# Calculate UTM zone
utm_zone = get_utm_zone(minx)

# Check for a suitable projection using pyproj
proj = pyproj.Proj(proj='utm', zone=utm_zone, ellps='WGS84')

# Get the corresponding EPSG code for the UTM zone using pyproj
utm_crs = pyproj.CRS(f"+proj=utm +zone={utm_zone} +datum=WGS84")
epsg_code = utm_crs.to_epsg()

# Reproject the GeoDataFrame to the chosen EPSG code
gdf = gdf.to_crs(epsg=epsg_code)
gdf

Unnamed: 0,geometry
0,"POLYGON ((402315.263 3960781.699, 401878.759 3..."


In [4]:
import numpy as np

# Create a grid of points 5120m apart
x = np.arange(gdf.total_bounds[0], gdf.total_bounds[2], 2560)
y = np.arange(gdf.total_bounds[1], gdf.total_bounds[3], 2560)
xx, yy = np.meshgrid(x, y)
points = np.vstack([xx.ravel(), yy.ravel()]).T

grid = gpd.GeoDataFrame(geometry=gpd.points_from_xy(points[:, 0], points[:, 1], crs=gdf.crs))
grid

Unnamed: 0,geometry
0,POINT (401878.759 3920624.607)
1,POINT (404438.759 3920624.607)
2,POINT (406998.759 3920624.607)
3,POINT (409558.759 3920624.607)
4,POINT (412118.759 3920624.607)
...,...
331,POINT (442838.759 3959024.607)
332,POINT (445398.759 3959024.607)
333,POINT (447958.759 3959024.607)
334,POINT (450518.759 3959024.607)


#### Join biomass data on test grid

In [5]:
import ee
import geopandas as gpd
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
import os
from shapely.geometry import box, mapping
from shapely.ops import transform
import pyproj

# Initialize Earth Engine
ee.Initialize()

# Load the NASA/ORNL biomass carbon density dataset
biomass = ee.ImageCollection("NASA/ORNL/biomass_carbon_density/v1").mosaic()

def create_bbox_around_point(point, size=2560):
    # Create a bounding box around a point with the given size (meters)
    half_size = size / 2.0
    bbox = box(point.x - half_size, point.y - half_size, point.x + half_size, point.y + half_size)
    
    # Define the projections
    wgs84 = pyproj.CRS('EPSG:4326')
    utm = pyproj.CRS(epsg_code)  # Use the appropriate UTM zone for your data

    project_to_wgs84 = pyproj.Transformer.from_crs(utm, wgs84, always_xy=True).transform

    # Create the bounding box in UTM
    expanded_bbox = box(
        bbox.bounds[0] - half_size, bbox.bounds[1] - half_size,
        bbox.bounds[2] + half_size, bbox.bounds[3] + half_size
    )

    # Project the bounding box back to WGS84
    bbox_wgs84 = transform(project_to_wgs84, expanded_bbox)
    
    return bbox_wgs84

def get_agb(geometry):
    if geometry.geom_type == 'Point':
        # Create a bounding box around the point
        bbox = create_bbox_around_point(geometry)
    else:
        bbox = geometry
    
    # Convert the GeoPandas geometry to an Earth Engine geometry
    ee_geometry = ee.Geometry(mapping(bbox))

    # Get the mean AGB value within the geometry
    agb_value = biomass.reduceRegion(
        reducer=ee.Reducer.mean(),
        geometry=ee_geometry,
        scale=300,
        maxPixels=1e9
    ).get('agb')
    
    # Return the result
    return agb_value.getInfo() if agb_value is not None else None

def process_geometries(combined_gdf):
    # Get the total number of rows in the GeoDataFrame
    total_rows = len(combined_gdf)

    # Determine the number of threads to use
    max_threads = 10  # Adjust this based on your system and Earth Engine quota
    num_threads = min(total_rows, max_threads)

    results = []

    with ThreadPoolExecutor(max_workers=num_threads) as executor:
        # Submit all tasks
        future_to_index = {executor.submit(get_agb, row.geometry): index 
                        for index, row in combined_gdf.iterrows()}
        
        # Process as they complete with a progress bar
        with tqdm(total=total_rows, desc="Processing geometries") as pbar:
            for future in as_completed(future_to_index):
                index = future_to_index[future]
                try:
                    result = future.result()
                except Exception as exc:
                    print(f'Generated an exception: {exc}')
                    result = None
                results.append((index, result))
                pbar.update(1)

    # Sort results by index and extract only the values
    sorted_results = [r[1] for r in sorted(results, key=lambda x: x[0])]
    
    return sorted_results

if __name__ == '__main__':
    # Process the geometries
    results = process_geometries(grid)

    # Add the results as a new column to the GeoDataFrame
    grid['mean_agb'] = results

Processing geometries: 100%|██████████| 336/336 [00:04<00:00, 70.47it/s]


#### Create embeddings for test set

In [6]:
import pandas as pd
import geopandas as gpd
from shapely.geometry import Point
import pystac_client
import stackstac
import torch
from torchvision import transforms as v2
from box import Box
import yaml
import math
from rasterio.enums import Resampling
from tqdm import tqdm
import rasterio
import warnings
import os
import numpy as np
import rioxarray  # Make sure to import rioxarray to extend xarray

from src.model import ClayMAEModule

warnings.filterwarnings("ignore")

STAC_API = "https://earth-search.aws.element84.com/v1"
COLLECTION = "sentinel-2-l2a"

# Load the model and metadata
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ckpt = "https://clay-model-ckpt.s3.amazonaws.com/v0.5.7/mae_v0.5.7_epoch-13_val-loss-0.3098.ckpt"
torch.set_default_device(device)

torch.cuda.empty_cache()  # Clear GPU cache

# Assuming grid is a GeoDataFrame with the points
points = grid.to_crs("EPSG:4326").geometry.apply(lambda x: (x.x, x.y)).tolist()

model = ClayMAEModule.load_from_checkpoint(
    ckpt, metadata_path="configs/metadata.yaml", shuffle=False, mask_ratio=0
)
model.eval()
model = model.to(device)

metadata = Box(yaml.safe_load(open("configs/metadata.yaml")))

# Function to normalize timestamp
def normalize_timestamp(date):
    week = date.isocalendar().week * 2 * np.pi / 52
    hour = date.hour * 2 * np.pi / 24
    return (math.sin(week), math.cos(week)), (math.sin(hour), math.cos(hour))

# Function to normalize lat/lon
def normalize_latlon(lat, lon):
    lat = lat * np.pi / 180
    lon = lon * np.pi / 180
    return (math.sin(lat), math.cos(lat)), (math.sin(lon), math.cos(lon))

def to_device(data, device):
    if isinstance(data, torch.Tensor):
        return data.to(device)
    elif isinstance(data, dict):
        return {k: to_device(v, device) for k, v in data.items()}
    elif isinstance(data, list):
        return [to_device(v, device) for v in data]
    return data

def process_point(lon, lat, model, metadata, year, device, j):
    model.to(device)  # Ensure the model is on the correct device
    catalog = pystac_client.Client.open(STAC_API)
    search = catalog.search(
        collections=[COLLECTION],
        datetime=f"{year}-01-01/{year}-12-31",
        bbox=(lon - 1e-5, lat - 1e-5, lon + 1e-5, lat + 1e-5),
        max_items=10,
        query={"eo:cloud_cover": {"lt": 80}},
    )

    all_items = search.get_all_items()
    items = list(all_items)
    if not items:
        return None
    
    items = sorted(items, key=lambda x: x.properties.get('eo:cloud_cover', float('inf')))
    lowest_cloud_item = items[0]

    epsg = lowest_cloud_item.properties["proj:epsg"]

    poidf = gpd.GeoDataFrame(
        pd.DataFrame(),
        crs="EPSG:4326",
        geometry=[Point(lon, lat)],
    ).to_crs(epsg)

    coords = poidf.iloc[0].geometry.coords[0]

    size = 256
    gsd = 10
    bounds = (
        coords[0] - (size * gsd) // 2,
        coords[1] - (size * gsd) // 2,
        coords[0] + (size * gsd) // 2,
        coords[1] + (size * gsd) // 2,
    )

    stack = stackstac.stack(
        lowest_cloud_item,
        bounds=bounds,
        snap_bounds=False,
        epsg=epsg,
        resolution=gsd,
        dtype="float32",
        rescale=False,
        fill_value=0,
        assets=["blue", "green", "red", "nir"],
        resampling=Resampling.nearest,
    )

    stack = stack.compute()

    items = []
    dates = []
    for item in all_items:
        if item.datetime.date() not in dates:
            items.append(item)
            dates.append(item.datetime.date())

    date = str(stack.time.values)[2:11]

    output_path = os.path.join("test_data/embeddings/challenge_1/", f"stack_{lon}_{lat}_{j}.tif")
    
    # Write the stack to a TIFF file
    with rasterio.open(
            output_path, 'w',
            driver='GTiff',
            height=stack.shape[2],
            width=stack.shape[3],
            count=len(stack.band),  # Number of bands
            dtype=str(stack.dtype),
            crs=epsg,
            transform=stack.rio.transform()
        ) as tif:
        for i, band in enumerate(stack.band, start=1):
            tif.write(np.squeeze(stack.sel(band=band).values), i)

    # Reopen the file to add metadata
    with rasterio.open(output_path, "r+") as rst:
        rst.update_tags(date=date)

    platform = "sentinel-2-l2a"
    mean = []
    std = []
    waves = []
    for band in stack.band:
        mean.append(metadata[platform].bands.mean[str(band.values)])
        std.append(metadata[platform].bands.std[str(band.values)])
        waves.append(metadata[platform].bands.wavelength[str(band.values)])

    transform = v2.Compose([v2.Normalize(mean=mean, std=std)])

    datetimes = stack.time.values.astype("datetime64[s]").tolist()
    times = [normalize_timestamp(dat) for dat in datetimes]
    week_norm = [dat[0] for dat in times]
    hour_norm = [dat[1] for dat in times]

    latlons = [normalize_latlon(lat, lon)] * len(times)
    lat_norm = [dat[0] for dat in latlons]
    lon_norm = [dat[1] for dat in latlons]

    pixels = torch.from_numpy(stack.data.astype(np.float32)).to(device)
    pixels = transform(pixels)

    batch_size = 16
    num_batches = math.ceil(len(stack) / batch_size)
    
    embeddings_list = []
    for i in range(num_batches):
        start_idx = i * batch_size
        end_idx = min((i + 1) * batch_size, len(stack))
        
        batch_pixels = pixels[start_idx:end_idx].to(device)
        batch_time = torch.tensor(np.hstack((week_norm, hour_norm))[start_idx:end_idx], dtype=torch.float32).to(device)
        batch_latlon = torch.tensor(np.hstack((lat_norm, lon_norm))[start_idx:end_idx], dtype=torch.float32).to(device)
        
        batch_datacube = {
            "platform": platform,
            "time": batch_time,
            "latlon": batch_latlon,
            "pixels": batch_pixels,
            "gsd": torch.tensor(stack.gsd.values).to(device),
            "waves": torch.tensor(waves).to(device),
        }

        batch_datacube = to_device(batch_datacube, device)

        try:
            model = model.to(device)

            with torch.no_grad():
                unmsk_patch, _, _, _ = model.model.encoder(batch_datacube)
            batch_embeddings = unmsk_patch[:, 0, :].cpu().numpy()
            embeddings_list.append(batch_embeddings)
        except RuntimeError as e:
            if "out of memory" in str(e):
                print(f"GPU OOM for point ({lon}, {lat}), batch {i+1}/{num_batches}. Trying CPU...")
                device = torch.device("cpu")
                batch_datacube = to_device(batch_datacube, device)
                model = model.to(device)
                with torch.no_grad():
                    unmsk_patch, _, _, _ = model.model.encoder(batch_datacube)
                batch_embeddings = unmsk_patch[:, 0, :].numpy()
                embeddings_list.append(batch_embeddings)
                device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            else:
                raise e

    embeddings = np.concatenate(embeddings_list, axis=0)
    return embeddings

# Specify the year for the datetime range in the search
year = 2022

# Store results in a list
results = []

# Iterate through the points and process each one
for i, point in enumerate(tqdm(points)):
    lon, lat = point
    embeddings = process_point(lon, lat, model, metadata, year, device, i)
    if embeddings is not None:
        results.append((lon, lat, embeddings, grid.loc[i, 'mean_agb']))

# Create a DataFrame from the results
df = pd.DataFrame(results, columns=["lon", "lat", "embeddings", "mean_agb"])

# Convert to a GeoDataFrame
gdf_results = gpd.GeoDataFrame(df, geometry=gpd.points_from_xy(df.lon, df.lat))

# Output the resulting GeoDataFrame
gdf_results.head()


  from .autonotebook import tqdm as notebook_tqdm
  _torch_pytree._register_pytree_node(
100%|██████████| 336/336 [04:00<00:00,  1.40it/s]


Unnamed: 0,lon,lat,embeddings,mean_agb,geometry
0,-106.080869,35.424211,"[[0.040777754, -0.022916876, 0.07173511, 0.078...",0.522241,POINT (-106.08087 35.42421)
1,-106.052673,35.42446,"[[0.03793185, 0.009665264, 0.10715726, 0.04229...",0.453972,POINT (-106.05267 35.42446)
2,-106.024477,35.424703,"[[0.025163846, 0.013742316, 0.124567054, 0.047...",0.426844,POINT (-106.02448 35.42470)
3,-105.99628,35.424939,"[[0.0464423, 0.015486966, 0.08057992, 0.064075...",0.453645,POINT (-105.99628 35.42494)
4,-105.968084,35.425168,"[[0.0056447657, -0.0030673319, 0.074692115, 0....",0.480928,POINT (-105.96808 35.42517)


#### Run predict on test set

In [18]:
import tensorflow as tf

# Detect if GPU is available
device = '/GPU:0' if tf.config.list_physical_devices('GPU') else '/CPU:0'

# Load the model
loaded_nn = NeuralNetwork.load_model('models/agb_regression_model.h5', input_shape=768, device=device)

# Prepare your new data (assuming it's in the same format as your training data)
new_data = np.squeeze(gdf_results['embeddings'].tolist())
new_data = pd.DataFrame(new_data)  # Ensure the new data is in DataFrame format

# Standardize the new data using the saved scaler
scaler = joblib.load('models/scaler.joblib')
new_data_scaled = scaler.transform(new_data)

# Make predictions
new_predictions = loaded_nn.predict(new_data_scaled)

gdf_results['pred_agb'] = np.exp(new_predictions) - 1



[1m11/11[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 39ms/step


In [23]:
print("Test set RMSE:", np.sqrt(np.mean((gdf_results['pred_agb']-gdf_results['mean_agb'])**2)))
print("Test set corr:", np.corrcoef(gdf_results['pred_agb'],gdf_results['mean_agb'])[0][1])

Test set RMSE: 10.0773895934398
Test set corr: 0.7283481857530248
