Package imports

In [None]:
import gc
import math
import os
import time
from datetime import datetime, timedelta

import netCDF4 as nc
import numpy as np
import pandas as pd
from geopy import Point
from geopy.distance import great_circle
from scipy.spatial import cKDTree
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split

Function to pre-process spatial data

In [None]:
# Precompute the KDTree and valid_time differences
def precompute_kdtree_and_time_diffs(uwnd_nc_file_path):
    try:
        print("Precomputing KDTree and time differences...")
        # Load the NetCDF file
        ds = nc.Dataset(uwnd_nc_file_path)

        # Extract the valid_time, latitudes, and longitudes from the NetCDF file
        valid_time = ds.variables['valid_time'][:]  # Assuming 'valid_time' is the variable name for time
        latitudes = ds.variables['latitude'][:]
        longitudes = ds.variables['longitude'][:]

        # Convert valid_time from seconds since 1970-01-01 to datetime
        base_time = datetime(1970, 1, 1)
        valid_time_dt = np.array([base_time + timedelta(seconds=int(ts)) for ts in valid_time], dtype='datetime64[ns]')

        # Create a KDTree for fast spatial lookup
        lat_lon_pairs = np.array([(lat, lon) for lat in latitudes for lon in longitudes])
        tree = cKDTree(lat_lon_pairs)

        print("KDTree and time differences precomputed successfully.")
        return tree, valid_time_dt, latitudes, longitudes, lat_lon_pairs
    except Exception as e:
        print(f"Error precomputing KDTree and time differences: {e}")
        raise

uwnd_nc_file_path = '../data/raw/reanalyses/ERA5/era5_uwnd_2023.nc'
vwnd_nc_file_path = '../data/raw/reanalyses/ERA5/era5_vwnd_2023.nc'
try:
    tree, valid_time_dt, latitudes, longitudes, lat_lon_pairs = precompute_kdtree_and_time_diffs(uwnd_nc_file_path)
except Exception as e:
    print(f"Error precomputing KDTree and time differences: {e}")
    raise


Function to extract wind components at a given lat/lon (preloads reanalysis netCDFs also)

In [None]:
uwnd_nc_file_path = '../data/raw/reanalyses/ERA5/era5_uwnd_2023.nc'
vwnd_nc_file_path = '../data/raw/reanalyses/ERA5/era5_vwnd_2023.nc'

uwnd_ds = nc.Dataset(uwnd_nc_file_path)
vwnd_ds = nc.Dataset(vwnd_nc_file_path)

uwnd_array = uwnd_ds.variables['u'][:, 0, :, :]  # Assuming 'u' is the variable name for u-component wind and removing the pressure dimension
vwnd_array = vwnd_ds.variables['v'][:, 0, :, :]  # Assuming 'v' is the variable name for v-component wind and removing the pressure dimension

# Function to extract wind components
def extract_wind_components(lat, lon, dt, tree, valid_time_dt, latitudes, longitudes, lat_lon_pairs):
    try:
        # Convert the given datetime to a numpy datetime64 object
        row_datetime = np.datetime64(dt)

        # Find the value in the valid_time dimension closest in time to the datetime in the dataframe
        time_diffs = np.abs(valid_time_dt - row_datetime)
        closest_time_index = np.argmin(time_diffs)

        # Check if the calculated index is within the bounds of the uwnd_array
        if closest_time_index < 0 or closest_time_index >= uwnd_array.shape[0]:
            raise ValueError("The given datetime is out of bounds for the NetCDF data")

        # Select the corresponding netCDF slices
        uwnd_slice = uwnd_array[closest_time_index, :, :]
        vwnd_slice = vwnd_array[closest_time_index, :, :]

        # Find the grid cell of the netCDF slice closest to the given Latitude and Longitude position
        lat_lon = (lat, lon)
        _, closest_point_index = tree.query(lat_lon)
        closest_lat, closest_lon = lat_lon_pairs[closest_point_index]

        # Find the index of the closest latitude/longitude pair in the arrays
        lat_index = np.where(latitudes == closest_lat)[0][0]
        lon_index = np.where(longitudes == closest_lon)[0][0]

        # Extract the u and v wind components
        u_wind = uwnd_slice[lat_index, lon_index]
        v_wind = vwnd_slice[lat_index, lon_index]

        # Round wind components to 4 decimal places
        u_wind = round(u_wind, 4)
        v_wind = round(v_wind, 4)

        return u_wind, v_wind
    except Exception as e:
        print(f"Error extracting wind components: {e}")
        raise

Model training and validation

In [None]:
# Load the data from the spreadsheet
buoy_data = pd.read_csv('../processed_buoy_data.csv')

def calculate_new_position(current_position, displacement, heading):
    R = 6371000  # Earth's radius in meters
    
    lat1 = math.radians(current_position[0])
    lon1 = math.radians(current_position[1])
    heading_rad = math.radians(heading)
    
    lat2 = math.asin(math.sin(lat1) * math.cos(displacement / R) +
                     math.cos(lat1) * math.sin(displacement / R) * math.cos(heading_rad))
    
    lon2 = lon1 + math.atan2(math.sin(heading_rad) * math.sin(displacement / R) * math.cos(lat1),
                             math.cos(displacement / R) - math.sin(lat1) * math.sin(lat2))
    
    return math.degrees(lat2), math.degrees(lon2)

# Drop unused columns
print("Dropping unused columns...")
columns_to_keep = ['Latitude', 'Longitude', 'BuoyID', 'datetime', 'era5_uwnd', 'era5_vwnd', 'displacement', 'heading']
buoy_data = buoy_data[columns_to_keep].copy()
buoy_data['datetime'] = pd.to_datetime(buoy_data['datetime'])
print("Datetime column converted.")

# Split data by BuoyID into training and validation sets
print("Splitting data by BuoyID...")
buoy_ids = buoy_data['BuoyID'].unique()
train_ids = np.random.choice(buoy_ids, size=int(len(buoy_ids) - 5), replace=False)
val_ids = np.setdiff1d(buoy_ids, train_ids)

train_data = buoy_data[buoy_data['BuoyID'].isin(train_ids)].copy()
val_data = buoy_data[buoy_data['BuoyID'].isin(val_ids)].copy()

# Clear unused data from memory
del train_data, val_data
gc.collect()
print("Memory cleaned up after subsetting data.")

# Set up training and validation data
X_train = train_data[['Latitude', 'Longitude', 'era5_uwnd', 'era5_vwnd']]
y_train = train_data[['displacement', 'heading']]

# Train the model
print("Training model...")
model = RandomForestRegressor()
model.fit(X_train, y_train)
print("Model training complete.")

# Set output file path
output_file_path = f'../data/predictions/predictions_{val_buoy_id}.csv'

# Ensure the predictions directory exists
predictions_dir = os.path.dirname(output_file_path)
if not os.path.exists(predictions_dir):
    os.makedirs(predictions_dir)
    print(f"Directory {predictions_dir} created.")
else:
    print(f"Directory {predictions_dir} already exists.")

def iterative_prediction(val_data, model, tree, valid_times, latitudes, longitudes, lat_lon_pairs, output_file_path):
    start_time = time.time()
    max_duration = 4 * 60 * 60  # Maximum runtime in seconds

    # Check if the output file exists and write the header only if it doesn't
    if not os.path.exists(output_file_path):
        with open(output_file_path, 'w') as file:
            file.write("Predicted_Latitude,Predicted_Longitude,Datetime,BuoyID\n")
    
    with open(output_file_path, 'a') as file:
        current_lat, current_lon = val_subset.iloc[0][['Latitude', 'Longitude']]
        current_uwnd, current_vwnd = val_subset.iloc[0][['era5_uwnd', 'era5_vwnd']]
        buoy_id = val_data.iloc[0]['BuoyID']
        
        print("\nInitial conditions:")
        print(f"Latitude: {current_lat:.2f}, Longitude: {current_lon:.2f}, Datetime: {val_data.iloc[0]['datetime']}, BuoyID: {buoy_id}")

        for i in range(1, len(val_data)):
            # Check if the maximum duration has been exceeded
            elapsed_time = time.time() - start_time
            if elapsed_time > max_duration:
                print("Maximum duration exceeded. Stopping the script.")
                break

            next_row = val_data.iloc[i]
            
            input_data = pd.DataFrame({
                'Latitude': [current_lat],
                'Longitude': [current_lon],
                'era5_uwnd': [current_uwnd],
                'era5_vwnd': [current_vwnd]
            })
            
            prediction_start_time = time.time()
            predicted_displacement, predicted_heading = model.predict(input_data)[0]
            predicted_lat, predicted_lon = calculate_new_position(
                (current_lat, current_lon),
                predicted_displacement,
                predicted_heading
            )
            prediction_end_time = time.time()
            print(f"Prediction step {i} took {prediction_end_time - prediction_start_time:.2f} seconds.")
            
            wind_extraction_start_time = time.time()
            predicted_wind_u, predicted_wind_v = extract_wind_components(
                predicted_lat, 
                predicted_lon, 
                next_row['datetime'],
                tree,
                valid_times,
                latitudes,
                longitudes,
                lat_lon_pairs
            )
            wind_extraction_end_time = time.time()
            print(f"Wind extraction step {i} took {wind_extraction_end_time - wind_extraction_start_time:.2f} seconds.")

            print(f"\nPrediction for row {i}:")
            print(f"Predicted Latitude: {predicted_lat:.2f}, Predicted Longitude: {predicted_lon:.2f}")
            print(f"Predicted Displacement: {predicted_displacement:.2f}, Predicted Heading: {predicted_heading:.2f}")
            print(f"Updated Wind U: {predicted_wind_u:.2f}, Wind V: {predicted_wind_v:.2f}")
            print(f"Target Datetime: {next_row['datetime']}, BuoyID: {buoy_id}")

            # Write results to the file
            file.write(f"{round(predicted_lat, 3)},{round(predicted_lon, 3)},{next_row['datetime']},{buoy_id}\n")

            # Stop the script if the file size exceeds 2 GB
            file_size = file.tell()
            if file_size > 2 * 1024 * 1024 * 1024:
                print("File size exceeded 2 GB. Stopping the script.")
                break

            # Update current state for next iteration
            current_lat, current_lon = predicted_lat, predicted_lon
            current_uwnd, current_vwnd = predicted_wind_u, predicted_wind_v

            gc.collect()
            if i % 10 == 0:
                print(f"Processed {i} predictions...")

    print(f"\nPrediction complete. Results saved to {output_file_path}")


# Get unique BuoyIDs
unique_buoy_ids = val_data['BuoyID'].unique()

# Iterate over each BuoyID in the validation data and make predictions
for buoy_id in unique_buoy_ids:
    # Set output file path
    output_file_path = f'../data/predictions/predictions_{buoy_id}.csv'
    
    # Subset the data for the current BuoyID
    buoy_data = val_data[val_data['BuoyID'] == buoy_id]

    # Run iterative predictions on validation subset
    print("\nStarting iterative predictions on validation subset...")
    iterative_prediction(
        val_data=buoy_data,
        model=model,
        tree=tree,
        valid_times=valid_time_dt,
        latitudes=latitudes,
        longitudes=longitudes,
        lat_lon_pairs=lat_lon_pairs,
        output_file_path=output_file_path
    )