In [None]:
import os, sys, glob

# adds the package path to the Python path to make sure all the local imports work fine 
if os.path.dirname(os.path.dirname(os.path.dirname(os.getcwd()))) not in sys.path:
    sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.getcwd()))))

# local imports
from wp4.constants import POLLUTANTS, DATA_DIR_CAMS, DB_HOST, DB_NAME, DB_USER, DB_PASS
from wp4.baseline.spatial import get_spatial_baseline

# import remaining packages needed for the script
import xarray as xr
import pandas as pd
import psycopg2 
from datetime import datetime, timedelta

import cartopy.crs as ccrs
import matplotlib.pyplot as plt

import plotly.io as pio
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px

## Initiate connection and load fire events from the Postgres database

In [None]:
conn = psycopg2.connect(dbname=DB_NAME, user=DB_USER, password=DB_PASS, host=DB_HOST)
cur = conn.cursor()

query = """
    SELECT id, datetime, ST_X(geometry), ST_Y(geometry), source, location, reference, type, info
    FROM public.fire_events
    WHERE reference = 'Aqua' OR reference = 'Terra'
"""

df_fire_events = pd.read_sql_query(query,con=conn).rename(columns = {'st_x':'longitude', 'st_y':'latitude'})

conn.close()

In [None]:
def create_plot(df_baseline, name_pollutant, file_name=None):
    fig = go.Figure()  # init a plotly figure

    # Update xaxis properties
    fig.update_xaxes(title_text="Date")
    
    # Update yaxis properties
    fig.update_yaxes(title_text=f"{name_pollutant} Concentration µg m<sup>-3</sup>")
    
    fig.update_layout(
        template=pio.templates["seaborn"],
        autosize=False,
        width=1000,
        height=500
    )  
    
    fig.update_layout(legend=dict(  # position the legend
        yanchor="top",
        y=-0.2,
        xanchor="left",
        x=0.01)
    )
    
    colors = ["#29bf12", "#abff4f", "#08bdbd", "#f21b3f", "#ff9914"]
                      
    fig.add_trace(go.Scatter(  # add the baseline values
                x=df_baseline['time'],
                y=df_baseline['spatial_baseline_lower_quartile'].rolling(window=3, min_periods=1).mean(),
                mode='lines',
                name='25th percentile',
                line={'color': "#29bf12"}),
            )
            
    fig.add_trace(go.Scatter(  # add the baseline values
                x=df_baseline['time'],
                y=df_baseline['spatial_baseline_upper_quartile'].rolling(window=3, min_periods=1).mean(),
                fill='tonexty', # fill area between trace0 and trace1
                mode='lines',
                name='75th percentile',
                line={'color': "#29bf12"}),
            )
            
    fig.add_trace(go.Scatter(  # add the baseline values
                x=df_baseline['time'],
                y=df_baseline['spatial_baseline_median'].rolling(window=3, min_periods=1).mean(),
                mode='lines',
                name='Median',
                line={'color': "#48cae4"}),
            )
    
    fig.add_trace(go.Scatter(  # add the CAMS analysis concentration data closest to the fireevent
                x=df_baseline['time'],
                y=df_baseline['fire_event'].rolling(window=3, min_periods=1).mean(),
                mode='lines',
                name='CAMS Regional Analysis cell closest to fire event',
                line={'color': colors[3]}),
            )
    
    if file_name is not None:
        pass
    
    return fig

def create_nn_map(ds_fe, df_nn, df_fe_information, name_pollutant):
    ds = ds_fe.max(dim='time')
    ds['longitude'] = ds['longitude'].values - 360 
    
    fig, ax = plt.subplots(figsize=(13, 13))
    ax = plt.axes(projection=ccrs.Orthographic(8.7, 49.9))
    ax.coastlines()
    
    ds[name_pollutant].plot(
                transform=ccrs.PlateCarree(),
                robust=True,
                extend='neither',
                facecolor="gray"
            )
    
    # create a list of Artists to provide handles to plt.legend
    scatters = [ax.scatter(
        cell['longitude'],
        cell['latitude'],
        transform=ccrs.PlateCarree(),
        marker="o",
        c='orange',
        s=15,
        alpha=1) for ind, cell in df_nn.iterrows()]
    
    sc = ax.scatter(
        df_fe_information['longitude'].iloc[0],
        df_fe_information['latitude'].iloc[0],
        transform=ccrs.PlateCarree(),
        marker="X",
        c='red',
        s=25,
        alpha=1)
    
    
    plt.show()

In [None]:
pollutant = 'PM10'
name_pollutant = POLLUTANTS[pollutant]['FULL_NAME']
name_pollutant_var = POLLUTANTS[pollutant]['CAMS']

for ind, fe in df_fire_events.head(1).iterrows():
    
    try:
        df_baseline, ds_fe, df_nn, df_fe_information = get_spatial_baseline(
            fe['latitude'],
            fe['longitude'],
            fe['datetime'],
            5,
            pollutant,
            meteo_dataset='MERA',
            min_distance_km=20,
            max_distance_km=75,
            number_of_neighbours=50,
#             upwind_downwind='downwind',
            mask_ocean=True,
            
        )

        plot = create_plot(df_baseline, name_pollutant)
        plot.show()
        
        create_nn_map(ds_fe, df_nn, df_fe_information, name_pollutant_var)
        
        
    except Exception as e:
        print(f'Skipping fire {ind} because of the following error: {e}')