In [None]:
import os, sys, glob, zipfile
from pathlib import Path

# adds the package path to the Python path to make sure all the local imports work fine 
if os.path.dirname(os.path.dirname(os.path.dirname(os.getcwd()))) not in sys.path:
    sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.getcwd()))))

# local imports 
from wp4.constants import POLLUTANTS, DATA_DIR_CAMS_AN, DATA_DIR_CAMS_RE, DATA_DIR_GFAS, DATA_DIR_PLOTS, DB_HOST, DB_NAME, DB_USER, DB_PASS, ADS_URL, ADS_KEY, EXTENTS 
from wp4.baseline.spatial import get_spatial_baseline
from wp4.baseline.temporal import get_temporal_baseline
from wp4.processing.ground_stations import get_closest_ground_station_historical_data

# import remaining packages needed for the script
import xarray as xr
import pandas as pd
import psycopg2
from datetime import datetime, timedelta

import cartopy.crs as ccrs
import matplotlib.pyplot as plt

import plotly.io as pio
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px

import cdsapi

import warnings

# Filters out a warning that pops up when calling the ADS api when using the credentials as parameters
warnings.filterwarnings("ignore", message="Unverified HTTPS request is being made to host ")

## Set the search/plotting criteria

In [None]:
# Time window to select fire events from the database from
time_window = {
    'start':datetime(year=2018, month=5, day=1, hour = 0),
    'end':datetime(year=2018, month=6, day=30, hour = 23)
}

# The number of days for which to include data in the graphs, so for 7 the graph will start at 7 days before the event
# and end 7 days after

DAYS = 7

# Select which datasets you want to include in the graph
FORECAST = False  # If you want to include the European Air Quality Forecast in the graph set this to True
FC_LEADTIME = 24  # How many hours into the future to select from the forecast. 24 hours gives a smooth line
ANALYSIS = False  # If you want to include the European Air Quality NRT Analysis in the graph set this to True
REANALYSIS = True  # If you want to include the European Air Quality Renalysis in the graph set this to True (2018 only) 
GROUND_STATIONS = False  # if you don't want to include the data from the ground stations, set this to False
FRP = True
PM10_WILDFIRES = False
TEMPORAL_BASELINE = False
SPATIAL_BASELINE_AN = False
SPATIAL_BASELINE_RE = False

## Load the fire events from the Flares2 database for the given time window

In [None]:
conn = psycopg2.connect(dbname=DB_NAME, user=DB_USER, password=DB_PASS, host=DB_HOST)
cur = conn.cursor()

query = f"""
    SELECT id, datetime, ST_X(geometry), ST_Y(geometry), source, location, reference, type, info, frp
    FROM public.all_fire_events
    WHERE (reference = 'Aqua' OR reference = 'Terra') AND "frp" IS NOT NULL    
    AND (datetime >= '{time_window['start'].strftime('%Y-%m-%d')}' AND datetime <= '{time_window['end'].strftime('%Y-%m-%d')}');
     
"""

df_fire_events = pd.read_sql_query(query,con=conn).rename(columns = {'st_x':'longitude', 'st_y':'latitude'})

conn.close()

if len(df_fire_events) == 0:
    print('No fire events found matching your search parameters')
else:
    print(f'Fire events loaded from database: {len(df_fire_events)}')

In [None]:
# Combine Forecast & Analysis for the fire event

def create_plot(fig, col, row, name_pollutant, df_fire_event, df_temporal=None,
                df_spatial=None, second_y=False, file_name=None, show_legend=False):
    
    colors = {
        'temporal_median':'#38b000',
        'temporal_range':'#b5e48c',
        'spatial_median':'#5e60ce',
        'spatial_range':'#bde0fe',
        'fire_event':'#e31a1c',
    }
    
    if df_temporal is not None:
                      
        fig.add_trace(go.Scatter(  # add the baseline values
                    x=df_temporal['time'],
                    y=df_temporal['temporal_baseline_lower_quartile'].rolling(window=3, min_periods=1).mean(),
                    mode='lines', 
                    name='25th percentile (Temporal)',
                    line={'color': colors['temporal_range']}),col=col, row=row
                )

        fig.add_trace(go.Scatter(  # add the baseline values
                    x=df_temporal['time'],
                    y=df_temporal['temporal_baseline_upper_quartile'].rolling(window=3, min_periods=1).mean(),
                    fill='tonexty', # fill area between trace0 and trace1
                    mode='lines',
                    name='75th percentile (Temporal)',
                    line={'color': colors['temporal_range']}),col=col, row=row
                )

        fig.add_trace(go.Scatter(  # add the baseline values
                    x=df_temporal['time'],
                    y=df_temporal['temporal_baseline_median'].rolling(window=3, min_periods=1).mean(),
                    mode='lines',
                    name='Median (Temporal)',
                    line={'color': colors['temporal_median']}),col=col, row=row
                )
    
    if df_spatial is not None:
        fig.add_trace(go.Scatter(  # add the baseline values
                    x=df_spatial['time'],
                    y=df_spatial['spatial_baseline_lower_quartile'].rolling(window=3, min_periods=1).mean(),
                    mode='lines',
                    legendgroup=f"group_baseline",
                    name='25th percentile (Spatial)',
                    line={'color': colors['spatial_range']}),col=col, row=row
                )

        fig.add_trace(go.Scatter(  # add the baseline values
                    x=df_spatial['time'],
                    y=df_spatial['spatial_baseline_upper_quartile'].rolling(window=3, min_periods=1).mean(),
                    fill='tonexty', # fill area between trace0 and trace1
                    mode='lines',
                    legendgroup=f"group_baseline",
                    name='75th percentile (Spatial)',
                    line={'color': colors['spatial_range']}),col=col, row=row
                )

        fig.add_trace(go.Scatter(  # add the baseline values
                    x=df_spatial['time'],
                    y=df_spatial['spatial_baseline_median'].rolling(window=3, min_periods=1).mean(),
                    mode='lines',
                    name='Median (Spatial)',
                    legendgroup=f"group_baseline",
                    line={'color': colors['spatial_median']}),col=col, row=row
                )
    
    
    fig.add_trace(go.Scatter(  # add the CAMS analysis concentration data closest to the fireevent
                x=df_fire_event['time'],
                y=df_fire_event['fire_event'].rolling(window=3, min_periods=1).mean(),
                mode='lines',
                legendgroup=f"group_fe",
                name='CAMS Regional Analysis cell closest to fire event',
                showlegend=show_legend,
                line={'color': colors['fire_event']}),col=col, row=row
            )
    
    
    
    if file_name is not None:
        pass
    
    return fig

def ds_to_pandas(ds, timestamp, lat, long, days, var_name, ident=None):
    
    if ident=='Aqua':
        ds = ds.sel(ident=784)
    elif ident=='Terra':
        ds = ds.sel(ident=783)
    
    ds_time = ds.sel(time=slice(
            timestamp.replace(minute=0) - timedelta(days=days),
            timestamp.replace(minute=0) + timedelta(days=days)
        ))
    
    ds_loc = ds_time.sel(
                    latitude=lat,
                    longitude=long,
                    method='nearest'
                )
        
    df = ds_loc.to_dataframe().reset_index()[['time', var_name]]
    
    return df, ds_time

def _to_datetime(dataset: xr.Dataset):
    """Convert the time column of newly downloaded CAMS analysis data into datetime objects"""

    # Strip the start date from the ANALYSIS attribute
    start_date = datetime.strptime(dataset.FORECAST[:16], 'Europe, %Y%m%d')

    def convert_to_datetime(x):  # inner function used to add the timedelta to the start date
        return start_date + x

    dataset['time'] = dataset.time.to_pandas().apply(convert_to_datetime)  # convert the time parameter to datetime

    return dataset

In [None]:
if FC_LEADTIME > 24:
    filename = 'forecast.zip'
else:
    filename = 'forecast.nc'

In [None]:
# Download forecast data from the ADS

if FORECAST:
    
    if not os.path.exists('temp'):
        os.makedirs('temp')
            
    c = cdsapi.Client(
        url=ADS_URL,
        key=ADS_KEY
    )

    start = (time_window['start'] - timedelta(days=DAYS)).strftime('%Y-%m-%d')
    end = (time_window['end'] + timedelta(days=DAYS)).strftime('%Y-%m-%d')

    if FC_LEADTIME > 24:
        filename = 'temp/forecast.zip'
    else:
        filename = 'temp/forecast.nc'

    c.retrieve(
        'cams-europe-air-quality-forecasts',
        {
            'variable': ['carbon_monoxide', 'nitrogen_dioxide', 'nitrogen_monoxide',
                         'ozone', 'particulate_matter_10um', 'particulate_matter_2.5um',
                         'sulphur_dioxide',
                        ],
            'model': 'ensemble',
            'level': '0',
            'date': f"{start}/{end}",
            'type': 'forecast',
            'time': '00:00',
            'leadtime_hour': [str(hour) for hour in range(0, FC_LEADTIME)],
            'area': EXTENTS['IRELAND']['LIST'],
            'format': 'netcdf',
        },
        filename)


    if FC_LEADTIME > 24:
        with zipfile.ZipFile(filename, 'r') as zip_ref:
            zip_ref.extractall(f'temp/')

        forecast_files = glob.glob('temp/*.nc')
        forecast_collection = [_to_datetime(xr.open_dataset(x)) for x in forecast_files]
    else:
        ds_forecast = xr.open_dataset(filename)
        ds_forecast = _to_datetime(ds_forecast)

In [None]:
if FORECAST:
    if FC_LEADTIME > 24:

            with zipfile.ZipFile(filename, 'r') as zip_ref:
                zip_ref.extractall(f'temp/')

            forecast_files = glob.glob('temp/*.nc')
            forecast_collection = [_to_datetime(xr.open_dataset(x)) for x in forecast_files]
    else:
        ds_forecast = xr.open_dataset(filename)
        ds_forecast = _to_datetime(ds_forecast)

In [None]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go

if FRP:
    ds_frp = xr.open_dataset(Path(DATA_DIR_GFAS).joinpath('frp.nc'))
    
if PM10_WILDFIRES:
    ds_wfpm10 = xr.open_dataset(Path(DATA_DIR_CAMS_AN).joinpath('pmwf_conc.nc'))


plot_position = {
    0:{'row':1, 'col':1},
    1:{'row':1, 'col':2},
    2:{'row':1, 'col':3},
    3:{'row':2, 'col':1},
    4:{'row':2, 'col':2},
    5:{'row':2, 'col':3},
    6:{'row':3, 'col':1},
}
    

for ind, fe in df_fire_events.iterrows():
    
    fig = make_subplots(
        rows=3,
        cols=3,
        subplot_titles = [POLLUTANTS[x]['FORMULA_HTML'] for x in POLLUTANTS],
        specs=[[{"secondary_y": FRP}, {"secondary_y": FRP}, {"secondary_y": FRP}],
               [{"secondary_y": FRP}, {"secondary_y": FRP}, {"secondary_y": FRP}],
               [{"secondary_y": FRP}, {"secondary_y": FRP}, {"secondary_y": FRP}],
              ]
        
    )
    
    for ind, pollutant in enumerate(POLLUTANTS):
        
        row = plot_position[ind]['row']
        col = plot_position[ind]['col']

        name_pollutant = POLLUTANTS[pollutant]['FULL_NAME']
        name_pollutant_html = POLLUTANTS[pollutant]['FORMULA_HTML']
        name_pollutant_var = POLLUTANTS[pollutant]['CAMS']
        
        
        # Update xaxis properties
        fig.update_xaxes(title_text="Date",col=col, row=row)

        # Update yaxis properties
        fig.update_yaxes(title_text=f"{name_pollutant_html} Concentration (µg m<sup>-3</sup>)",col=col, row=row)
        
        
        try:
            
            ds_an = xr.open_dataset(Path(DATA_DIR_CAMS_AN).joinpath(f'{pollutant}/cams_nrt_analysis_{pollutant}_2018.nc')) 
                
            df_fire_event, ds_fire_event = ds_to_pandas(
                ds_an,
                fe['datetime'],
                fe['latitude'],
                fe['longitude'],
                DAYS,
                name_pollutant_var
            )
            
            df_fire_event = df_fire_event.rename(columns={name_pollutant_var:'fire_event'})
            
            if SPATIAL_BASELINE_AN:
                df_spatial_baseline, ds_fe, _, _ = get_spatial_baseline(
                    fe_lat=fe['latitude'],
                    fe_long=fe['longitude'],
                    timestamp=fe['datetime'],
                    days=DAYS,
                    pollutant=pollutant,
                    meteo_dataset='MERA',
                    min_distance_km=50,
                    max_distance_km=250,
                    number_of_neighbours=50,
                    mask_ocean=True,   
                )
            else:
                df_spatial_baseline = None
            
            if TEMPORAL_BASELINE:
                df_temporal_baseline = get_temporal_baseline(
                    fe_lat=fe['latitude'],
                    fe_long=fe['longitude'],
                    timestamp=fe['datetime'],
                    days=DAYS,
                    pollutant=pollutant,
                )
            else:
                df_temporal_baseline = None
            
            # Inititate the plotly figure
            
            if ind == (len(POLLUTANTS) - 1):
                leg = True
            else:
                leg = False
            
            fig = create_plot(
                fig,
                col,
                row,
                name_pollutant,
                df_fire_event,
                show_legend=leg,
                df_temporal=df_temporal_baseline, 
                df_spatial=df_spatial_baseline,
                second_y=FRP
            )
            
            
            
            if GROUND_STATIONS:
                data = get_closest_ground_station_historical_data(
                    lat=fe['latitude'],
                    long=fe['longitude'],
                    timestamp=fe['datetime'],
                    pollutant=name_pollutant_var,
                    quantity=3
                )
                
                if data is None:
                    pass
                else:
                    for gs in data:
                        
                        df_gs = data[gs]['data']
                        name = data[gs]['name']
                        distance  = data[gs]['distance']
                        
                        if distance > 1000:
                            distance = round(distance / 1000)
                            trace_name = f'Ground Station ({name}, distance: {distance}KM)'
                        else:
                            distance = round(distance)
                            trace_name = f'Ground Station ({name}, distance: {distance}M)'
                        
                        fig.add_trace(go.Scatter(  # add the CAMS analysis concentration data closest to the fireevent
                            x=df_gs['time'],
                            y=df_gs['ground_station_data'].rolling(window=3, min_periods=1).mean(),
                            mode='lines',
                            name=trace_name,
                            legendgroup=trace_name),col=col, row=row
                                     )

            if REANALYSIS:
                
                ds_re = xr.open_dataset(Path(DATA_DIR_CAMS_RE).joinpath(f'{pollutant}/cams_reanalysis_{pollutant}_2018.nc'))
                
                df_re_fe, ds_re_time = ds_to_pandas(
                        ds_re,
                        fe['datetime'],
                        fe['latitude'],
                        fe['longitude'],
                        DAYS,
                        name_pollutant_var
                    )
                
                if ind == (len(POLLUTANTS) - 1):
                
                    fig.add_trace(go.Scatter(  # add the CAMS analysis concentration data closest to the fireevent
                            x=df_re_fe['time'],
                            y=df_re_fe[name_pollutant_var].rolling(window=3, min_periods=1).mean(),
                            mode='lines',
                            name='Reanalysis',
                            legendgroup=f"reanalysis",
                            showlegend=True,
                            line={'color': '#1f78b4'}),col=col, row=row
                        )
                else:
                    fig.add_trace(go.Scatter(  # add the CAMS analysis concentration data closest to the fireevent
                            x=df_re_fe['time'],
                            y=df_re_fe[name_pollutant_var].rolling(window=3, min_periods=1).mean(),
                            mode='lines',
                            name='Reanalysis',
                            showlegend=False,
                            legendgroup=f"reanalysis",
                            line={'color': '#1f78b4'}),col=col, row=row
                        )
            
            if PM10_WILDFIRES:
                df_wfpm10_fe, ds_wfpm10_time = ds_to_pandas(
                        ds_wfpm10,
                        fe['datetime'],
                        fe['latitude'],
                        fe['longitude'],
                        DAYS,
                        'pmwf_conc'
                    )
                
                fig.add_trace(go.Scatter(  # add the CAMS analysis concentration data closest to the fireevent
                        x=df_wfpm10_fe['time'],
                        y=df_wfpm10_fe['pmwf_conc'].rolling(window=3, min_periods=1).mean(),
                        mode='lines',
                        name='Wildfire PM10',
                        legendgroup=f"group_wf",
                        line={'color': 'orange'}),col=col, row=row
                    )

            
            if FRP:
                df_frp_fe, ds_frp_time = ds_to_pandas(
                    ds_frp,
                    fe['datetime'],
                    fe['latitude'],
                    fe['longitude'],
                    DAYS,
                    'frpfire',
                    ident=fe['reference']
                )
                
                fig.update_yaxes(title_text=f"FRP", secondary_y=True, col=col, row=row)
                
                max_frp = df_frp_fe['frpfire'].max()

                if max_frp == 0:
                    max_frp = 1
                
                fig.update_yaxes(range=[0, max_frp], secondary_y=True,)
                
                if ind == (len(POLLUTANTS) - 1):
                
                    fig.add_trace(go.Scatter(  # add the CAMS analysis concentration data closest to the fireevent
                            x=df_frp_fe['time'],
                            y=df_frp_fe['frpfire'].rolling(window=3, min_periods=1).mean(),
                            mode='lines',
                            name='FRP',
                            showlegend=True,
                            legendgroup=f"group_frp",
                            line={'color': '#ff7f00'}),
                            col=col,
                            row=row,
                            secondary_y=True,
                        )
                else:
                
                    fig.add_trace(go.Scatter(  # add the CAMS analysis concentration data closest to the fireevent
                            x=df_frp_fe['time'],
                            y=df_frp_fe['frpfire'].rolling(window=3, min_periods=1).mean(),
                            mode='lines',
                            name='FRP',
                            showlegend=False,
                            legendgroup=f"group_frp",
                            line={'color': '#ff7f00'}),
                            col=col,
                            row=row,
                            secondary_y=True,
                                 )
            
            if FORECAST:
                if FC_LEADTIME > 24:
                    for ind, fc_file in enumerate(forecast_collection):
                        df_fc_fe, df_fc_time = ds_to_pandas(
                            fc_file,
                            fe['datetime'],
                            fe['latitude'],
                            fe['longitude'],
                            DAYS,
                            name_pollutant_var
                        )

                        start_date = datetime.strptime(fc_file.FORECAST[:16], 'Europe, %Y%m%d')

                        fig.add_trace(go.Scatter(  # add the CAMS analysis concentration data closest to the fireevent
                            x=df_fc_fe['time'],
                            y=df_fc_fe[name_pollutant_var],
                            mode='lines',
                            legendgroup=f"group_{ind}",
                            name=f'Forecast {start_date.strftime("%d %B")}'),col=col, row=row                              
                        )
                else:

                    df_fc_fe, df_fc_time = ds_to_pandas(
                        ds_forecast,
                        fe['datetime'],
                        fe['latitude'],
                        fe['longitude'],
                        DAYS,
                        name_pollutant_var
                    )

                    fig.add_trace(go.Scatter(  # add the CAMS analysis concentration data closest to the fireevent
                        x=df_fc_fe['time'],
                        y=df_fc_fe[name_pollutant_var],
                        mode='lines',
                        name='Forecast',
                        legendgroup=f"group_1",
                        line={'color': 'pink'}),col=col, row=row
                    )
            
            fig.add_vline(x=fe['datetime'], line_width=3, line_dash="dash", line_color="orange", col=col, row=row)


        except Exception as e:
            raise
            print(f'Skipping fire {ind} because of the following error: {e}')
    
    fig.update_layout( legend=dict(  # position the legend
        yanchor="top",
        y=-0.2,
        xanchor="left",
        x=0.01,
        orientation="h"
    )
    )
    
    fig.update_layout(legend=dict(font = dict(size = 25, color = "black")))
    
    fig.update_layout(
        template=pio.templates["seaborn"],
        autosize=False,
        width=1000,
        height=500
    )  
    
    fig.update_layout( legend=dict(  # position the legend
        yanchor="top",
        y=-0.2,
        xanchor="left",
        x=0.01)
    )
    
    fig.update_layout(
        autosize=False,
        width=1600,
        height=900,
        margin=dict(
            l=0,
            r=0,
            b=100,
            t=100,
            pad=4
        ),
        )
    fig.show()
    
    output_loc_fig = Path(DATA_DIR_PLOTS).joinpath('notebooks/analysis_vs_reanalysis/plot/')
    
    if not os.path.exists(output_loc_fig):
        os.makedirs(output_loc_fig)
        
    fig.write_html(output_loc_fig.joinpath(f"analysis_vs_reanalysis_{fe['id']}_{fe['datetime'].strftime('%m_%d_%Y')}.html"))