In [None]:

import numpy as np
import pandas as pd
from shapely.geometry import LineString, Point, shape
import geopandas as gpd
from geopy.distance import great_circle
import xarray as xr
from geographiclib.geodesic import Geodesic

# From the ipynb driver, user input: ###################################
# User Inputs:
start_time_str       = '2023-01-01T00:00:00Z'
stop_time_str        = '2023-12-31T23:59:59Z'
query_limit          = 15e4
send_notification    = True
make_plot            = True
output_dir           = "/scratch/omg28/Data/"

# Convert start and stop times to datetime objects
start_time_simple = pd.to_datetime(start_time_str).strftime("%Y-%m-%d")
stop_time_simple = pd.to_datetime(stop_time_str).strftime("%Y-%m-%d")
analysis_year = pd.to_datetime(start_time_str).year

# Define grid
lat_bins = np.arange(-90, 90.1, 0.5)
lon_bins = np.arange(-180, 180.1, 0.5)
alt_bins_ft = np.arange(0, 55001, 1000)
alt_bins_m = alt_bins_ft * 0.3048
nlat, nlon, nalt = len(lat_bins)-1, len(lon_bins)-1, len(alt_bins_m)-1
########################################################################

# MATH HELPER FUNCTIONS ##################################
def secd(theta):
    """Convert degrees to seconds."""
    if theta%360 == 90 or theta%360 == 270:
        return np.nan
    return 1/np.cos(np.radians(theta))
###########################################################


# Define conflict countries and buffer degrees

conflict_countries = ['Russia', 'Ukraine', 'Libya', 'Syria', 'Sudan', 'Yemen']
buffer_degrees = 1.0

# Load country boundaries
world = gpd.read_file('ne_110m_admin_0_countries.zip')
name_col = 'NAME' if 'NAME' in world.columns else 'name'

# Get geometries for conflict countries
conflict_geometries = []
for c in conflict_countries:
    country_geom = world[world[name_col] == c].geometry
    if not country_geom.empty:
        conflict_geometries.append(country_geom.union_all().buffer(buffer_degrees))

# Create a union of `all conflict areas
if conflict_geometries:
    conflict_areas = conflict_geometries[0]
    for geom in conflict_geometries[1:]:
        conflict_areas = conflict_areas.union(geom)
    conflict_areas_buffered = conflict_areas.buffer(buffer_degrees)
else:
    # If no valid countries found, create empty geometry
    conflict_areas_buffered = Point(0, 0).buffer(0)


#TEST CODE ##################################
flights = pd.read_pickle('/scratch/omg28/Data/2023-01-01_to_2023-01-31_labeled.pkl')
row = flights.iloc[244256] # long flight from EGLL to NZCH
era5_file = f"/scratch/omg28/Data/winddb/era5_wind_{analysis_year}.nc"
ds_era5 = xr.open_dataset(era5_file)
print(ds_era5)

cruise_alt_ft = 35000  # feet example value
cruise_speed_ms = 250  # m/s, example value
#################################################################


cruise_distance_m = row['gc_FEAT_km'] * 1000


import numpy as np
import pandas as pd
import pickle
from generate_flightpath import generate_flightpath
import os
from multiprocessing import Pool, cpu_count
from geographiclib.geodesic import Geodesic
from xgboost import XGBRegressor

SECONDS_PER_MONTH = 31 * 24 * 3600  # January
REMOVAL_TIMESCALE_S = 2 * 24 * 3600  # 2 days

def get_cruise_params(typecode, perf_df):
    try:
        cruise_alt_ft = perf_df.loc[typecode, 'cruise_Ceiling'] * 100 if not pd.isnull(perf_df.loc[typecode, 'cruise_Ceiling']) else 35000
        cruise_speed_ms = perf_df.loc[typecode, 'cruise_TAS'] * 0.514444 if not pd.isnull(perf_df.loc[typecode, 'cruise_TAS']) else 250
        if cruise_alt_ft <= 0 or np.isnan(cruise_alt_ft):
            cruise_alt_ft = 35000
        if cruise_speed_ms <= 0 or np.isnan(cruise_speed_ms):
            cruise_speed_ms = 250
        return cruise_alt_ft, cruise_speed_ms
    except Exception:
        return 35000, 250

def process_flight(args):
    row, xgb_models, perf_df, lat_bins, lon_bins, alt_bins_ft, nlat, nlon, nalt = args
    typecode = row['typecode']
    model = xgb_models.get(typecode)
    if model is None:
        return []
    try:
        cruise_alt_ft, cruise_speed_ms = get_cruise_params(typecode, perf_df)
        fp = generate_flightpath(typecode, row['gc_FEAT_km'], None)
        cruise_alt_ft = fp.get('cruise', {}).get('cruise_altitude_ft', cruise_alt_ft)
    except Exception:
        cruise_alt_ft, cruise_speed_ms = get_cruise_params(typecode, perf_df)
    features = np.array([[row['gc_FEAT_km'], cruise_alt_ft]])
    mean_nox_flux = model.predict(features)[0]
    cruise_distance_m = row['gc_FEAT_km'] * 1000
    cruise_time_s = cruise_distance_m / cruise_speed_ms
    total_nox_g = mean_nox_flux * cruise_time_s
    total_nox_kg = total_nox_g / 1000
    
    n_segments = int(np.ceil(cruise_distance_m / 10000))
    geod = Geodesic.WGS84
    line = geod.InverseLine(row['estdeparturelat'], row['estdeparturelong'],
                            row['estarrivallat'], row['estarrivallong'])
    ds = cruise_distance_m / n_segments
    lats, lons = [], []
    for i in range(n_segments):
        s = min(ds * i, line.s13)
        pos = line.Position(s)
        lats.append(pos['lat2'])
        lons.append(pos['lon2'])
    alts = np.full(n_segments, cruise_alt_ft)

    box_fraction = REMOVAL_TIMESCALE_S / (SECONDS_PER_MONTH + REMOVAL_TIMESCALE_S)
    nox_per_segment = total_nox_kg / n_segments * box_fraction

    updates = []
    for i in range(n_segments):
        lat_idx = np.searchsorted(lat_bins, lats[i], side='right') - 1
        lon_idx = np.searchsorted(lon_bins, lons[i], side='right') - 1
        alt_idx = np.searchsorted(alt_bins_ft, alts[i], side='right') - 1
        if 0 <= lat_idx < nlat and 0 <= lon_idx < nlon and 0 <= alt_idx < nalt:
            updates.append((lat_idx, lon_idx, alt_idx, nox_per_segment))       
    return updates

def process_month_emissions(
    month_start_time_str: str,
    output_dir: str = "/scratch/omg28/Data/no_track2023/emissions/",
    performance_and_emissions_model: pd.DataFrame = pd.read_pickle('performance_and_emissions_model.pkl')
):
    start_time_str_loop = pd.to_datetime(month_start_time_str)
    stop_time_str_loop = (start_time_str_loop + pd.offsets.MonthEnd(1)).replace(hour=23, minute=59, second=59)
    start_time_simple_loop = pd.to_datetime(start_time_str_loop).strftime("%Y-%m-%d")
    stop_time_simple_loop = pd.to_datetime(stop_time_str_loop).strftime("%Y-%m-%d")

    # Load flights data
    monthly_flights = pd.read_pickle(f'{output_dir}/{start_time_simple_loop}_to_{stop_time_simple_loop}_filtered.pkl')
    model_dir = 'saved_models_nox_flux'
    typecodes = monthly_flights['typecode'].unique()

    # Load all xgboost models into memory for speed
    xgb_models = {}
    for typecode in typecodes:
        model_path = os.path.join(model_dir, f'xgb_{typecode}.ubj')
        if os.path.exists(model_path):
            model = XGBRegressor()
            model.load_model(model_path)
            xgb_models[typecode] = model


    # Prepare cruise altitude and speed lookup from performance_and_emissions_model
    perf_df = performance_and_emissions_model.set_index('typecode')

    # Define grid
    lat_bins = np.arange(-90, 90.1, 0.5)
    lon_bins = np.arange(-180, 180.1, 0.5)
    alt_bins_ft = np.arange(0, 55001, 1000)
    alt_bins_m = alt_bins_ft * 0.3048
    nlat, nlon, nalt = len(lat_bins)-1, len(lon_bins)-1, len(alt_bins_m)-1
    nox_grid = np.zeros((nlat, nlon, nalt), dtype=np.float64)

    # Prepare arguments for loop
    pool_args = [
        (row, xgb_models, perf_df, lat_bins, lon_bins, alt_bins_ft, nlat, nlon, nalt)
        for _, row in monthly_flights.iterrows()
    ]

    # Simple for loop instead of multiprocessing
    results = []
    for args in pool_args:
        updates = process_flight(args)
        results.append(updates)

    # Aggregate results
    for updates in results:
        for lat_idx, lon_idx, alt_idx, nox in updates:
            nox_grid[lat_idx, lon_idx, alt_idx] += nox

    # Optionally: Save as NetCDF or CSV for further analysis
    output_dir = os.path.expanduser(output_dir)
    os.makedirs(f'{output_dir}/emissions', exist_ok=True)
    filename = os.path.join(output_dir, f'emissions/{start_time_simple_loop}_to_{stop_time_simple_loop}_NOx_nowar.npy')
    np.save(filename, nox_grid)
    return filename


In [2]:
#==================================================================================
#--Optimal flight route calculation compared against IAGOS actual route
#--From Ed Gryspeerdt and Olivier Boucher
#--October 2022
#==================================================================================
#
#--import numpy packages
import bisect
#
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import matplotlib.colors as colors
from datetime import datetime,timedelta
from scipy.optimize import minimize
from scipy import interpolate
from sklearn.metrics import mean_squared_error
import cartopy.crs as ccrs
import cartopy
import pandas as pd
import great_circle_calculator.great_circle_calculator as gcc
import argparse
import glob, os, sys
import warnings
import time
from math import exp,log
import pickle
from scipy.ndimage import gaussian_filter1d


from FlightTrajectories.misc_geo import haversine, bearing, nearest, closest_argmin, sph2car, car2sph, G, el_foeew, ei_foeew
from FlightTrajectories.optimalrouting import ZermeloLonLat
from FlightTrajectories.minimization import *


def process_grid(xr_u200, xr_v200, nbts, 
       lon_p1, lat_p1, lon_p2, lat_p2, lons_wind, lats_wind, 
       lon_iagos_values, lat_iagos_values, lon_key_values, lat_key_values):
    """rotate globe to have trajectory around equator"""
    #--prepare plate grid
    plate = ccrs.PlateCarree()
    
    #--prepare meshgrid
    xx,yy=np.meshgrid(lons_wind,lats_wind)
    
    #--Deal with rotated grid
    #--convert lat-lon to radians
    lat_p1_rad=np.deg2rad(lat_p1)
    lon_p1_rad=np.deg2rad(lon_p1)
    lat_p2_rad=np.deg2rad(lat_p2)
    lon_p2_rad=np.deg2rad(lon_p2)
    
    #--convert to cartesian coordinates
    P1=sph2car(lat_p1_rad,lon_p1_rad)
    P2=sph2car(lat_p2_rad,lon_p2_rad)
    
    #--cross product P1^P2 - perpendicular to P1 and P2
    ON=np.cross(P1,P2)
    
    #--coordinates of new Pole in old system
    THETA,PHI=car2sph(ON)
    lat_pole=np.rad2deg(THETA)
    lon_pole=np.rad2deg(PHI)
    print('Coord lat lon new pole in original grid=',"{:5.2f}".format(lat_pole),"{:5.2f}".format(lon_pole))
    
    #--prepare rotated grid
    rotated = ccrs.RotatedPole(pole_latitude=lat_pole,pole_longitude=lon_pole)
    
    #--new coordinates of old pole (diagnostics)
    xyz=rotated.transform_points(plate,np.array([0.]),np.array([90.]))
    lon_pole_t=xyz[0,0]
    lat_pole_t=xyz[0,1]
    print('lon lat pole t=',"{:5.2f}".format(lon_pole_t),"{:5.2f}".format(lat_pole_t))
    
    #--new coordinates of p1,p2 points
    xyz=rotated.transform_points(plate,np.array([lon_p1,lon_p2]),np.array([lat_p1,lat_p2]))
    lon_p1=xyz[0,0] ; lat_p1=xyz[0,1]
    lon_p2=xyz[1,0] ; lat_p2=xyz[1,1]
    print('lon lat p1 t=',"{:5.2f}".format(lon_p1),"{:5.2f}".format(lat_p1))
    print('lon lat p2 t=',"{:5.2f}".format(lon_p2),"{:5.2f}".format(lat_p2))
    if lon_p2 < lon_p1: 
       print('lon_p2 < lon_p1 may pose problems later on - reconsider new pole')
       print('we stop here as this needs to be fixed if ever it happens')
       sys.exit()
    
    #--rotate lon_iagos and lat_iagos
    xyz=rotated.transform_points(plate,lon_iagos_values,lat_iagos_values)
    lon_iagos_values=xyz[:,0] 
    lat_iagos_values=xyz[:,1]
    
    #--rotate lon_key and lat_key values
    xyz=rotated.transform_points(plate,lon_key_values,lat_key_values)
    lon_key_values=xyz[:,0] 
    lat_key_values=xyz[:,1]
    
    #--rotate wind field on original gridpoints
    xr_u200_values_t=np.zeros((xr_u200['u'].values.shape))
    xr_v200_values_t=np.zeros((xr_v200['v'].values.shape))
    for t in range(nbts):
        xr_u200_values_t[t,:,:],xr_v200_values_t[t,:,:]=rotated.transform_vectors(plate,xx,yy,xr_u200['u'].values[t,:,:],xr_v200['v'].values[t,:,:])
    
    #--rotate meshgrid
    xyz=rotated.transform_points(plate,xx,yy)
    xx_t=xyz[:,:,0]
    yy_t=xyz[:,:,1]
    
    #--interpolate u_t and v_t on a regular grid on rotated grid using gridddata
    xx_t_yy_t=np.array([[ixt,iyt] for ixt,iyt in zip(xx_t.flatten(),yy_t.flatten())])
    for t in range(nbts):
       xr_u200['u'].values[t,:,:]=interpolate.griddata(xx_t_yy_t,xr_u200_values_t[t,:,:].flatten(),(xx,yy),method='nearest')
       xr_v200['v'].values[t,:,:]=interpolate.griddata(xx_t_yy_t,xr_v200_values_t[t,:,:].flatten(),(xx,yy),method='nearest')

    return plate, xyz, lon_pole_t, lat_pole_t, lon_p1, lat_p1, lon_p2, lat_p2, xx, yy, lon_iagos_values, lat_iagos_values, rotated, lon_key_values, lat_key_values, lon_pole, lat_pole


#-----------------------------------------------------------------------------------------
# MAIN CODE 
#-----------------------------------------------------------------------------------------

def read_data(iagos_file, Dt_ERA):
    """Read list of IAGOS files
        Default: use inputfile"""

    #--print out file
    print('\n')
    print(iagos_file)
    
    #--open IAGOS file
    iagos=xr.open_dataset(iagos_file)
    
    #--get IAGOS id from file name
    iagos_id=iagos_file.split('/')[-1].split('_')[2]
    
    #--extract metadata from flight
    dep_airport_iagos=iagos.departure_airport.split(',')[0]
    arr_airport_iagos=iagos.arrival_airport.split(',')[0]
    dep_time_iagos=datetime.strptime(iagos.departure_UTC_time,"%Y-%m-%dT%H:%M:%SZ")
    arr_time_iagos=datetime.strptime(iagos.arrival_UTC_time,"%Y-%m-%dT%H:%M:%SZ")
    ave_time_iagos=dep_time_iagos+(arr_time_iagos-dep_time_iagos)/2.
    flightid_iagos=iagos.platform.split(',')[3].lstrip().rstrip()
    print('Flightid=',flightid_iagos,dep_airport_iagos,arr_airport_iagos)
    print('Flight departure time=',dep_time_iagos)
    print('Flight arrival time=',arr_time_iagos)
    print('Flight average time=',ave_time_iagos)
    
    #--extract data
    lat_iagos=iagos['lat']
    lon_iagos=iagos['lon']
    time_iagos=iagos['UTC_time']
    pressure_iagos=iagos['air_press_AC']
    
    #--define arrays of lon, lat, and pressure values
    lon_iagos_values=lon_iagos.values
    lat_iagos_values=lat_iagos.values
    pressure_iagos_values=pressure_iagos.values
    
    #--extract and convert IAGOS departure date
    yr_iagos=dep_time_iagos.year
    mth_iagos=dep_time_iagos.month
    day_iagos=dep_time_iagos.day
    hr_iagos=dep_time_iagos.hour
    hr_iagos_closest,hr_iagos_ind=nearest([i*Dt_ERA for i in range(24//Dt_ERA)],hr_iagos)
    stryr=dep_time_iagos.strftime("%Y")
    strmth=dep_time_iagos.strftime("%m")
    strday=dep_time_iagos.strftime("%d")
    
    #--find lon, lat, alt of cruising similar to what FR24 database does
    #--this is empirical and needs to be checked
    #--the threshold depend on the time resolution of the IAGOS data
    ind=np.where((np.abs(np.diff(gaussian_filter1d(pressure_iagos,40)))<50.) & (pressure_iagos[:-1]<35000.))[0]
    
    #--eliminate low-level flights
    if len(ind) == 0:
        print('This flight is too low to be optimized so we stop here')
        return None
    
    #--find longitude and latitude of beginning and end of cruising phase
    lon_p1=lon_iagos_values[ind[0]]
    lon_p2=lon_iagos_values[ind[-1]]
    lat_p1=lat_iagos_values[ind[0]]
    lat_p2=lat_iagos_values[ind[-1]]
    
    lon_key_values=np.array([lon_iagos_values[0],lon_p1,lon_p2,lon_iagos_values[-1]])
    lat_key_values=np.array([lat_iagos_values[0],lat_p1,lat_p2,lat_iagos_values[-1]])
    alt_key_values=np.array([pressure_iagos_values[0],pressure_iagos_values[ind[0]],pressure_iagos_values[ind[-1]],pressure_iagos_values[-1]])
    
    print('Departure and arrival points')
    print('lon lat p1=',lon_p1,lat_p1)
    print('lon lat p2=',lon_p2,lat_p2)

    return (iagos_id, dep_airport_iagos, arr_airport_iagos, dep_time_iagos, 
           arr_time_iagos, ave_time_iagos, flightid_iagos, lat_iagos,
           lon_iagos, time_iagos, pressure_iagos, yr_iagos, mth_iagos, day_iagos,
           hr_iagos, hr_iagos_closest, hr_iagos_ind, stryr, strmth, strday,
           ind, lon_p1, lon_p2, lat_p1, lat_p2, lon_key_values, lat_key_values, alt_key_values)

def compute_IAGOS_route(lon_shortest, lon_iagos_values, lat_iagos_values, 
        pressure_iagos_values, lon_p1, lon_p2, lat_p1, lat_p2, idx1, idx2,
        xr_u200_reduced, xr_v200_reduced, airspeed, lons_wind, lats_wind):

    #--interpolated latitude of IAGOS flight with similar sampling
    nlon=len(lon_shortest)
    imid=nlon//2
    idxmid=(np.abs(lon_iagos_values-lon_shortest[imid])).argmin()
    lon_iagos_cruising=[lon_iagos_values[idxmid]]
    lat_iagos_cruising=[lat_iagos_values[idxmid]]
    pressure_iagos_cruising=[pressure_iagos_values[idxmid]]

    if idx1 >= idxmid or idx2 <= idxmid: 
        print('This is a dodgy case with idxmid not in between idx1 and idx2 - would need some investigation')
        return None #continue
    #--flight is eastbound in new coordinates (lon_p1 < lon_p2)
    for i in range(imid+1,len(lon_shortest)):
       ilon=np.max(np.where(lon_iagos_values[idxmid:idx2+1]<=lon_shortest[i]))
       lon_iagos_cruising=lon_iagos_cruising+[lon_iagos_values[idxmid+ilon]]
       lat_iagos_cruising=lat_iagos_cruising+[lat_iagos_values[idxmid+ilon]]
       pressure_iagos_cruising=pressure_iagos_cruising+[pressure_iagos_values[idxmid+ilon]]
    for i in range(imid-1,-1,-1):
       ilon=np.min(np.where(lon_iagos_values[idx1:idxmid]>=lon_shortest[i]))
       lon_iagos_cruising=[lon_iagos_values[idx1+ilon]]+lon_iagos_cruising
       lat_iagos_cruising=[lat_iagos_values[idx1+ilon]]+lat_iagos_cruising
       pressure_iagos_cruising=[pressure_iagos_values[idx1+ilon]]+pressure_iagos_cruising
    
    #--put the correct departure and arrival coordinates
    lon_iagos_cruising[0]=lon_p1
    lon_iagos_cruising[-1]=lon_p2
    lat_iagos_cruising[0]=lat_p1
    lat_iagos_cruising[-1]=lat_p2
    
    #--conversion to np array
    lon_iagos_cruising=np.array(lon_iagos_cruising)
    lat_iagos_cruising=np.array(lat_iagos_cruising)
    pressure_iagos_cruising=np.array(pressure_iagos_cruising)
    
    #--IAGOS route
    p1_iagos=(lon_iagos_cruising[0],lat_iagos_cruising[0])
    p2_iagos=(lon_iagos_cruising[-1],lat_iagos_cruising[-1])
    dist_iagos = haversine(p1_iagos[1], p1_iagos[0], p2_iagos[1], p2_iagos[0])
    dist_gcc_iagos = gcc.distance_between_points(p1_iagos,p2_iagos,unit='kilometers')

    dt_iagos_2=cost_time(lon_iagos_cruising, lat_iagos_cruising, lons_wind, lats_wind, xr_u200_reduced, xr_v200_reduced, airspeed, dtprint=False)
    
    return (lon_iagos_cruising, lat_iagos_cruising, pressure_iagos_cruising, p1_iagos, p2_iagos, dist_iagos, dist_gcc_iagos, dt_iagos_2)

def make_plot(rotated, lon_iagos_values, lat_iagos_values, lon_key_values, lat_key_values, alt_key_values,
              lon_shortest, lat_shortest, lon_quickest, lat_quickest, lon_ed, lat_ed, 
              lons_wind, lats_wind, xr_u200_reduced, xr_v200_reduced,
              iagos_id, flightid_iagos, dep_airport_iagos, arr_airport_iagos, stryr, strmth, strday, # TODO hide inside dict..
              optim_level, dt_shortest, dt_quickest, dt_ed_LD, dt_iagos_2, pathout, yr, 
              pressure_iagos,solution, ind, airspeed, dist_gcc, lat_pole, lon_pole, lon_iagos_cruising, lat_iagos_cruising):
        
    fig=plt.figure(figsize=(10,5))
    ax=fig.add_subplot(111, projection=rotated)
    #--lon-lat
    ax.plot(lon_iagos_values, lat_iagos_values, c='black', lw=2, label='IAGOS')
    ax.scatter(lon_key_values, lat_key_values, c='red', marker='X', lw=1)
    ax.plot(lon_shortest, lat_shortest, c='green', lw=2, label='Shortest',zorder=10)
    ax.plot(lon_quickest, lat_quickest, c='blue',  lw=2, label='Gradient descent')
    if solution: ax.plot(lon_ed, lat_ed, c='cyan', lw=2, label='Zermelo method')
    #--wind field
    ax.quiver(lons_wind[::10],lats_wind[::10],np.transpose(xr_u200_reduced['u'].values[::10,::10]),np.transpose(xr_v200_reduced['v'].values[::10,::10]),scale=1200)
    #--define domain with a margin
    ax.set_xlim(np.min(lon_shortest)-20.,np.max(lon_shortest)+20.)
    ax.set_ylim(np.max([-90.,np.min(lat_shortest)-20.]),np.min([90.,np.max(lat_shortest)+20.]))
    #--make plot nice
    ax.legend(loc='lower left',fontsize=10)
    fig.gca().coastlines()
    fig.gca().add_feature(cartopy.feature.OCEAN,facecolor=("lightblue"),alpha=1)
    fig.gca().add_feature(cartopy.feature.BORDERS,color="red",linewidth=0.3)
    #--plot title
    plt.title(iagos_id+' '+str(flightid_iagos)+' '+dep_airport_iagos+'=>'+arr_airport_iagos+' '+' '+stryr+strmth+strday+\
              ' level='+str(optim_level)+' hPa'+'\n'+ 'Shortest='+"{:.3f}".format(dt_shortest)+\
              ' Gradient descent='+"{:.3f}".format(dt_quickest)+' Zermelo method='+"{:.3f}".format(dt_ed_LD)+' IAGOS='+"{:.3f}".format(dt_iagos_2),fontsize=11)
    #--save, show and close
    basefile=pathout+'traj_'+str(iagos_id)+'_lev'+str(optim_level)+'_'+str(yr)
    plt.savefig(basefile+'.png',dpi=150,bbox_inches='tight')
    pltshow = True
    if pltshow: plt.show()
    plt.close()
    
    #--plot 2: height
    #----------------
    fig,ax = plt.subplots()
    ax.set_ylim(1030,-300)
    ax.set_ylabel("Pressure (hPa)",color="black",fontsize=14)
    ax.set_xlabel("Longitude",color="black",fontsize=14)
    ax.plot(lon_iagos_values, pressure_iagos/100., c='black', lw=2, label='IAGOS flight pressure')
    ax.plot(lon_iagos_values[ind], pressure_iagos[ind]/100., c='red', lw=2, label='Flight pressure (cruising)',zorder=10)
    ax.plot(lon_iagos_values[:-1], np.diff(gaussian_filter1d(pressure_iagos,40)), c='blue', lw=1, label='Finite difference')
    ax.plot(lon_iagos_values[[0,-1]], [350.,350.], c='black', linestyle='solid', lw=1)
    ax.plot(lon_iagos_values[[0,-1]], [50.,50.], c='black', linestyle='dashed', lw=1)
    ax.plot(lon_iagos_values[[0,-1]], [-50.,-50.], c='black', linestyle='dashed', lw=1)
    ax.scatter(lon_key_values, alt_key_values/100., c='red', marker='X', lw=1, label='Origin - Cruising - Destination')
    #--make plot nice
    plt.title(iagos_id+' '+str(flightid_iagos)+' '+dep_airport_iagos+'=>'+arr_airport_iagos+' '+' '+stryr+strmth+strday+' level='+str(optim_level)+' hPa')
    plt.legend()
    #--save, show and close
    basefile=pathout+'alt_'+str(iagos_id)+'_lev'+str(optim_level)+'_'+str(yr)
    plt.savefig(basefile+'.png',dpi=150,bbox_inches='tight')
    if pltshow: plt.show()
    plt.close()

def ED_quickest_route(p1, p2, airspeed, lon_p1, lon_p2, lat_p1, lat_p2, 
                      lat_shortest, lat_quickest, lat_iagos_cruising, lons_wind, lats_wind, xr_u200_reduced, xr_v200_reduced, npoints):
    start_time = time.time()
    # Create the zermelo solver. Note the default values
    #--max_dest_dist is in metres
    #--sub_factor: number of splits for next round if solution is bounded by pairs of trajectories
    #--psi_range: +/- angle for the initial bearing
    #--psi_res: resolution within psi_range bounds, could try 0.2 instead
    zermelolonlat = ZermeloLonLat(cost_func=lambda x, y, z: np.ones(np.atleast_1d(x).shape),
                                  wind_func=wind, timestep=60, psi_range=60, psi_res=0.5,
                                  length_factor=1.4, max_dest_distance=75000., sub_factor=80)
    
    initial_psi = zermelolonlat.bearing_func(*p1, *p2)
    psi_vals = np.linspace(initial_psi-60, initial_psi+60, 30)
    #--This prodcues a series of Zermelo trajectories for the given initial directions
    zloc, zpsi, zcost = zermelolonlat.zermelo_path(np.repeat(np.array(p1)[:, None], len(psi_vals), axis=-1),lons_wind, lats_wind, xr_u200_reduced, xr_v200_reduced, 
                        # This 90 is due to an internal conversion between bearings and angles
                        # - which is obviously a bad idea... noramlly it is hidden internally
                        ##   90-psi_vals, nsteps=800, airspeed=250, dtime=dep_time_iagos)
                        90-psi_vals, nsteps=800, airspeed=airspeed, dtime=0) #--modif OB
    
    # This identifies the optimal route
    solution, fpst, ftime, flocs, fcost = zermelolonlat.route_optimise(np.array(p1), np.array(p2),  lons_wind, lats_wind, xr_u200_reduced, xr_v200_reduced, airspeed=airspeed, dtime=0)
    #--if solution was found
    if solution: 
      lon_ed=flocs[:,0]
      lat_ed=flocs[:,1]
      #
      #--compute Ed's time by stretching slightly the trajectory to the same endpoints
      npoints_ed=len(lon_ed)
      print('npoints_ed=',npoints_ed)
      lon_ed=lon_ed+(lon_p2-lon_ed[-1])*np.arange(npoints_ed)/float(npoints_ed-1)
      lat_ed=lat_ed+(lat_p2-lat_ed[-1])*np.arange(npoints_ed)/float(npoints_ed-1)
      #--compute corresponding time 
      dt_ed_HD=cost_time(lon_ed, lat_ed, lons_wind, lats_wind, xr_u200_reduced, xr_v200_reduced, airspeed, dtprint=False)
      print('Cruising flight time ED (high res) =',"{:6.4f}".format(dt_ed_HD),'hours')
      lon_ed_LD=np.append(lon_ed[::npoints_ed//npoints],[lon_ed[-1]])
      lat_ed_LD=np.append(lat_ed[::npoints_ed//npoints],[lat_ed[-1]])
      dt_ed_LD=cost_time(lon_ed_LD, lat_ed_LD, lons_wind, lats_wind, xr_u200_reduced, xr_v200_reduced, airspeed, dtprint=False)
      print('Cruising flight time ED (low res) =',"{:6.4f}".format(dt_ed_LD),'hours')
    else: 
      print('No solution found by Zermelo')  
      lon_ed=float('inf')
      lat_ed=float('inf')
      dt_ed_HD=float('inf')
      lon_ed_LD=float('inf')
      lat_ed_LD=float('inf')
      dt_ed_LD=float('inf')
    end_time = time.time()
    time_elapsed_EG=end_time-start_time
    print('Time elapsed for Zermelo method=',"{:3.1f}".format(time_elapsed_EG),'s')
    
    #--computing indices of quality of fit
    rmse_shortest=mean_squared_error(lat_shortest,lat_iagos_cruising)**0.5
    rmse_quickest=mean_squared_error(lat_quickest,lat_iagos_cruising)**0.5
    lat_max_shortest=np.max(np.abs(lat_shortest-lat_iagos_cruising))
    lat_max_quickest=np.max(np.abs(lat_quickest-lat_iagos_cruising))
    print('rmse and lat max=',rmse_shortest,rmse_quickest,lat_max_shortest,lat_max_quickest)
    
    return rmse_shortest, rmse_quickest, lat_max_shortest, lat_max_quickest, lon_ed_LD, lat_ed_LD, dt_ed_LD, time_elapsed_EG, lon_ed, lat_ed, dt_ed_HD, solution

#--main routine
def opti(yr, mth, inputfile, route, level, maxiter,
         method, precision, path_iagos, pathERA5, pathout,
         nbmeters, nbest, airspeed, disp, pltshow, Dt_ERA):

    if route != '': 
        #--use the selected route
        #--read route flights 
        csvfile = '../FLIGHTS/'+route+'_'+str(yr)+'.csv'
        #--open file
        print('csvfile=',csvfile)
        iagos_files = pd.read_csv(csvfile,header=None,names=['file'])
        iagos_files = list(iagos_files['file'].values)
    elif inputfile != '': 
        #--otherwise use single prescribed file
        iagos_files=[inputfile]
        #--overwrite year from filename in this case
        yr=int(inputfile.split('/')[-1].split('_')[2][0:4])
    else:
        #--Otherwise use year+month
        #--all IAGOS files from selected year and month
        iagos_files=sorted(glob.glob(path_iagos+str(yr)+str(mth).zfill(2)+'/*.nc4'))

    print('We have found '+str(len(iagos_files))+' IAGOS files.')
 
    #--Initialise dataframe for saving output
    final_df=pd.DataFrame(columns=['file_iagos','flightname_iagos','flightid_iagos','level','year','dep_airport','arr_airport',\
                                   'dep_time','arr_time','time_start_cruising','time_end_cruising','lon_start_cruising','lat_start_cruising',\
                                   'lon_end_cruising','lat_end_cruising','time shortest','time OB','time EG','time_iagos',\
                                   'rmse lat shortest','rmse lat quickest','time elapsed OB','time elapsed EG','airspeed','dist_gcc'])
    
    #--Open ERA5 wind files as xarray objects
    #TODO PATHERA5
    stryr=str(yr)
    file_u=pathERA5+'u.'+stryr+'.GLOBAL.nc'
    file_v=pathERA5+'v.'+stryr+'.GLOBAL.nc'
    #file_t=pathERA5+'ta.'+stryr+'.GLOBAL.nc'
    #file_r=pathERA5+'r.'+stryr+'.GLOBAL.nc'

    print(file_u)
    print(file_v)
    #print(file_t)
    #print(file_r)

    xr_u=xr.open_dataset(file_u)  
    xr_v=xr.open_dataset(file_v)  
    #xr_t=xr.open_dataset(file_t)  
    #xr_r=xr.open_dataset(file_r)  

    #--Extract coordinates
    levels_wind=list(xr_u['level'].values)
    lons_wind=xr_u['longitude'].values
    lats_wind=xr_u['latitude'].values
 
    #--Loop on flights
    for iagos_file in iagos_files:
        #
        _data = read_data(iagos_file, Dt_ERA)
        #
        if _data is None:
            continue
        else:
            (iagos_id, dep_airport_iagos, arr_airport_iagos, dep_time_iagos, 
             arr_time_iagos, ave_time_iagos, flightid_iagos, lat_iagos,
             lon_iagos, time_iagos, pressure_iagos, yr_iagos, mth_iagos, day_iagos,
             hr_iagos, hr_iagos_closest, hr_iagos_ind, stryr, strmth, strday,
             ind, lon_p1, lon_p2, lat_p1, lat_p2, lon_key_values, lat_key_values, alt_key_values) = _data
        
        #--compute great circle distance
        dist = haversine(lat_p1, lon_p1, lat_p2, lon_p2)
        dist_gcc = gcc.distance_between_points((lon_p1,lat_p1),(lon_p2,lat_p2),unit='meters')
        print('Distance between airports (haversine & gcc) = ',"{:6.3f}".format(dist/1000.),"{:6.3f}".format(dist_gcc/1000.),'km')
        
        #--compute number of legs
        npoints = int(dist // nbmeters)
        
        #--select IAGOS datapoints as close as possible to FR24 datapoints
        idx1 = ((lon_iagos.values-lon_p1)**2.0+(lat_iagos.values-lat_p1)**2.0).argmin()
        idx2 = ((lon_iagos.values-lon_p2)**2.0+(lat_iagos.values-lat_p2)**2.0).argmin()

        #--compute actual IAGOS time flight during cruising (in hours)
        dt_iagos_1=float(time_iagos.values[idx2]-time_iagos.values[idx1])/3600./1.e9 #--convert nanoseconds => hr
        
        #--compute average IAGOS pressure (in hPa)
        ave_pressure_iagos=np.average(pressure_iagos[idx1:idx2])/100.
        print('Pressure levels in ERA5 file=',levels_wind)

        #--find closest pressure level in ERA5 data
        pressure_iagos_closest,pressure_ind_closest=nearest(levels_wind,ave_pressure_iagos)
        print('Average pressure=',"{:5.2f}".format(ave_pressure_iagos),'hPa closest to',pressure_iagos_closest,'hPa')
        
        #--select pressure level for optimisation
        if level == -1:
           optim_level=pressure_iagos_closest
        else:
           optim_level=level
        
        #--time ERA5 preparation
        start_time = time.time()
        
        #--pre-sample times 
        nbts=int(dt_iagos_1/Dt_ERA)+2
        
        #--times to extract (3-hourly) from start to end of flight
        times_to_extract=[datetime(yr_iagos,mth_iagos,day_iagos,hr_iagos_closest,0)+timedelta(hours=i*Dt_ERA) for i in range(nbts)]
        
        #--preload the data for a range of nbts times
        xr_u200=xr_u.sel(level=optim_level,time=times_to_extract).load()
        xr_v200=xr_v.sel(level=optim_level,time=times_to_extract).load()
        
        #--select array (m/s)
        xr_u200_values=xr_u200['u'].values
        xr_v200_values=xr_v200['v'].values
        
        #--prepare plate grid
        plate, xyz, lon_pole_t, lat_pole_t, lon_p1, lat_p1, lon_p2, lat_p2, xx, yy,                             \
             lon_iagos_values, lat_iagos_values, rotated, lon_key_values, lat_key_values, lon_pole, lat_pole =  \
             process_grid(xr_u200, xr_v200, nbts, lon_p1, lat_p1, lon_p2, lat_p2, lons_wind, lats_wind,         \
             lon_iagos.values, lat_iagos.values, lon_key_values, lat_key_values)

        lon_iagos.values = lon_iagos_values
        lat_iagos.values = lat_iagos_values

        #--substitute data back in their original xarray objects
        xr_u200['u'].values=xr_u200_values
        xr_v200['v'].values=xr_v200_values
        
        #--dt per degree of longitude
        dtime_per_degree=dt_iagos_1/abs(lon_p2-lon_p1) #--this assumes lon_p2 > lon_p1
        
        #--compute times_era assuming uniform sampling of longitudes
        times_wind=[]

        _idx_p1 = bisect.bisect_right(lons_wind, lon_p1) - 1 
        _idx_p2 = bisect.bisect_right(lons_wind, lon_p2) + 1

        lons_wind_reduced = lons_wind[_idx_p1:_idx_p2]

        lons_z = xr.DataArray(lons_wind_reduced, dims="z")

        #--finding the corresponding times
        for lon in list(lons_wind_reduced):
          time_to_append,time_ind_closest=nearest(times_to_extract,dep_time_iagos+timedelta(hours=dtime_per_degree*(lon-lon_p1)))
          times_wind.append(time_to_append)
        times_wind=np.array(times_wind)

        #--initial and final time tags to ensure back-comptability but could be removed
        times_wind[np.where(lons_wind_reduced<lon_p1)]=times_to_extract[0]
        times_wind[np.where(lon_p2<lons_wind_reduced)]=times_to_extract[-1]
        
        times_z = xr.DataArray(times_wind, dims="z")

        #--preload the data for a range of nbts times
        xr_u200_reduced=xr_u200.sel(longitude=lons_z,time=times_z,latitude=lats_wind).load()
        xr_v200_reduced=xr_v200.sel(longitude=lons_z,time=times_z,latitude=lats_wind).load()
        #xr_u200_reduced.chunk(chunks={"z":"auto"})
        #xr_v200_reduced.chunk(chunks={"z":"auto"})
        
        end_time = time.time()
        time_elapsed_ERA=end_time-start_time
        print('Time elapsed for ERA5=',"{:3.1f}".format(time_elapsed_ERA),'s')
        
        #--define p1 and p2 as tuples
        p1 = (lon_p1, lat_p1)
        p2 = (lon_p2, lat_p2)
        
        #--flatten arrays and create coordinate vector
        xx_yy=np.array([[ixx,iyy] for ixx,iyy in zip(xx.flatten(),yy.flatten())])
        
        #--check longitudes are mostly monotonic in IAGOS flight and stop otherwise
        print('monotonicity of IAGOS longitudes =',np.sum(np.diff(lon_iagos.values) >= 0)/len(lon_iagos.values), \
                                                   np.sum(np.diff(lon_iagos.values) <= 0)/len(lon_iagos.values))
        if (np.sum(np.diff(lon_iagos.values) >= 0)/len(lon_iagos.values) < 0.90 and \
            np.sum(np.diff(lon_iagos.values) <= 0)/len(lon_iagos.values) < 0.90): 
            print('Flight longitudes are not monotonic enough so we stop here for this flight')
            continue
        
        #--compute shortest route
        lon_shortest, lat_shortest = shortest_route(p1, p2, npoints)
        
        #--compute time of shortest route 
        dt_shortest=cost_time(lon_shortest, lat_shortest, lons_wind_reduced, lats_wind, xr_u200_reduced, xr_v200_reduced, airspeed, dtprint=False)
        print('Cruising flight time shortest =',"{:6.4f}".format(dt_shortest),'hours')
        
        #---------------------
        #--compute IAGOS route
        #---------------------
        iagos_route = compute_IAGOS_route(lon_shortest, lon_iagos.values, lat_iagos.values, 
                                          pressure_iagos.values, lon_p1, lon_p2, lat_p1, lat_p2, idx1, idx2, 
                                          xr_u200_reduced, xr_v200_reduced, airspeed, lons_wind_reduced, lats_wind)
        if iagos_route is None:
            continue
        else:
            (lon_iagos_cruising, lat_iagos_cruising, pressure_iagos_cruising, p1_iagos, p2_iagos, dist_iagos, dist_gcc_iagos, dt_iagos_2) = iagos_route
            print('IAGOS cruising flight time actual and sampled estimated =',"{:6.4f}".format(dt_iagos_1),"{:6.4f}".format(dt_iagos_2),'hours')

        #---------------------------
        #--compute OB quickest route
        #---------------------------
        start_time = time.time()
        if precision=='low':
            #--fast, less accurate version
           lon_quickest, lat_quickest, dt_quickest = quickest_route_fast(p1, p2, npoints, nbest, lat_iagos_cruising, lons_wind_reduced, lats_wind, 
                                                                         xr_u200_reduced, xr_v200_reduced, airspeed, method, disp, maxiter )
        elif precision=='high':
           #--full version
           lon_quickest, lat_quickest, dt_quickest = quickest_route(p1, p2, npoints, lat_iagos_cruising, lons_wind_reduced, lats_wind, 
                                                                    xr_u200_reduced, xr_v200_reduced, airspeed, method, disp, maxiter )
        end_time = time.time()
        time_elapsed_OB=end_time-start_time
        print('Time elapsed for gradient descent method =',"{:3.1f}".format(time_elapsed_OB),'s')
        print('Cruising flight time quickest gradient descent=',"{:6.4f}".format(dt_quickest),'hours')
       
        #---------------------------
        #--compute ED quickest route
        #---------------------------
        rmse_shortest, rmse_quickest, lat_max_shortest, lat_max_quickest, lon_ed_LD, lat_ed_LD, dt_ed_LD,                   \
                      time_elapsed_EG, lon_ed, lat_ed, dt_ed_HD, solution =                                                 \
                      ED_quickest_route(p1, p2, airspeed, lon_p1, lon_p2, lat_p1, lat_p2,                                   \
                      lat_shortest, lat_quickest, lat_iagos_cruising, lons_wind_reduced, lats_wind, xr_u200_reduced, xr_v200_reduced, npoints)
                
        #--fill DataFrame - not very efficient but ok as dataframe is short
        final_df.loc[len(final_df)]=[iagos_file,iagos_id,flightid_iagos,optim_level,yr,dep_airport_iagos,arr_airport_iagos,      \
                                     dep_time_iagos,arr_time_iagos,time_iagos.values[idx1],time_iagos.values[idx2],lon_p1,lat_p1,\
                                     lon_p2,lat_p2,dt_shortest,dt_quickest,dt_ed_LD,dt_iagos_2,rmse_shortest,rmse_quickest,      \
                                     time_elapsed_OB,time_elapsed_EG,airspeed,dist_gcc]
        
        #--save quickest route for Ed in a Dataframe
        route_df=pd.DataFrame(columns=['Total_time_IAGOS','Total_time_quickest','longitudes IAGOS',\
                                       'latitudes IAGOS','longitudes quickest','latitudes quickest'])
        for i in range(len(lat_quickest)):
           new_df = pd.DataFrame({'Total_time_IAGOS':dt_iagos_2*3600.,'Total_time_quickest':dt_quickest*3600.,\
                                  'longitudes IAGOS':lon_iagos_cruising[i],'latitudes IAGOS':lat_iagos_cruising[i],\
                                  'longitudes quickest':lon_quickest[i],'latitudes quickest':lat_quickest[i]},index=[i])
           route_df = pd.concat([route_df,new_df]) 
        route_df.to_csv(pathout+str(iagos_id)+'_lev'+str(optim_level)+'_'+str(yr)+'.csv')
        
        #-plot 1: prepare map traj plot
        #------------------------------
        if pltshow:
            make_plot(rotated, lon_iagos.values, lat_iagos.values, lon_key_values, lat_key_values, alt_key_values,
                  lon_shortest, lat_shortest, lon_quickest, lat_quickest, lon_ed, lat_ed, 
                  lons_wind_reduced, lats_wind, xr_u200_reduced, xr_v200_reduced,
                  iagos_id, flightid_iagos, dep_airport_iagos, arr_airport_iagos, stryr, strmth, strday, # TODO hide inside dict..
                  optim_level, dt_shortest, dt_quickest, dt_ed_LD, dt_iagos_2, pathout, yr, 
                  pressure_iagos, solution, ind, airspeed, dist_gcc, lat_pole, lon_pole, lon_iagos_cruising, lat_iagos_cruising)

        
        #--save a pickle to redo the plots later if needed
        dico={'iagos_id':iagos_id,'flightid_iagos':flightid_iagos,'airspeed':airspeed,'dist_gcc':dist_gcc,\
              'lat_pole': lat_pole,'lon_pole':lon_pole,'lon_iagos_values':lon_iagos.values,'lon_iagos_cruising':lon_iagos_cruising,'lat_iagos_values':lat_iagos.values,\
              'lon_key_values':lon_key_values,'optim_level':optim_level,'solution':solution,\
              'lat_key_values':lat_key_values,'lon_shortest':lon_shortest,'lat_shortest':lat_shortest,'pressure_iagos':pressure_iagos,'alt_key_values':alt_key_values,\
              'lon_quickest':lon_quickest,'lat_quickest':lat_quickest,'lon_ed':lon_ed,'lat_ed':lat_ed,'lons_wind':lons_wind,'lats_wind':lats_wind,\
              'xr_u200':xr_u200_reduced,'xr_v200':xr_v200_reduced,'dep_airport_iagos':dep_airport_iagos,'arr_airport_iagos':arr_airport_iagos,\
              'stryr':stryr,'strmth':strmth,'strday':strday,'dt_shortest':dt_shortest,'dt_quickest':dt_quickest,'dt_ed':dt_ed_LD,'dt_iagos_2':dt_iagos_2}
        basefile=pathout+'data_'+str(iagos_id)+'_lev'+str(optim_level)+'_'+str(yr)
        with open(basefile+'.pickle', 'wb') as f: pickle.dump(dico, f)
        #
        #--save dataframe (and overwrite)
        if route != '':
          final_df.to_csv(pathout+str(route)+'_'+str(yr)+'_lev'+str(level)+'.csv')
        else:
          final_df.to_csv(pathout+str(yr)+str(mth).zfill(2)+'_lev'+str(level)+'.csv')
    
def main():
    #  NOTE 
    #  OB: gradient descent
    #  EG: Zermelo method
    
    #  INPUT PARAMETERS FROM COMMAND LINE
    #  example on how to call the python script
    #  python optim_iagos_only.py --yr=2019 --level=200
    #  level=-1 => the script uses the pressure level that is closest to the average IAGOS flight pressure
    #  one can prescribe inputfile, otherwise route, otherwise mth
    
    parser = argparse.ArgumentParser(
            prog="FlightTrajectories",
            description="Optime flight trajectories",
            epilog="""example on how to call the python script:\n"""
                """python optim_iagos_only.py --yr=2019 --level=200\n"""
                """level=-1 => the script uses the pressure level that is closest to the average IAGOS flight pressure \n"""
                """one can prescribe inputfile, otherwise route, otherwise mth.""")

    parser.add_argument('--yr', type=int, choices=[2018,2019,2020,2021,2022], default=2019, help='year')
    parser.add_argument('--mth', type=int, choices=[-1,1,2,3,4,5,6,7,8,9,10,11,12], default=1, help='month')
    parser.add_argument('--inputfile', type=str, default='../data/IAGOS_timeseries_2019010510370802_L2_3.1.0.nc4', help='input file')
    parser.add_argument('--route', type=str, default='', help='route')
    parser.add_argument('--level', type=int, choices=[-1,150,175,200,225,250], default=-1, help='level (hPa)')
    parser.add_argument('--maxiter', type=int, default=100, help='max number of iterations')
    parser.add_argument('--method', type=str, default='SLSQP', choices=['SLSQP','BFGS','L-BFGS-B'], help='minimization method')
    parser.add_argument('--precision', type=str, default='high', choices=['low','high'], help='precision of gradient descent method')
    parser.add_argument('--iagos', type=str, default='/bdd/IAGOS/netcdf/', help='path to the IAGOS files')
    parser.add_argument('--era5', type=str, default='/projsu/cmip-work/oboucher/ERA5/', help='path to the ERA5 files')
    parser.add_argument('--output', type=str, default='/projsu/cmip-work/oboucher/FR24/ROUTE/', help='path to the output folder')

    #--get arguments from command line
    args = parser.parse_args()
    
    #--copy arguments into variables
    yr=args.yr
    mth=args.mth
    inputfile=args.inputfile
    route=args.route
    level=args.level
    maxiter=args.maxiter
    method=args.method
    precision=args.precision
    path_iagos=args.iagos
    path_ERA5=args.era5
    output=args.output
    pathout=os.path.join(output+str(maxiter),method)+'/'
  
    #--print input parameters to output
    print('yr=',yr)
    print('mth=',mth)
    print('inputfile=',inputfile)
    print('route=',route)
    print('level=',level)
    print('maxiter=',maxiter)
    print('method=',method)
    print('precision=',precision)
    print('pathout=',pathout)
    
    #--stop unwanted warnings from xarray
    warnings.filterwarnings("ignore")
    
    #--a little more verbose
    print('We are dealing with IAGOS flights for year '+str(yr))
    if level == -1:
       print('Variable level in optimisation')
    else:
       print('Fixed level in optimisation='+str(level)+' hPa')
    
    if not os.path.exists(pathout): os.makedirs(pathout)

    #--some more important parameters
    #--time sampling of ERA5 data in hours (1 = hourly data)
    Dt_ERA=1
    #--number of m for one leg when discretizing the trajectory
    nbmeters = 50000.
    #--number of best first guess to be used in low precision option
    nbest = 12
    #--typical aircraft airspeed in m/s 
    airspeed = 241.
    #--print out details of the minimization
    disp=False
    #--show plots while running
    pltshow=False
    #pltshow=True
    #
    opti(yr, mth, inputfile, route, level, maxiter, method, precision, path_iagos, path_ERA5, pathout, nbmeters, nbest, airspeed, disp, pltshow, Dt_ERA)
    #
    print('END OF MAIN ROUTINE')

if __name__ == "__main__":
    main()

ModuleNotFoundError: No module named 'FlightTrajectories.misc_geo'

In [None]:
import os
import sys
proj_path1 = os.path.abspath(os.path.join("FlightTrajectories"))
proj_path2 = os.path.abspath(os.path.join("FlightTrajectories/FlightTrajectories"))
if proj_path1 not in sys.path:
    sys.path.append(proj_path1)
if proj_path2 not in sys.path:
    sys.path.append(proj_path2)
from FlightTrajectories.optimalrouting import ZermeloLonLat
from FlightTrajectories.misc_geo import nearest
from FlightTrajectories.minimization import cost_time, wind
import numpy as np
import time
import xarray as xr
import netCDF4
from sklearn.metrics import mean_squared_error
import pandas as pd
from geographiclib.geodesic import Geodesic
import bisect

# From the ipynb driver, user input: ###################################
# User Inputs:
start_time_str       = '2023-01-01T00:00:00Z'
stop_time_str        = '2023-12-31T23:59:59Z'
query_limit          = 15e4
send_notification    = True
make_plot            = True
output_dir           = "/scratch/omg28/Data/"

# Convert start and stop times to datetime objects
start_time_simple = pd.to_datetime(start_time_str).strftime("%Y-%m-%d")
stop_time_simple = pd.to_datetime(stop_time_str).strftime("%Y-%m-%d")
analysis_year = pd.to_datetime(start_time_str).year

# Define grid
lat_bins = np.arange(-90, 90.1, 0.5)
lon_bins = np.arange(-180, 180.1, 0.5)
alt_bins_ft = np.arange(0, 55001, 1000)
alt_bins_m = alt_bins_ft * 0.3048
nlat, nlon, nalt = len(lat_bins)-1, len(lon_bins)-1, len(alt_bins_m)-1
########################################################################

cruise_alt_ft = 35000  # in feet
cruise_speed = 250 #m/s

era5_file = f"/scratch/omg28/Data/winddb/era5_wind_{analysis_year}.nc"
ds_era5 = xr.open_dataset(era5_file).load()


<xarray.Dataset> Size: 498MB
Dimensions:         (valid_time: 12, pressure_level: 5, latitude: 721,
                     longitude: 1440)
Coordinates:
    number          int64 8B 0
  * valid_time      (valid_time) datetime64[ns] 96B 2023-01-01 ... 2023-12-01
  * pressure_level  (pressure_level) float64 40B 300.0 250.0 225.0 200.0 175.0
  * latitude        (latitude) float64 6kB 90.0 89.75 89.5 ... -89.75 -90.0
  * longitude       (longitude) float64 12kB 0.0 0.25 0.5 ... 359.2 359.5 359.8
    expver          (valid_time) <U4 192B '0001' '0001' '0001' ... '0001' '0001'
Data variables:
    u               (valid_time, pressure_level, latitude, longitude) float32 249MB ...
    v               (valid_time, pressure_level, latitude, longitude) float32 249MB ...
Attributes:
    GRIB_centre:             ecmf
    GRIB_centreDescription:  European Centre for Medium-Range Weather Forecasts
    GRIB_subCentre:          0
    Conventions:             CF-1.7
    institution:             European C

In [18]:
month_winds = ds_era5.sel(valid_time= start_time_simple, pressure_level=200, method='nearest')
print(month_winds)

<xarray.Dataset> Size: 8MB
Dimensions:         (latitude: 721, longitude: 1440)
Coordinates:
    number          int64 8B 0
    valid_time      datetime64[ns] 8B 2023-01-01
    pressure_level  float64 8B 200.0
  * latitude        (latitude) float64 6kB 90.0 89.75 89.5 ... -89.75 -90.0
  * longitude       (longitude) float64 12kB 0.0 0.25 0.5 ... 359.2 359.5 359.8
    expver          <U4 16B '0001'
Data variables:
    u               (latitude, longitude) float32 4MB 9.155e-05 ... 9.155e-05
    v               (latitude, longitude) float32 4MB -0.0004425 ... -0.0004425
Attributes:
    GRIB_centre:             ecmf
    GRIB_centreDescription:  European Centre for Medium-Range Weather Forecasts
    GRIB_subCentre:          0
    Conventions:             CF-1.7
    institution:             European Centre for Medium-Range Weather Forecasts


In [None]:
import os
import sys

proj_path1 = os.path.abspath(os.path.join("FlightTrajectories"))
proj_path2 = os.path.abspath(os.path.join("FlightTrajectories/FlightTrajectories"))
if proj_path1 not in sys.path:
    sys.path.append(proj_path1)
if proj_path2 not in sys.path:
    sys.path.append(proj_path2)
import FlightTrajectories.optim_iagos_only

FlightTrajectories.optim_iagos_only.main()

In [None]:


monthly_flights = pd.read_pickle(f'{output_dir}/2023_01_01_to_{stop_time_simple_loop}_labeled.pkl').iloc[244256]

NameError: name 'pd' is not defined