<a href="https://colab.research.google.com/github/SM24-Industrial-Software-Dev/ML-forecasting-NOx-levels/blob/AP-28-Dash-API/Dash-App/Dash_Implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install dash plotly holidays pycrs import-ipynb
!pip install git+https://github.com/rameshnatarajanus/sts-jax@get-fit_hmc-to-work
import ee
from google.colab import userdata

credentials = ee.ServiceAccountCredentials("yeshiva-summer-2024-1@yu-summer-2024.iam.gserviceaccount.com", key_data=userdata.get('GCP_CREDENTIALS'))
ee.Initialize(credentials = credentials, project='yu-summer-2024', opt_url='https://earthengine-highvolume.googleapis.com')
!git clone https://github.com/SM24-Industrial-Software-Dev/ML-forecasting-NOx-levels.git
!cp ML-forecasting-NOx-levels/Data-APIs/receive_conc_api.py .
!cp ML-forecasting-NOx-levels/Data-APIs/msa.py .
!cp ML-forecasting-NOx-levels/Data-APIs/forecaster.py .

from msa import MSA
from receive_conc_api import get_Data
from forecaster import NOxForecaster

In [None]:
from dash import Dash, dcc, html, Input, Output
import plotly.express as px
import plotly.graph_objects as go
import holidays
import dateutil
import datetime
from datetime import datetime, timedelta

import pandas as pd
from plotly.subplots import make_subplots


nox_data = pd.DataFrame()
fetch_history = {}

msa = MSA()
names = msa.names
names.sort()

app = Dash(__name__)

tab_selected_style = {
    'border': '2px solid #1915eb',
    'backgroundColor': 'black',
    'color': 'white',
    'padding': '6px'
}
tab_style = {
    'border': '1px solid grey',
    'backgroundColor': 'grey',
    'padding': '6px',
    'color': '#200031'
}
holiday_abbr = holiday_abbr = {
    'Independence Day': 'IND',
    'Independence Day (observed)': '',
    'Columbus Day': 'COL',
    'Veterans Day': 'VET',
    'Veterans Day (observed)': '',
    "New Year's Day": 'NYD',
    "New Year's Day (observed)": '',
    'Juneteenth National Independence Day': '',
    'Juneteenth National Independence Day (observed)': '',
    'Christmas Day (observed)': '',
    'Martin Luther King Jr. Day': 'MLK',
    "Washington's Birthday": 'PRS',
    'Memorial Day': 'MEM',
    'Labor Day': 'LAB',
    'Thanksgiving': 'THK',
    'Christmas Day': 'CHR',
    'Easter': 'eas'
}


def get_custom_holidays(years):
    custom_holidays = holidays.HolidayBase()
    for year in years:
        easter = dateutil.easter.easter(year)
        custom_holidays.append({
            easter: "Easter",
        })
    return custom_holidays


app.layout = html.Div([
    dcc.Dropdown(id='city-dd',
                 options=[
                     {'label': country, 'value': country} \
                     for country in names],
                 multi=True,
                 placeholder="Select a City",
                 style={'width': '500px', 'margin': '0 auto',
                        'background-color': 'rgb(192,192,192)'}),

    dcc.DatePickerRange(id='date-range',
                        start_date='2018-06-28',
                        end_date='2024-06-30', #TODO make this different for the max allowed and the initial visible month
                        style={'padding': '10px 450px', 'background-color': 'rgb(119,136,153)'}
                        ),
    dcc.Tabs(id="tabs-for-graph", value='hist', children=[
        dcc.Tab(label='Historical Data', value='hist', style=tab_style, selected_style=tab_selected_style),
        # dcc.Tab(label='Model', value='model', style=tab_style, selected_style=tab_selected_style),
        dcc.Tab(label='Forecasts', value='forecast', style=tab_style, selected_style=tab_selected_style)
    ]),
    dcc.RadioItems(id='select-graph', value='daily', inline=True, options=[
        {'label': 'Daily', 'value': 'daily'},
        {'label': 'Monthly Average', 'value': 'monthly'},
        {'label': 'Yearly', 'value': 'yearly'}
    ]),
    dcc.Graph(id='time-series')
],

    style={'background-color': 'rgb(119,136,153)'})


def find_non_overlapping_ranges(given_start_date, given_end_date, existing_start_date, existing_end_date):
    given_range = (given_start_date, given_end_date)
    existing_range = (existing_start_date, existing_end_date)

    overlap = (given_start_date <= existing_end_date and given_start_date >= existing_start_date) or \
              (existing_start_date <= given_end_date and existing_start_date >= given_start_date)

    if not overlap:
        non_overlapping_ranges = [given_range]
    else:
        non_overlapping_ranges = []
        if given_start_date < existing_start_date:
            non_overlapping_ranges.append((given_start_date, existing_start_date - timedelta(days=1)))
        if given_end_date > existing_end_date:
            non_overlapping_ranges.append((existing_end_date + timedelta(days=1), given_end_date))

    return non_overlapping_ranges


def sift_data(unfiltered_data, selected_cities, start_date, end_date):
    if selected_cities:
        filtered_data = unfiltered_data[unfiltered_data.NAME.isin(selected_cities)]
    else:
        filtered_data = unfiltered_data.copy(deep=True)
    if start_date:
        filtered_data = filtered_data[filtered_data.DATE > start_date]
    if end_date:
        filtered_data = filtered_data[filtered_data.DATE < end_date]

    return filtered_data


def filter_data(selected_cities, start_date, end_date, cloud_mask=0.3):
  global nox_data
  global fetch_history
  for city in selected_cities:
    if city not in fetch_history:
      for starting_date, ending_date in\
      find_non_overlapping_ranges(start_date, end_date, datetime(2022, 5, 1), datetime(2022, 5, 1)):
        new_data = get_Data(msa.get_msas(city), starting_date, ending_date, cloud_mask)
        new_data['NAME'] = city
        nox_data = pd.concat([nox_data, new_data], ignore_index=True)
      fetch_history[city] = {start_date: end_date}
    else:
      for existing_start, existing_end in fetch_history[city].items():
        non_overlapping = find_non_overlapping_ranges(start_date, end_date, existing_start, existing_end)
        for start, end in non_overlapping:
          for starting_date, ending_date in\
          find_non_overlapping_ranges(start_date, end_date, datetime(2022, 5, 1), datetime(2022, 5, 1)):
            new_data = get_Data(msa.get_msas(city), starting_date, ending_date, cloud_mask)
            new_data['NAME'] = city
            nox_data = pd.concat([nox_data, new_data], ignore_index=True)
          if start < existing_start:
            fetch_history[city] = {start: existing_end}
          if end > existing_end:
            fetch_history[city] = {existing_start: end}
  nox_data['DATE'] = pd.to_datetime(nox_data['DATE'])
  nox_data.drop_duplicates(subset=['DATE', 'NAME', 'conc'], inplace=True)
  return sift_data(nox_data, selected_cities, start_date, end_date)


def get_forecast_data(selected_city):
  # get full range of Sentinal-5P data
  image_collection = ee.ImageCollection("COPERNICUS/S5P/OFFL/L3_NO2")
  date_range = image_collection.reduceColumns(ee.Reducer.minMax(), ['system:time_start'])
  first_date = ee.Date(date_range.get('min')).format('YYYY-MM-dd').getInfo()
  last_date = ee.Date(date_range.get('max')).format('YYYY-MM-dd').getInfo()
  first_date, last_date = pd.to_datetime([first_date, last_date])
  full_data = filter_data([selected_city], first_date, last_date)
  full_data.rename(columns={'DATE': 'date', 'conc':'nox-concentration'}, inplace=True)
  full_data = full_data[['date', 'nox-concentration']]
  forecast = NOxForecaster(full_data)
  seasonal_dummy = forecast.fit_dummy_seasonal_model()
  forecast_means, forecast_scales = forecast.get_forecast(*seasonal_dummy, 14)
  forecast_df = pd.DataFrame(forecast_means)
  date_range = pd.date_range(start=full_data['date'].max() + pd.Timedelta(days=1), periods=14, freq='D')
  forecast_df['date'] = date_range
  forecast_df.rename(columns={0: 'nox-concentration'}, inplace=True)
  forecast_df['upper_bound'] = forecast_df['nox-concentration'] + 2 * forecast_scales
  forecast_df['lower_bound'] = forecast_df['nox-concentration'] - 2 * forecast_scales
  return forecast_df, full_data

@app.callback(
    [Output('time-series', 'figure'),
     Output('select-graph', 'style')],
    [Input('city-dd', 'value'),
     Input('date-range', 'start_date'),
     Input('date-range', 'end_date'),
     Input('select-graph', 'value'),
     Input('tabs-for-graph', 'value')]
)
def update_county(selected_cities, start_date, end_date, selected_view, selected_tab):
    tabs = {'display': 'block' if selected_tab == 'hist' else 'none'}
    if not selected_cities:
        return px.line(), tabs

    # Convert the dates to datetime objects
    start_date = pd.to_datetime(start_date)
    end_date = pd.to_datetime(end_date)

    global nox_data
    # Filter the data based on the selected cities, start date, and end date

    if selected_tab == 'hist':
        nox_data = filter_data(selected_cities, start_date, end_date)


        if selected_view == 'daily':
            # daily plot view
            data = nox_data
            figure = px.line(data, x='DATE', y='conc',
                             hover_data=['NAME', 'DATE', 'conc', 'DOW', 'DOY'],
                             color='NAME', markers=False, labels={'nox-concentration': 'NO2 concentration'},
                             title='tropospheric_NO2_column_number_density (mol/m^2)', height=450)

            # Add US holidays
            years = range(start_date.year, end_date.year + 1)
            us_holidays = holidays.US(years=years) + get_custom_holidays(years)
            holidays_in_range = [(pd.Timestamp(date), name) for date, name in us_holidays.items()
                                 if start_date <= pd.Timestamp(date) <= end_date]
            for holiday_date, holiday_name in holidays_in_range:
                figure.add_vline(x=holiday_date, line=dict(color='red', dash='dash', width=1))
                figure.add_annotation(
                    x=holiday_date,
                    y=1, yref='paper',
                   showarrow=False, text=holiday_abbr[holiday_name],
                    xanchor='left', textangle=-90,
                    font=dict(color='red')
                )
        elif selected_view == 'monthly':
            # For the Monthly Plot
            Monthly_nox = nox_data.copy(deep=True)
            Monthly_nox.set_index('DATE', inplace=True)
            Monthly_nox = (Monthly_nox.groupby('NAME')['conc'].resample('M').mean().reset_index())
            filtered_data = sift_data(Monthly_nox, selected_cities, start_date, end_date)
            figure = px.line(filtered_data, x='DATE', y='conc',
                             hover_data=['NAME', 'DATE', 'conc'],
                             color='NAME', markers=True, labels={'conc': 'NO2 concentration'},
                             title='tropospheric_NO2_column_number_density (mol/m^2)', height=450)

            # Generate a list of dates that start new years
            year_start_dates = pd.date_range(start=start_date, end=end_date, freq='YS')

            # Add vertical lines at each new year
            for year_date in year_start_dates:
                figure.add_vline(x=year_date, line=dict(color='red', dash='dash', width=1))
                figure.add_annotation(
                    x=year_date,
                    y=1, yref='paper',
                    showarrow=False, text=str(year_date.year),
                    xanchor='left', textangle=-90,
                    font=dict(color='red')
                )
        else:
            # Year over year view
            nox_data['Year'] = pd.DatetimeIndex(nox_data['DATE']).year
            nox_data['City-Year'] = nox_data['NAME'] + ' ' + nox_data['Year'].astype(str)
            filtered_data = sift_data(nox_data, selected_cities, start_date, end_date)
            figure = px.line(filtered_data, x='DOY', y='conc',
                             hover_data=['NAME', 'DATE', 'conc', 'DOW', 'DOY', 'Year'],
                             color='City-Year', markers=False,
                             labels={'conc': 'NO2 concentration', 'DOY': 'Day of Year'},
                             title='tropospheric_NO2_column_number_density (mol/m^2)')
            selected_cities = selected_cities if selected_cities else nox_data['NAME'].unique()
            figure = make_subplots(rows=len(selected_cities), cols=1, shared_xaxes=True,
                                   subplot_titles=selected_cities)

            for i, city in enumerate(selected_cities):
                filtered_data = nox_data[nox_data['NAME'] == city]
                city_fig = px.line(filtered_data, x='DOY', y='conc',
                                   hover_data=['NAME', 'DATE', 'conc', 'DOW', 'DOY', 'Year'],
                                   color='City-Year', markers=False,
                                   labels={'conc': 'NO2 concentration', 'DOY': 'Day of Year'})

                for trace in city_fig['data']:
                    figure.add_trace(trace, row=i + 1, col=1)
            figure.update_layout(height=400 * len(selected_cities))

    elif selected_tab == 'forecast':
      forecast_df, historical_df = get_forecast_data(selected_cities[0])

      figure = px.line(historical_df, x='date', y='nox-concentration',
                       labels={'nox-concentration': 'NO2 concentration'},
                       title='tropospheric_NO2_column_number_density (mol/m^2)', height=450)
      figure.add_traces([
          go.Scatter(x=forecast_df['date'], y=forecast_df['nox-concentration'],
                     mode='lines', name='Forecast', marker=dict(color='orange')),
          go.Scatter(x=forecast_df['date'], y=forecast_df['upper_bound'],
                     mode='lines', line=dict(width=0), showlegend=False),
          go.Scatter(x=forecast_df['date'], y=forecast_df['lower_bound'],
                     fill='tonexty', line=dict(width=0), showlegend=False,
                     fillcolor='rgba(255, 165, 0, 0.3)', mode='lines')
      ])
    else:
        figure = px.line()

    figure.update_layout({'paper_bgcolor': 'rgb(44,44,44)', 'font': {'color': 'white'},
                      'title': {'x': 0.45, 'xanchor': 'center'}}, yaxis={'title': None})
    return figure, tabs


if __name__ == '__main__':
    app.run_server(debug=True)#, jupyter_mode="external")
