In [None]:
import os, sys, glob
from pathlib import Path

# adds the package path to the Python path to make sure all the local imports work fine 
if os.path.dirname(os.path.dirname(os.path.dirname(os.getcwd()))) not in sys.path:
    sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.getcwd()))))

# local imports 
from wp4.constants import POLLUTANTS, DATA_DIR_CAMS_AN, DATA_DIR_GFAS, DATA_DIR_PLOTS, DB_HOST, DB_NAME, DB_USER, DB_PASS
from wp4.baseline.spatial import get_spatial_baseline
from wp4.baseline.temporal import get_temporal_baseline
from wp4.baseline.spatiotemporal import get_spatiotemporal_baseline

# import remaining packages needed for the script
import xarray as xr
import pandas as pd
import psycopg2 
from datetime import datetime, timedelta

import cartopy.crs as ccrs
import matplotlib.pyplot as plt

import plotly.io as pio
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px

In [None]:
conn = psycopg2.connect(dbname=DB_NAME, user=DB_USER, password=DB_PASS, host=DB_HOST)
cur = conn.cursor()

query = """
    SELECT id, datetime, ST_X(geometry), ST_Y(geometry), source, location, reference, type, info, frp
    FROM public.all_fire_events
    WHERE reference = 'Aqua' OR reference = 'Terra' AND "frp" > 50 AND "frp" IS NOT NULL
"""

df_fire_events = pd.read_sql_query(query,con=conn).rename(columns = {'st_x':'longitude', 'st_y':'latitude'})

conn.close()

In [None]:
def create_plot(name_pollutant, df_fire_event, df_temporal=None, df_spatial=None, df_sptp = None, second_y=False, file_name=None):
    
    # Create figure with secondary y-axis
    fig = make_subplots(specs=[[{"secondary_y": second_y}]])

    # Update xaxis properties
    fig.update_xaxes(title_text="Date")
    
    # Update yaxis properties
    fig.update_yaxes(title_text=f"{name_pollutant} Concentration µg m<sup>-3</sup>")
    fig.update_yaxes(title_text=f"FRP", secondary_y=True)
    
    fig.update_layout(
        template=pio.templates["seaborn"],
        autosize=False,
        width=1000,
        height=500
    )  
    
    fig.update_layout(legend=dict(  # position the legend
        yanchor="top",
        y=-0.2,
        xanchor="left",
        x=0.01)
    )
    
    colors = {
        'temporal_median':'#38b000',
        'temporal_range':'#b5e48c',
        'spatial_median':'#5e60ce',
        'spatial_range':'#bde0fe',
        'spatiotemporal_median':'#822160',
        'spatiotemporal_range':'#F881C6',
        'fire_event':'red',
    }
    
    if df_temporal is not None:
                      
        fig.add_trace(go.Scatter(  # add the baseline values
                    x=df_temporal['time'],
                    y=df_temporal['temporal_baseline_lower_quartile'].rolling(window=3, min_periods=1).mean(),
                    mode='lines', 
                    name='25th percentile (Temporal)',
                    line={'color': colors['temporal_range']}),
                )

        fig.add_trace(go.Scatter(  # add the baseline values
                    x=df_temporal['time'],
                    y=df_temporal['temporal_baseline_upper_quartile'].rolling(window=3, min_periods=1).mean(),
                    fill='tonexty', # fill area between trace0 and trace1
                    mode='lines',
                    name='75th percentile (Temporal)',
                    line={'color': colors['temporal_range']}),
                )

        fig.add_trace(go.Scatter(  # add the baseline values
                    x=df_temporal['time'],
                    y=df_temporal['temporal_baseline_median'].rolling(window=3, min_periods=1).mean(),
                    mode='lines',
                    name='Median (Temporal)',
                    line={'color': colors['temporal_median']}),
                )
    
    if df_spatial is not None:
        fig.add_trace(go.Scatter(  # add the baseline values
                    x=df_spatial['time'],
                    y=df_spatial['spatial_baseline_lower_quartile'].rolling(window=3, min_periods=1).mean(),
                    mode='lines',
                    name='25th percentile (Spatial)',
                    line={'color': colors['spatial_range']}),
                )

        fig.add_trace(go.Scatter(  # add the baseline values
                    x=df_spatial['time'],
                    y=df_spatial['spatial_baseline_upper_quartile'].rolling(window=3, min_periods=1).mean(),
                    fill='tonexty', # fill area between trace0 and trace1
                    mode='lines',
                    name='75th percentile (Spatial)',
                    line={'color': colors['spatial_range']}),
                )

        fig.add_trace(go.Scatter(  # add the baseline values
                    x=df_spatial['time'],
                    y=df_spatial['spatial_baseline_median'].rolling(window=3, min_periods=1).mean(),
                    mode='lines',
                    name='Median (Spatial)',
                    line={'color': colors['spatial_median']}),
                )
    
    
    fig.add_trace(go.Scatter(  # add the CAMS analysis concentration data closest to the fireevent
                x=df_fire_event['time'],
                y=df_fire_event['fire_event'].rolling(window=3, min_periods=1).mean(),
                mode='lines',
                name='CAMS Regional Analysis cell closest to fire event',
                line={'color': colors['fire_event']}),
            )
    
    if df_sptp is not None:
        fig.add_trace(go.Scatter(  # add the baseline values
                    x=df_sptp['time'],
                    y=df_sptp['spatiotemporal_baseline_lower_quartile'].rolling(window=3, min_periods=1).mean(),
                    mode='lines',
                    name='25th percentile (Spatiotemporal)',
                    line={'color': colors['spatiotemporal_range']}),
                )

        fig.add_trace(go.Scatter(  # add the baseline values
                    x=df_sptp['time'],
                    y=df_sptp['spatiotemporal_baseline_upper_quartile'].rolling(window=3, min_periods=1).mean(),
                    fill='tonexty', # fill area between trace0 and trace1
                    mode='lines',
                    name='75th percentile (Spatiotemporal)',
                    line={'color': colors['spatiotemporal_range']}),
                )

        fig.add_trace(go.Scatter(  # add the baseline values
                    x=df_sptp['time'],
                    y=df_sptp['spatiotemporal_baseline_median'].rolling(window=3, min_periods=1).mean(),
                    mode='lines',
                    name='Median (Spatiotemporal)',
                    line={'color': colors['spatiotemporal_median']}),
                )
    
    if file_name is not None:
        pass
    
    return fig

def ds_to_pandas(ds, timestamp, lat, long, days, var_name, ident=None):
    
    if ident=='Aqua':
        ds = ds.sel(ident=784)
    elif ident=='Terra':
        ds = ds.sel(ident=783)
    
    ds_time = ds.sel(time=slice(
            timestamp.replace(minute=0) - timedelta(days=days),
            timestamp.replace(minute=0) + timedelta(days=days)
        ))
    
    ds_loc = ds_time.sel(
                    latitude=lat,
                    longitude=long,
                    method='nearest'
                )
        
    df = ds_loc.to_dataframe().reset_index()[['time', var_name]]
    
    return df, ds_time

def map_ds(ds_fe, pol_name, fe_long, fe_lat, title=None):
    
    ds = ds_fe.mean(dim='time')
    ds['longitude'] = ds['longitude'].values - 360
    
    fig, ax = plt.subplots(figsize=(13, 13))
    ax = plt.axes(projection=ccrs.Orthographic(8.7, 49.9))
    ax.coastlines()
    
    ds[pol_name].plot(
                transform=ccrs.PlateCarree(),
                robust=True,
                extend='neither',
                facecolor="gray"
            )
    
    sc = ax.scatter(
        fe_long,
        fe_lat,
        transform=ccrs.PlateCarree(),
        marker="X",
        c='red',
        s=25,
        alpha=1)
    
    ax.set_title(title)
    
    plt.show()
    
    return fig

In [None]:
pollutant = 'PM10'
name_pollutant = POLLUTANTS[pollutant]['FULL_NAME']
name_pollutant_var = POLLUTANTS[pollutant]['CAMS']
ds_wfpm10 = xr.open_dataset(Path(DATA_DIR_CAMS_AN).joinpath('pmwf_conc.nc'))
ds_frp = xr.open_dataset(Path(DATA_DIR_GFAS).joinpath('frp.nc'))

DAYS = 7

for ind, fe in df_fire_events.tail(40).iterrows():
    
    try:
        df_spatial_baseline, ds_fe, _, _ = get_spatial_baseline(
            fe_lat=fe['latitude'],
            fe_long=fe['longitude'],
            timestamp=fe['datetime'],
            pollutant=pollutant,
            days=DAYS,
            meteo_dataset='MERA',
            min_distance_km=50,
            max_distance_km=250,
            number_of_neighbours=50,
            mask_ocean=True,   
        )
        
        df_temporal_baseline = get_temporal_baseline(
            fe_lat=fe['latitude'],
            fe_long=fe['longitude'],
            timestamp=fe['datetime'],
            days=DAYS,
            pollutant=pollutant,
        )
        
        df_sptp_baseline = get_spatiotemporal_baseline(
            fe_lat=fe['latitude'],
            fe_long=fe['longitude'],
            timestamp=fe['datetime'],
            days=DAYS, # time window around fire event
            pollutant=pollutant,
            meteo_dataset='MERA',
            min_distance_km=50,
            max_distance_km=250,
            number_of_neighbours=50,
#             upwind_downwind='upwind',
            mask_ocean=True,          
        )
        
        df_fire_event = df_spatial_baseline[['time', 'fire_event']]
        
        fig = create_plot(name_pollutant, df_fire_event,
                          df_temporal=df_temporal_baseline, 
                          df_spatial=df_spatial_baseline,
                          df_sptp=df_sptp_baseline,
                          second_y=True
                         )
        
        
        df_wfpm10_fe, ds_wfpm10_time = ds_to_pandas(ds_wfpm10, fe['datetime'], fe['latitude'], 360 + fe['longitude'], DAYS, 'pmwf_conc')
        df_frp_fe, ds_frp_time = ds_to_pandas(ds_frp, fe['datetime'], fe['latitude'], fe['longitude'], DAYS, 'frpfire', ident=fe['reference'])
        
#         fig.add_trace(go.Scatter(  # add the CAMS analysis concentration data closest to the fireevent
#                 x=df_wfpm10_fe['time'],
#                 y=df_wfpm10_fe['pmwf_conc'].rolling(window=3, min_periods=1).mean(),
#                 mode='lines',
#                 name='Wildfire PM10',
#                 line={'color': 'orange'}),
#             )
        
        max_frp = df_frp_fe['frpfire'].max()
        
        if max_frp == 0:
            max_frp = 1
        
        fig.update_yaxes(range=[0, max_frp], secondary_y=True,)
        
        fig.add_trace(go.Scatter(  # add the CAMS analysis concentration data closest to the fireevent
                x=df_frp_fe['time'],
                y=df_frp_fe['frpfire'].rolling(window=3, min_periods=1).mean(),
                mode='lines',
                name='FRP',
                line={'color': 'black'}),
                secondary_y=True,
            )
        
        fig.add_vline(x=fe['datetime'], line_width=3, line_dash="dash", line_color="orange")
        
        fig.show()
        
        output_loc_fig = Path(DATA_DIR_PLOTS).joinpath('notebooks/baseline_vs_FRP/plot/')
        
        if not os.path.exists(output_loc_fig):
            os.makedirs(output_loc_fig)

        fig.write_html(output_loc_fig.joinpath(f"pm10_FRP_fire_{fe['id']}_{fe['datetime'].strftime('%m_%d_%Y')}.html"))
        
        ds_fe = ds_fe.sel(time=slice(
                fe['datetime'].replace(minute=0),
                fe['datetime'].replace(minute=0) + timedelta(hours=12)
            ))
        
        fig = map_ds(ds_fe, name_pollutant_var, fe_long=fe['longitude'], fe_lat=fe['latitude'])
        
        output_loc_map = Path(DATA_DIR_PLOTS).joinpath('notebooks/baseline_vs_FRP/map/')
        
        if not os.path.exists(output_loc_map):
            os.makedirs(output_loc_map)    
        
        fig.savefig(output_loc_map.joinpath(f"pm10_FRP_fire_{fe['id']}_{fe['datetime'].strftime('%m_%d_%Y')}.png"))
        
    except Exception as e:
        raise
        print(f'Skipping fire {ind} because of the following error: {e}')