# General functions

In [2]:
def comparison_check(sensor, model, component_nom, model_full_name, sensor_type, apply_kernels):

    """ Check if the comparison is possible

        Args:
            sensor (str): Name of the sensor
            model (str): Name of the model
            component_nom (str): Component chemical nomenclature
            model_full_name (str): Full name of the CAMS model among:
            - 'cams-global-atmospheric-composition-forecasts' 
            - 'cams-global-reanalysis-eac4-monthly'
            sensor_type (str): Sensor type
    """

    if ((sensor == 'tropomi' and sensor_type == 'L2' and model == 'cams' and model_full_name == 'cams-global-atmospheric-composition-forecasts') or
        (sensor == 'tropomi' and sensor_type == 'L3' and model == 'cams' and model_full_name == 'cams-global-reanalysis-eac4-monthly') or 
        (sensor == 'iasi' and sensor_type == 'L2' and model == 'cams' and model_full_name == 'cams-global-atmospheric-composition-forecasts') or
        (sensor == 'iasi' and sensor_type == 'L3' and model == 'cams' and model_full_name == 'cams-global-reanalysis-eac4-monthly') or
        (sensor == 'gome' and sensor_type == 'L2' and model == 'cams' and model_full_name == 'cams-global-atmospheric-composition-forecasts') or
        (sensor == 'gome' and sensor_type == 'L3' and model == 'cams' and model_full_name == 'cams-global-reanalysis-eac4-monthly')):

        if (model_full_name != 'cams-global-atmospheric-composition-forecasts' and
            model_full_name != 'cams-global-reanalysis-eac4-monthly'):

            print('ERROR: The model is not supported.')
            print('The models that are currently supported are:')
            print('- cams-global-atmospheric-composition-forecasts')
            print('- cams-global-reanalysis-eac4-monthly')
            raise KeyboardInterrupt()

        else:
            
            tropomi_L2_kernels_component_nom = ['NO2']
            tropomi_L2_component_nom = ['NO2', 'CO', 'O3', 'SO2', 'HCHO']
            tropomi_L3_component_nom = ['NO2']
            iasi_L2_component_nom = ['CO', 'O3', 'SO2']
            iasi_L3_component_nom = ['CO', 'O3']
            gome_L2_component_nom = ['NO2', 'O3', 'HCHO', 'SO2']
            gome_L3_component_nom = ['NO2']
            
            if ((sensor == 'tropomi' and sensor_type == 'L2' and component_nom not in tropomi_L2_component_nom) or
                (sensor == 'tropomi' and sensor_type == 'L3' and component_nom not in tropomi_L3_component_nom) or
                (sensor == 'iasi' and sensor_type == 'L2' and component_nom not in iasi_L2_component_nom) or
                (sensor == 'iasi' and sensor_type == 'L3' and component_nom not in iasi_L3_component_nom) or
                (sensor == 'gome' and sensor_type == 'L2' and component_nom not in gome_L2_component_nom) or
                (sensor == 'gome' and sensor_type == 'L3' and component_nom not in gome_L3_component_nom)):

                print(f'ERROR: This specific component cannot be retrieved by the sensor {sensor.upper()} ({sensor_type}).')
                raise KeyboardInterrupt()

            elif ((apply_kernels == True and sensor == 'tropomi' and component_nom not in tropomi_L2_kernels_component_nom) or
                  (apply_kernels == True and sensor != 'tropomi')):

                 print('ERROR: It is only possible to apply the averaging kernels from the TROPOMI observations to the CAMS forecasts for NO2 and SO2.')
                 print('Please set the variable apply_kernels to False.') 
                 raise KeyboardInterrupt()

            else:

                print('The comparison is possible and will start now.')
    else:

        print('Currently, it is possible to compare:')
        print('1 - Forecast data from CAMS model and L2 data from TROPOMI, IASI and GOME-2 sensors.')
        print('2 - Reanalysis data from CAMS model and L3 monthly data from TROPOMI, IASI and GOME-2 sensors.')

        raise KeyboardInterrupt()

In [3]:
def components_table(sensor, component_nom, sensor_type):

    """ Create table with information about the components (molecular weight, full name in different datasets)

        Args:
            sensor (str): Name of the sensor
            component_nom (str): Component chemical nomenclature
            sensor_type (str): Sensor type

        Returns:
            component (str): Component name
            component_mol_weight (float): Component molecular weight
            component_sensor_product (str): Component product name in TROPOMI, IASI or GOME-2 database
            sensor_column (str): Component column name in TROPOMI, IASI or GOME-2 database
            column_type (str): Tropospheric or total column
    """

    sensor_product_type = None

    component_nom_col = ['NO2', 'CO', 'O3', 'SO2', 'CH4', 'HCHO', 'NH3']

    component_col = ['nitrogen_dioxide', 'carbon_monoxide', 'ozone', 'sulphur_dioxide', 
                     'methane', 'formaldehyde', 'ammonia']
    component_mol_weight_col = [46.005, 28.01, 48, 64.066, 
                                16.04, 30.031, 17.031]
    component_tropomi_L3_column_col = ['NO2trop', '-', '-', '-', '-', '-', '-']
    component_tropomi_L2_column_col = ['nitrogendioxide_tropospheric_column', 
                                       'carbonmonoxide_total_column', 
                                       'ozone_total_vertical_column', 
                                       'sulfurdioxide_total_vertical_column',
                                       'methane_tropospheric_column',
                                       'formaldehyde_tropospheric_vertical_column',
                                       '-'
                                       ]
    component_tropomi_L2_product_col = ['L2__NO2___', 'L2__CO____', 'L2__O3____', 'L2__SO2___', 
                                        'L2__CH4___', 'L2__HCHO__', '-']
    component_iasi_L3_column_col = ['-', 'COgridDAY', 'O3gridDAY', '-', '-', '-', 'NH3gridDAY']
    component_iasi_L2_column_col = ['-', 'CO_total_column', 'O3_total_column', 'SO2_all_altitudes', '-', '-', '']
    component_gome_L3_column_col = ['NO2trop', '-', '-', '-', '-', '-', '-']
    component_gome_L2_column_col = ['NO2trop', '-', 'O3total', 'SO2total', '-', 'HCHOtotal', '-']

    rows = {'Nomenclature': component_nom_col, 
            'Weight': component_mol_weight_col,
            'Component': component_col, 
            'TROPOMI_L3_column': component_tropomi_L3_column_col,
            'TROPOMI_L2_product': component_tropomi_L2_product_col,
            'TROPOMI_L2_column': component_tropomi_L2_column_col,
            'IASI_L3_column': component_iasi_L3_column_col,
            'IASI_L2_column': component_iasi_L2_column_col,
            'GOME_L3_column': component_gome_L3_column_col,
            'GOME_L2_column': component_gome_L2_column_col}

    components_table = pd.DataFrame(rows)

    component = components_table['Component'].loc[components_table['Nomenclature'] == component_nom].iloc[0]
    component_mol_weight = components_table['Weight'].loc[components_table['Nomenclature'] == component_nom].iloc[0]
    
    if sensor == 'tropomi' and sensor_type == 'L2':
        sensor_product_type = components_table['TROPOMI_L2_product'].loc[components_table['Nomenclature'] == component_nom].iloc[0]

    sensor_column = components_table[sensor.upper() + '_' + sensor_type + '_column'].loc[components_table['Nomenclature'] == component_nom].iloc[0]

    if 'trop' in sensor_column:
        column_type = 'tropospheric'

    else:
        column_type = 'total'

    return component, component_mol_weight, sensor_product_type, sensor_column, column_type

In [4]:
def generate_folders(model, sensor, component_nom, sensor_type):

    """ Generate folders to download the datasets if they do not exist 

        Args:
            model (str): Name of the model
            sensor (str): Name of the sensor
            component_nom (str): Component chemical nomenclature
            sensor_type (str): Sensor type
    """

    # Model data path
    model_path = os.path.join('/', '/'.join(os.getcwd().split('/')[1:3]), 'adc-toolbox', os.path.relpath('data/' + model + '/' + component_nom))

    # Sensor data path
    if sensor_type == 'L3':
        sensor_path = os.path.join('/', '/'.join(os.getcwd().split('/')[1:3]), 'adc-toolbox', os.path.relpath('data/' + sensor + '/' + component_nom + '/L3/'))
        
    elif sensor_type == 'L2':
        sensor_path = os.path.join('/', '/'.join(os.getcwd().split('/')[1:3]), 'adc-toolbox', os.path.relpath('data/' + sensor + '/' + component_nom + '/L2/'))

    # Generate paths
    paths = [model_path, sensor_path]
    for path in paths:
        os.makedirs(path, exist_ok = True)

In [5]:
def search_period(start_date, end_date, sensor, sensor_type):

    """ Give list or tuple with dates that will be used to download the datasets

        Args:
            start_date (str): Query start date
            end_date (str): Query end date
            sensor (str): Name of the sensor
            sensor_type (str): Sensor type

        Returns:
            dates (list or tuple): Query dates
    """

    print('SEARCH PERIOD')

    range_dt = pd.date_range(np.datetime64(start_date), np.datetime64(end_date))

    if (sensor == 'gome' and sensor_type == 'L2') or (sensor == 'iasi' and sensor_type == 'L2'):
        dates = tuple(np.unique([date.strftime('%Y-%m-%d') for date in range_dt]))

    elif ((sensor == 'gome' and sensor_type == 'L3') or (sensor == 'iasi' and sensor_type == 'L3') or
         (sensor == 'tropomi' and sensor_type == 'L3')):
        dates = tuple(np.unique([date.strftime('%Y-%m') for date in range_dt]))

    elif sensor == 'tropomi' and sensor_type == 'L2':
        range_dt_initial = range_dt
        range_dt_final = range_dt_initial + dt.timedelta(hours = 23)
        dates = list(zip([date.strftime('%Y-%m-%dT%H:%M:%SZ') for date in range_dt_initial], 
                         [date.strftime('%Y-%m-%dT%H:%M:%SZ') for date in range_dt_final]))

    if sensor_type == 'L2':
        print(f'- In days: {dates}')

    elif sensor_type == 'L3':
        print(f'- In months: {dates}')

    return dates

In [None]:
def search_bbox(lon_min, lat_min, lon_max, lat_max):

    """ Generate bounding box from coordinates
        
        Args:
            lon_min (float): Minimum longitude
            lat_min (float): Minimum latitude
            lon_max (float): Maximum longitude
            lat_max (float): Maximum latitude
            
        Returns:
            bbox (arr): Query bounding box

    """

    bbox = ((lon_min, lat_min), (lon_max, lat_max))

    print('SEARCH BOUNDING BOX')
    print(f'Latitudes: from {lat_min} to {lat_max}')
    print(f'Longitudes: from {lon_min} to {lon_max}')

    return bbox

In [None]:
def available_period(sensor, sensor_type, dates, component_nom, *args):

    """ Remove dates if the folders where the dataset had to be downloaded are empty (dataset not available)
        
        Args:
            sensor (str): Name of the sensor
            sensor_type (str): Sensor type
            dates (list): Query dates
            component_nom (str): Component chemical nomenclature
            *args: satellites
            
        Returns:
            dates (list or tuple): Available dates

    """
        
    dates_to_delete = []

    for date in dates:
        
        if sensor_type == 'L3':
            output_path = os.path.join('/', '/'.join(os.getcwd().split('/')[1:3]), 'adc-toolbox', 
                          os.path.relpath('data/' + sensor + '/' + component_nom + '/L3/' + date))

        elif sensor_type == 'L2':
            output_path = os.path.join('/', '/'.join(os.getcwd().split('/')[1:3]), 'adc-toolbox', 
                          os.path.relpath('data/' + sensor + '/' + component_nom + '/L2/' + date))

        if not os.listdir(output_path):
            os.rmdir(output_path)
            dates_to_delete.append(date)

        if sensor_type == 'L2' and sensor == 'gome':
            for satellite in satellites:
                if not os.listdir(output_path + '/' + satellite):
                    os.rmdir(output_path+ '/' + satellite)

    dates_to_keep = np.setdiff1d(dates, np.array(dates_to_delete))
    dates = tuple(dates_to_keep)
    
    return dates

In [6]:
def sensor_download(sensor, sensor_type, component_nom, dates, *args):

    """ Download sensor datasets

        Args:
            sensor (str): Name of the sensor
            sensor_type (str): Sensor type
            component_nom (str): Component chemical nomenclature
            dates (list or tuple): Available dates
            *args: bbox, satellites or product_type

        Returns:
            dates (list): Available dates
    """ 

    print('RESULTS')
    
    if sensor == 'tropomi':

        for date in dates:

            if sensor_type == 'L2':

                print(f'For {date}:')
                input_type = 'Query'
                TROPOMI_L2_download(input_type, bbox, date, product_type, component_nom)

            elif sensor_type == 'L3':
                TROPOMI_L3_download(date, component_nom)
        
        if sensor_type == 'L3':
            dates = available_period(sensor, sensor_type, dates, component_nom)

    elif sensor == 'iasi' or sensor == 'gome':
        
        for date in dates:
            
            print(f'For {date}:')

            for satellite in satellites:

                if sensor == 'iasi' and sensor_type == 'L2':
                    IASI_L2_download(component_nom, date, satellite)

                elif sensor == 'iasi' and sensor_type == 'L3':
                    IASI_L3_download(component_nom, date, satellite)
        
                elif sensor == 'gome' and sensor_type == 'L2':
                    GOME_L2_download(component_nom, date, satellite)

                elif sensor == 'gome' and sensor_type == 'L3':
                    #GOME_L3_download_AC_SAF(component_nom, date, satellite)
                    GOME_L3_download_TEMIS(component_nom, date, satellite)

        dates = available_period(sensor, sensor_type, dates, component_nom, satellites)

    return dates

In [7]:
def sensor_read(sensor, sensor_type, sensor_column, component_nom, dates, *args):

    """ Read sensor datasets as xarray dataset objects

        Args:
            sensor (str): Name of the sensor
            sensor_type (str): Sensor type
            component_nom (str): Component chemical nomenclature
            dates (list or tuple): Available dates
            *args: satellites, lat_res, lon_res

        Returns:
            sensor_ds (xarray): sensor dataset in xarray format
            support_input_ds (xarray): TROPOMI dataset that contains support input data in xarray format
            support_details_ds (xarray): TROPOMI dataset that contains support details data in xarray format
    """ 
    
    support_input_ds = None
    support_details_ds = None

    if dates:

        if sensor == 'tropomi' and sensor_type == 'L2':
            sensor_ds, support_input_ds, support_details_ds = TROPOMI_L2_read(component_nom, sensor_column, dates)
        
        elif sensor == 'tropomi' and sensor_type == 'L3':
            sensor_ds = TROPOMI_L3_read(component_nom, dates, lat_res, lon_res)

        elif sensor == 'iasi' and sensor_type == 'L2':
            sensor_ds = IASI_L2_read(component_nom, sensor_column, dates, lat_res, lon_res)

        elif sensor == 'iasi' and sensor_type == 'L3':
            sensor_ds = IASI_L3_read(component_nom, sensor_column, dates, lat_res, lon_res)

        elif sensor == 'gome' and sensor_type == 'L2':
            sensor_ds = GOME_L2_read(component_nom, dates, lat_res, lon_res)

        elif sensor == 'gome' and sensor_type == 'L3':
            #sensor_ds = GOME_L3_read_AC_SAF(component_nom, sensor_column, dates, lat_res, lon_res)
            sensor_ds = GOME_L3_read_TEMIS(component_nom, dates, lat_res, lon_res)

    else:
        print('The datasets could not be downloaded for the dates that were queried.')
        
    return sensor_ds, support_input_ds, support_details_ds

In [8]:
def sensor_convert_units(sensor_ds, sensor, component_nom):

    """ Convert the units of the sensor dataset for any component from mol/m2 to molecules/cm2

        Args:
            sensor_ds (xarray): sensor dataset in xarray format (TROPOMI, IASI or GOME-2)
            sensor (str): Name of the sensor
            component_nom (str): Component chemical nomenclature
            
        Returns:
            sensor_ds (xarray): sensor dataset in xarray format
    """

    if sensor == 'tropomi':
        
        if sensor_ds['sensor_column'].units == 'mol m-2':

            sensor_ds['sensor_column'] = sensor_ds['sensor_column'] * 6.02214*10**19
            sensor_ds['sensor_column'] = sensor_ds['sensor_column'].assign_attrs({'units': 'molec cm-2'})
            print('The sensor component units have been converted from mol m-2 to molec cm-2.')
            
            if 'apriori_profile' in list(sensor_ds.keys()):
                sensor_ds['apriori_profile'] = sensor_ds['apriori_profile'] * 6.02214*10**19

            if sensor_ds['sensor_column'].units == 'molec cm-2' and component == 'ozone':
                sensor_ds['sensor_column'] = sensor_ds['sensor_column'] / (2.69*10**16)
                sensor_ds['sensor_column'] = sensor_ds['sensor_column'].assign_attrs({'units': 'DU'})
                print('The sensor component units have been converted from molec cm-2 to DU.')

                if 'apriori_profile' in list(sensor_ds.keys()):
                    sensor_ds['apriori_profile'] = sensor_ds['apriori_profile'] / (2.69*10**16)
    
    elif sensor == 'iasi':
        
        if sensor_ds.units == 'mol m-2':

            sensor_ds = sensor_ds * 6.02214*10**19
            sensor_ds = sensor_ds.assign_attrs({'units': 'molec cm-2'})
            print('The sensor component units have been converted from mol m-2 to molec cm-2.')

        if sensor_ds.units == 'molec cm-2' and (component_nom == 'O3' or component_nom == 'SO2'):
            sensor_ds = sensor_ds / (2.69*10**16)
            sensor_ds = sensor_ds.assign_attrs({'units': 'DU'})
            print('The sensor component units have been converted from molec cm-2 to DU.')

    return sensor_ds

In [9]:
def model_convert_units(model, model_ds, sensor, component_mol_weight, model_levels_df, 
                        start_date, end_date, component_nom, apply_kernels = False, 
                        CAMS_UID = None, CAMS_key = None):

    """ Convert the units of the model dataset for any component from kg/kg or kg/m2 to molecules/cm2

        Args:
            model (str): Name of the model
            model_ds (xarray): model dataset in xarray format (CAMS)
            sensor (str): Name of the sensor
            component_mol_weight (float): Component molecular weight
            model_levels_df (dataframe): Table with 137 CAMS levels data
            start_date (str): Query start date
            end_date (str): Query end date
            component_nom (str): Component chemical nomenclature
            apply_kernels (bool): Apply (True) or not (False) the averaging kernels 
            CAMS_UID (str): ADS user ID
            CAMS_key (str): ADS key
            
        Returns:
            model_ds (xarray): model dataset in xarray format
    """

    if model == 'cams':

        if model_ds.component.units == 'kg kg**-1':

            model_ds = CAMS_kg_kg_to_kg_m2(model_ds, model_levels_df, sensor, start_date, 
                                           end_date, component_nom, apply_kernels, CAMS_UID, CAMS_key)
            units = 'kg m**-2'
            model_ds['component'] = model_ds.component.assign_attrs({'units': units})
            print('The model component units have been converted from kg kg**-1 to kg m**-2.')

        if model_ds.component.units == 'kg m**-2':

            model_ds = CAMS_kg_m2_to_molecules_cm2(model_ds, component_mol_weight)
            units = 'molec cm-2'
            model_ds['component'] = model_ds.component.assign_attrs({'units': units})
            print('The model component units have been converted from kg m**-2 to molec cm-2.')
        
        if model_ds.component.units == 'molec cm-2' and (component_nom == 'O3' or component_nom == 'SO2'):

            model_ds = CAMS_molecules_cm2_to_DU(model_ds)
            units = 'DU'
            model_ds['component'] = model_ds.component.assign_attrs({'units': units})
            print('The model component units have been converted from molec cm-2 to DU.')
           
        else:
            units = 'molec cm-2'

    return model_ds, units

In [10]:
def nearest_neighbour(array, value):

    """ Find index of the closest value in a 1D-array

        Args:
            array (arr): Array to find the nearest neighbour
            value (float): Search value
    """

    index = np.abs([x - value for x in array]).argmin(0)
    
    return index

In [11]:
def closest_point(point, array):

    """ Find pair the closest values in a 2D-array

        Args:
            array (arr): Array to find the nearest neighbour
            point (tuple): Search coordinates
    """

    pair = array[cdist([point], array).argmin()]

    return pair

In [12]:
def pairwise(array):

    """ Split array in pairs

        Args:
            array (arr): Dates, coordinates list, etc.

        Returns:
            period (tuple): Divisible dates into pairs
    """

    pair_element = iter(array)
    period = list(zip(pair_element, pair_element))

    return period

In [None]:
def binning(ds, lat_res, lon_res):

    """ Regrid onto a custom defined regular grid

        Args:
            array (xarray): Dataset as xarray 
            lat_res (float): Spatial resolution for latitude
            lon_res (float): Spatial resolution for longitude

        Returns:
            array (xarray): Dataset as xarray (with regridded coordinates)
    """

    lat_bins = np.arange(-90, 90 + lat_res/2, lat_res)
    lon_bins = np.arange(-180, 180 + lon_res/2, lon_res)

    lat_center = np.arange(-90 + lat_res/2, 90, lat_res)
    lon_center = np.arange(-180 + lon_res/2, 180, lon_res)

    ds = ds.groupby_bins('latitude', lat_bins, labels = lat_center).mean()
    ds = ds.groupby_bins('longitude', lon_bins, labels = lon_center).mean()
    ds = ds.rename({'latitude_bins': 'latitude', 'longitude_bins': 'longitude'})
    ds

    return ds

In [13]:
def subset(ds, bbox, sensor, component_nom, sensor_type, subset_type):

    """ Subset any dataset (with latitude and longitude as coordinates) into desired bounding box.

        Args:
            ds (xarray): Dataset in xarray format
            bbox (arr): Query bounding box
            sensor (str): Name of the sensor
            component_nom (str): Component chemical nomenclature
            sensor_type (str): Sensor type
            subset_type (str):
            -  'sensor_subset': Sensor dataset will be subset
            -  'model_subset': Model dataset will be subset
    
        Returns:
            ds (xarray): Dataset in xarray format
    """

    if sensor == 'tropomi' and sensor_type == 'L2' and subset_type == 'sensor_subset':

        ds = TROPOMI_subset(ds, bbox, component_nom)

    else:

        # Get nearest longitude and latitude to bbox
        lon_min_index = nearest_neighbour(ds.longitude.data, bbox[0][0])
        lon_max_index = nearest_neighbour(ds.longitude.data, bbox[1][0])
        lat_min_index = nearest_neighbour(ds.latitude.data, bbox[0][1])
        lat_max_index = nearest_neighbour(ds.latitude.data, bbox[1][1])

        # Define slices
        slice_lat = slice(lat_min_index, lat_max_index + 1)
        slice_lon = slice(lon_min_index, lon_max_index + 1)

        # Set limits
        ds = ds.isel(longitude = slice_lon, latitude = slice_lat)

    return ds

In [14]:
def prepare_df(match_df_time, sensor, component_nom, time, sensor_type):

    """ Prepare dataframe for match

        Args:
            match_df_time (dataframe): Dataframe used to apply averaging kernels
            sensor (str): Name of the sensor
            component_nom (str): Component chemical nomenclature
            time (timestamp): Current time
            sensor_type (str): Sensor type
        
        Returns:
            match_df_time (dataframe): Dataframe used to apply averaging kernels
    """

    if sensor == 'tropomi' and sensor_type == 'L2':

        # Pass NaNs to data with qa_value under 0.5 (these values will be shown as transparent)
        match_df_time.loc[match_df_time['qa_value'] <= 0.5, ['sensor_column', 'column_kernel']] = float('NaN')

        # Drop levels
        if component_nom == 'CO' or component_nom == 'SO2':
            
            match_df_time.index.names = ['corner', 'ground_pixel', 'layer', 'scanline']
        
        elif component_nom == 'O3':

            match_df_time.index.names = ['corner', 'ground_pixel', 'layer', 'level', 'scanline']
            
        match_df_time = match_df_time.groupby(by = ['layer', 'scanline', 'ground_pixel', 'time', 'delta_time']).mean()
        match_df_time = match_df_time.reset_index(level = ['layer', 'delta_time'])

    elif sensor == 'iasi' or sensor == 'gome' or (sensor == 'tropomi' and sensor_type == 'L3'):

        match_df_time = match_df_time.reset_index(level = ['latitude', 'longitude'])
        
        if (sensor == 'gome' or sensor == 'tropomi') and sensor_type == 'L2':

            year = time.astype('datetime64[D]').astype(str).split('-')[0]
            month = time.astype('datetime64[D]').astype(str).split('-')[1]
            day = time.astype('datetime64[D]').astype(str).split('-')[2]
            match_df_time['delta_time'] = match_df_time['delta_time'].fillna(value = dt.datetime(
                                                                                     int(year), 
                                                                                     int(month), 
                                                                                     int(day), 
                                                                                     12, 0, 0))
        
    return match_df_time

In [15]:
def generate_match_df(sensor_ds, model_ds, bbox, sensor, component_nom, sensor_type, apply_kernels = False):

    """ Intermediate merge table with total column or partial column from both datasets, 
        the averaging kernels are applied if possible

        Args:
            sensor_ds (xarray): sensor dataset in xarray format (TROPOMI, IASI or GOME-2)
            model_ds (xarray): model dataset in xarray format (CAMS)
            bbox (arr): Query bounding box
            sensor (str): Name of the sensor
            component_nom (str): Component chemical nomenclature
            apply_kernels (bool): Apply (True) or not (False) the averaging kernels 

        Returns:
            match_df (dataframe): Intermediate merge table with total column or partial column from both datasets
    """
    
    match_df = pd.DataFrame()

    if sensor == 'tropomi' and sensor_type == 'L2' and apply_kernels == True:

        print('APPLICATION OF AVERAGING KERNELS')
        print('For the application of the averaging kernels, it is necessary to calculate:')
        print('1. Level pressures')
        print('2. Column kernels')
        print('The apriori profiles should be retrieved, but they are not necessary.')

        # Calculate TM5 level pressures, column kernels and apriori profiles
        print('DATA AVAILABILITY')
        sensor_ds = TROPOMI_pressure(sensor_ds, component_nom, support_input_ds, support_details_ds)
        sensor_ds = TROPOMI_column_kernel(sensor_ds, component_nom, support_details_ds)
        sensor_ds = TROPOMI_apriori_profile(sensor_ds, component_nom, component, support_details_ds)

    for time in sensor_ds.time.values:
        
        # Print estimated time or month
        if sensor_type == 'L2':
            day = np.datetime64(time).astype('datetime64[D]')
            print(f'FOR DATE: {day}')

        elif sensor_type == 'L3':
            month = np.datetime64(time).astype('datetime64[M]')
            print(f'FOR MONTH: {month}')

        # Reduce data to only one timestamp
        model_ds_time = model_ds.sel(time = time)
        sensor_ds_time = sensor_ds.sel(time = time)

        # Subset sensor dataset
        sensor_ds_time = subset(sensor_ds_time, bbox, sensor, component_nom, 
                                sensor_type, subset_type = 'sensor_subset')
        
        # Transform sensor data into dataframe and prepare it for merging it with the model data
        match_df_time = sensor_ds_time.to_dataframe()
        match_df_time = prepare_df(match_df_time, sensor, component_nom, time, sensor_type)
        
        if sensor == 'tropomi' and 'column_kernel' in list(sensor_ds.keys()) and apply_kernels == True:
            
            match_df_time = TROPOMI_apply_kernels(match_df_time, model_ds_time, sensor_ds_time, component_nom)
            
        else:

            if apply_kernels == True:
                print('The application of the averaging kernels cannot take place because there is not enough data.')
            
            # Get model timesteps
            model_times = model_ds_time.valid_time.data

            if 'hybrid' in list(model_ds.coords):

                print('The partial columns will be sumed up.')
                print('The sum will be matched to the sensor data by nearest neighbours.')

                model_ds_time = model_ds_time.component.sum(dim = 'hybrid', skipna = False)
               
                match_df_time['step_index'] = match_df_time.apply(lambda row: nearest_neighbour(model_times, row['delta_time']), axis = 1)
                match_df_time['model_time'] = match_df_time.apply(lambda row: model_ds_time.valid_time[row['step_index']].values, axis = 1)
                match_df_time['model_column'] = match_df_time.apply(lambda row: model_ds_time.sel(
                                                                                latitude = row['latitude'], 
                                                                                longitude = row['longitude'],
                                                                                method = 'nearest').isel(step = 
                                                                                int(row['step_index'])).values, 
                                                                                axis = 1)
   
                match_df_time = match_df_time.set_index('layer', append = True)
                
            else:

                print('The model dataset does not contain levels data.')
                print('The model dataset will be merged with the sensor dataset by nearest neighbours.')

                # Monthly data
                if 'step' not in list(model_ds.dims):
                    
                    match_df_time['model_column'] = match_df_time.apply(lambda row: float(model_ds_time.sel(
                                                                                    latitude = row['latitude'], 
                                                                                    longitude = row['longitude'],
                                                                                    method = 'nearest').component.values), 
                                                                                    axis = 1)
                # Hourly / Daily data
                else:

                    match_df_time['step_index'] = match_df_time.apply(lambda row: nearest_neighbour(model_times, row['delta_time']), axis = 1)
                    match_df_time['model_column'] = match_df_time.apply(lambda row: float(model_ds_time.sel(
                                                                                    latitude = row['latitude'], 
                                                                                    longitude = row['longitude'],
                                                                                    method = 'nearest').isel(step = 
                                                                                    int(row['step_index'])).component.values), 
                                                                                    axis = 1)

        match_df_time = match_df_time[~match_df_time.index.duplicated()]
        match_df = match_df.append(match_df_time)

    return match_df

In [16]:
def generate_merge_df(match_df, sensor_ds, model_ds, sensor, apply_kernels = False):

    """ Final merge table with total column component data for each dataset, 
        their difference in each grid point are calculated

        Args:
            match_df (dataframe): Intermediate merge table with total column or partial column from both datasets
            sensor_ds (xarray): sensor dataset in xarray format (TROPOMI, IASI or GOME-2)
            model_ds (xarray): model dataset in xarray format (CAMS)
            apply_kernels (bool): Apply (True) or not (False) the averaging kernels
            sensor (str): Name of the sensor
        
        Returns:
            merge_df (dataframe): Merge table with datasets column data and their difference
    """

    merge_df = []

    if 'hybrid' in list(model_ds.coords):

        for time in sensor_ds.time.values:

            match_ds_time = match_df.query('time == @time').to_xarray()

            # Read latitudes and longitudes from data array
            latitude = match_ds_time.sel(time = time).latitude.mean(dim = 'layer')
            longitude = match_ds_time.sel(time = time).longitude.mean(dim = 'layer')

            # Get sum of CAMS data of each layer to get column data
            if 'column_kernel' in list(match_ds_time.keys()) and apply_kernels == True:
                model_final_ds_time = match_ds_time.sel(time = time).model_column.sum(dim = 'layer', skipna = False).astype(float)

            else:
                model_final_ds_time = match_ds_time.sel(time = time).model_column.mean(dim = 'layer', skipna = False).astype(float)

            model_final_ds_time = model_final_ds_time.assign_coords(latitude = latitude, longitude = longitude)

            # Get mean of TROPOMI data of each layer (it must be equal)
            sensor_final_ds_time = match_ds_time.sensor_column.sel(time = time).mean(dim = 'layer', skipna = False).astype(float)
            sensor_final_ds_time = sensor_final_ds_time.assign_coords(latitude = latitude, longitude = longitude)

            merge_ds_time = xr.merge([model_final_ds_time, sensor_final_ds_time])
            merge_ds_time['difference'] = merge_ds_time.model_column - merge_ds_time.sensor_column
            merge_ds_time['relative_difference'] = (merge_ds_time.model_column - merge_ds_time.sensor_column)/merge_ds_time.sensor_column
            merge_df.append(merge_ds_time.to_dataframe())

        merge_df = pd.concat(merge_df)

    else:

        merge_df = match_df
        merge_df['difference'] = merge_df['model_column'] - merge_df['sensor_column']
        merge_df['relative_difference'] = (merge_df.model_column - merge_df.sensor_column)/merge_df.sensor_column

    # Organize dataset for visualization
    if sensor == 'tropomi' and sensor_type == 'L2':
        merge_df = merge_df.reset_index().set_index(['scanline', 'ground_pixel', 'time'])
        merge_df = merge_df[['latitude', 'longitude', 'model_column', 'sensor_column', 'difference', 'relative_difference']]

    else:
        merge_df = merge_df.reset_index().set_index(['latitude', 'longitude', 'time'])
        merge_df = merge_df[['model_column', 'sensor_column', 'difference', 'relative_difference']]
    
    return merge_df

In [17]:
def plot_period(sensor_ds, sensor_type):

    """ Define plot period

        Args:
            sensor_ds (xarray): sensor dataset in xarray format (TROPOMI, IASI or GOME-2)
            sensor_type (str): Sensor type

        Returns:
            plot_dates (arr): Plot dates
    """

    period_answer = input('Do you want to visualize the plots for specific dates? Press Enter for Yes or write No:')
    dates = sensor_ds.time.values

    if period_answer == 'No' or period_answer == 'no':
        plot_dates = dates
    
    else:
        
        plot_dates = []

        for date in dates:
            date_answer = input('Do you want to show the plots for ' + str(date) + '? Press Enter for Yes or write No:')    
            if date_answer == 'No' or date_answer == 'no':
                pass
            else:
                plot_dates.append(date)

        plot_dates = np.array(plot_dates)

    print('The plots will be shown for the following dates:')
    if sensor_type == 'L2':
        print(plot_dates.astype('datetime64[D]'))
    
    elif sensor_type == 'L3':
        print(plot_dates.astype('datetime64[M]'))

    return plot_dates

In [18]:
def plot_extent(bbox):

    """ Define plot extent

        Args:
            bbox (arr): Query bounding box

        Returns:
            plot_bbox (arr): Plot bounding box
    """

    extent_answer = input(f'Do you want to visualize the plots for a specific extent? Press Enter for Yes or write No (default {bbox}):')

    if extent_answer == 'No' or extent_answer == 'no':
        plot_bbox = ((bbox[0][0], bbox[0][1]), (bbox[1][0], bbox[1][1]))

    else:
        # Define minimum longitude
        plot_lon_min = float(input('Write value of minimum longitude: '))
        while (plot_lon_min < bbox[0][0]) or (plot_lon_min > bbox[1][0]):
            print(f'ERROR: Longitude must be between {bbox[0][0]} and {bbox[1][0]}.')
            plot_lon_min = float(input('Write value of minimum longitude (again): '))

        # Define maximum longitude
        plot_lon_max = float(input('Write value of maximum longitude: '))
        while (plot_lon_max < bbox[0][0]) or (plot_lon_max > bbox[1][0]) or (plot_lon_max <= plot_lon_min):
            print(f'ERROR: Longitude must be between {bbox[0][0]} and {bbox[1][0]} and be higher than the minimum {plot_lon_min}.')
            plot_lon_max = float(input('Write value of maximum longitude (again): '))

        # Define minimum latitude
        plot_lat_min = float(input('Write value of minimum latitude: '))
        while (plot_lat_min < bbox[0][1]) or (plot_lat_min > bbox[1][1]):
            print(f'ERROR: Latitude must be between {bbox[0][1]} and {bbox[1][1]}.')
            plot_lat_min = float(input('Write value of minimum latitude (again): '))

        # Define maximum latitude
        plot_lat_max = float(input('Write value of maximum latitude: '))
        while (plot_lat_max < bbox[0][1]) or (plot_lat_max > bbox[1][1]) or (plot_lat_max <= plot_lat_min):
            print(f'ERROR: Latitude must be between {bbox[0][1]} and {bbox[1][1]} and be higher than the minimum {plot_lat_min}.')
            plot_lat_max = float(input('Write value of maximum latitude (again): '))

        # Define plot bbox
        plot_bbox = ((plot_lon_min, plot_lat_min), (plot_lon_max, plot_lat_max))

    print('The plots will be shown for the following spatial extent: ')
    print(plot_bbox)
    
    return plot_bbox

In [19]:
def colorbar_range(range_type, array, diff_array, max_all, min_all, max_all_diff, min_all_diff,
                   vmin_manual, vmax_manual, vmin_manual_diff, vmax_manual_diff):

    """ Define colorbar range

        Args:
            range_type (str): Range type for colorbar:
            -  'original': Show original values in range
            -  'equal': Show same scale in range
            -  'manual': Show scale in range given by user
            -  'centered': Show scale centered in 0
            array (xarray): Component for a specific time and model/sensor
            diff_array (xarray): Difference for a specific time
            min_all (float): Absolute vmin
            max_all (float): Absolute vmax
            min_all_diff (float): Absolute vmin for difference values
            max_all_diff (float): Absolute vmax for difference values 
            vmin_manual (float): Input vmin by user
            vmax_manual (float): Input vmax by user
            vmin_manual_diff (float): Input vmin by user for difference values
            vmax_manual_diff (float): Input vmax by user for difference values
            
        Returns:
            vmin, vmax (float): Limits of color bar
    """
    
    # The colorbar for the absolute difference will be defined
    if np.array_equal(array, diff_array, equal_nan = True) == True:
        
        if vmin_manual_diff == None and vmax_manual_diff == None:
            if np.abs(max_all_diff) >= np.abs(min_all_diff):
                
                vmin = -np.abs(max_all_diff)
                vmax = np.abs(max_all_diff)

            elif np.abs(max_all_diff) < np.abs(min_all_diff):
                
                vmin = -np.abs(min_all_diff)
                vmax = np.abs(min_all_diff)
        else:
            vmin = vmin_manual_diff
            vmax = vmax_manual_diff

    # The colorbar will show the original range
    elif range_type == 'original':
      
        vmin = np.nanmin(array)
        vmax = np.nanmax(array)

    # The colorbar will be in the same scale for both datasets
    elif range_type == 'equal':
       
        vmin = min_all
        vmax = max_all
    
    # The colorbar will be in the scale given by the user
    elif range_type == 'manual':
       
        if vmin_manual == None or vmax_manual == None:
            print('ERROR: vmin_manual and vmax_manual have to be defined and cannot be None.')
            raise KeyboardInterrupt()
            
        else:
            vmin = vmin_manual
            vmax = vmax_manual

    # The colorbar will be centered at 0
    elif range_type == 'centered':
        if np.abs(max_all) >= np.abs(min_all):
        
            vmin = -np.abs(max_all)
            vmax = np.abs(max_all)

        elif np.abs(max_all) < np.abs(min_all):
            
            vmin = -np.abs(min_all)
            vmax = np.abs(min_all)

    return vmin, vmax

In [None]:
def get_frame_possible_lengths(loc_min, loc_max):

    """ Get lengths of the frame partitions (white / black) for which the frame 
        is well adjusted to the minimum and maximum coordinates

        Args:
            loc_min (float): Minimum latitude or longitude of frame (it should be an integer).
            loc_max (float): Maximum latitude or longitude of frame (it should be an integer).

        Returns:
            options (array): Possible lengths that will be well adjusted to the frame
    """

    number = np.abs(loc_max - loc_min)
    options = []

    if 0 < number <= 1:
        number_loop = int(number * 10)
    else:
        number_loop = int(number)

    for divisor in range(1, number_loop + 1):
        if (number_loop % divisor) == 0:
            if 0 < number <= 1 or (0 < number < 1 and (number * 10).is_integer() == True):
                divisor = divisor / 10
            options.append(divisor)

    if options:
        print(f'Frame length between {loc_min} and {loc_max} should be one of these options: {options}')
    else:
        print(f'Frame length suggestions could not be computed. Consider changing your bounding box to be at least 1ºx1º or to be composed by integer coordinates.')
        print(f'Alternatively, you can inactivate the map frame by removing the following line in the function visualize_pcolormesh:')
        print(f'axs = map_frame(axs, lat_min, lat_max, lon_min, lon_max, breaks_lon, breaks_lat, width_lat, width_lon, height_lat, height_lon)')

    return options

In [None]:
def map_frame(axs, lat_min, lat_max, lon_min, lon_max, breaks_lon, breaks_lat, 
              width_lat, width_lon, height_lat, height_lon):

    for i, x in zip(range(len(breaks_lon)), breaks_lon):
        
        color = 'white' if i%2 == 0 else 'black'

        # Horizontal bottom line
        axs.add_patch(mpatches.Rectangle(xy = [x, lat_min - height_lon], width = width_lon, height = height_lon,
                                         facecolor = color, clip_on = False, edgecolor = 'black', lw = 1, ls = 'solid'))

        # Horizontal top line
        axs.add_patch(mpatches.Rectangle(xy = [x, lat_max], width = width_lon, height = height_lon,
                                         facecolor = color, clip_on = False, edgecolor = 'black', lw = 1, ls = 'solid'))

    for i, y in zip(range(len(breaks_lat)), breaks_lat):

        color = 'white' if i%2 == 0 else 'black'

        # Vertical left line
        axs.add_patch(mpatches.Rectangle(xy = [lon_min - width_lat, y], width = width_lat, height = height_lat,
                                         facecolor = color, clip_on = False, edgecolor = 'black', lw = 1, ls = 'solid'))

        # Vertical right line
        axs.add_patch(mpatches.Rectangle(xy = [lon_max, y], width = width_lat, height = height_lat,
                                         facecolor = color, clip_on = False, edgecolor = 'black', lw = 1, ls = 'solid'))

    # Squares at the limits
    axs.add_patch(mpatches.Rectangle(xy = [lon_max, lat_max], width = width_lat, height = height_lon,
                                     facecolor = 'black', clip_on = False, edgecolor = 'black', lw = 1, ls = 'solid'))
    axs.add_patch(mpatches.Rectangle(xy = [lon_min - height_lon, lat_min - width_lat], width = width_lat, height = height_lon,
                                     facecolor = 'black', clip_on = False, edgecolor = 'black', lw = 1, ls = 'solid'))
    axs.add_patch(mpatches.Rectangle(xy = [lon_max, lat_min - width_lat], width = width_lat, height = height_lon,
                                     facecolor = 'black', clip_on = False, edgecolor = 'black', lw = 1, ls = 'solid'))
    axs.add_patch(mpatches.Rectangle(xy = [lon_min - height_lon, lat_max], width = width_lat, height = height_lon,
                                     facecolor = 'black', clip_on = False, edgecolor = 'black', lw = 1, ls = 'solid'))

    return axs

In [None]:
def map_markers(axs, bbox_list, coords_list, regions_names, lat_min, lat_max):

    """ Add markers to show regions, locations or texts above them

        Args:
            axs: Axes of figure
            bbox_list (list): List of search bounding boxes (eg. (lat_min, lat_max, lon_min, lon_max, ...)
            coords_list (list): List of search coordinates (eg. (lat, lon, lat, lon, ...)
            regions_names (list): Region names
            lat_min (float): Minimum latitude
            lat_max (float): Maximum latitude
    """

    # Transform string to tuple (if there is only one element)
    if isinstance(regions_names, str):
        regions_names = tuple([regions_names])

    # Show rectangles
    if bbox_list != None:
        
        regions_lats = pairwise(bbox_list)[0::2]
        regions_lons = pairwise(bbox_list)[1::2]
    
        for region_lats, region_lons in zip(regions_lats, regions_lons):
            
            axs.add_patch(mpatches.Rectangle(xy = [region_lons[0], region_lats[0]], 
                                             width = region_lons[1] - region_lons[0], 
                                             height = region_lats[1] - region_lats[0],
                                             linewidth = 1.5, linestyle = '--',
                                             edgecolor = 'black', fill = False))

    # Show points
    if coords_list != None:    
        coords = pairwise(coords_list)
        for i in range(0, len(coords)):
            axs.scatter(coords[i][1], coords[i][0], c = 'red', s = 12, marker = 'o')
    
    # Show text
    if regions_names != None and coords_list != None:
        coords = pairwise(coords_list)
        for i, region_name in zip(range(0, len(coords)), regions_names):
            axs.annotate(region_name, (coords[i][1], coords[i][0] + np.abs(lat_max - lat_min) / 12), 
            fontsize = 12, ha = 'center', va = 'center')
        
    return axs

In [20]:
def visualize_pcolormesh(fig, axs, data_array, longitude, latitude, projection, color_scale, 
                         pad, long_name, units_name, vmin, vmax, lon_min, lon_max, lat_min, lat_max, 
                         width_lon, height_lat, bbox_list, coords_list, regions_names):
    
    """ Set basic map configuration

        Args:
            fig: Figure
            axs: Axes of figure
            data_array (xarray): Variable values to plot - It must be 2-dimensional
            longitude (arr): Longitudes within data_array
            latitude (arr): Latitudes within data_array
            projection: Geographical projection
            color_scale (list):Name of color scale (e.g. coolwarm) (in order for: model, sensor, difference)
            pad (float): Padding for the subtitles
            long_name (str): Plot name
            units_name (str): Component name and units
            vmin, vmax (float): Limits of color bar
            lon_min, lon_max, lat_min, lat_max (float): Limits of longitude and latitude values
            width_lon (int): Horizontal width of frame individual sections (black - white lines)
            height_lat (int): Vertical height of frame individual sections (black - white lines)
            bbox_list (list): List of search bounding boxes (eg. (lat_min, lat_max, lon_min, lon_max, ...)
            coords_list (list): List of search coordinates (eg. (lat, lon, lat, lon, ...)
            regions_names (list): Region names
    """

    palette = copy(plt.get_cmap(color_scale))
    palette.set_bad(alpha = 0)
    axs.clear()
    im_ind = axs.pcolormesh(longitude, latitude, data_array, 
                            cmap = palette, 
                            transform = ccrs.PlateCarree(),
                            vmin = vmin,
                            vmax = vmax,
                            norm = colors.Normalize(vmin = vmin, vmax = vmax),
                            shading = 'auto'
                            )
                        
    axs.add_feature(cfeature.BORDERS, edgecolor = 'black', linewidth = 1)
    axs.add_feature(cfeature.COASTLINE, edgecolor = 'black', linewidth = 1)

    if projection == ccrs.PlateCarree():
        
        axs.set_extent([lon_min, lon_max, lat_min, lat_max], ccrs.PlateCarree())
        diff = (np.abs(lon_max - lon_min))
        width_lat, height_lon = diff/100, diff/100
        breaks_lon = list(np.arange(lon_min, lon_max, width_lon))
        breaks_lat = list(np.arange(lat_min, lat_max, height_lat))

        gl = axs.gridlines(draw_labels = True, linestyle = '--')
        gl.xlocator = mticker.FixedLocator(breaks_lon[1:])
        gl.ylocator = mticker.FixedLocator(breaks_lat)
        gl.xformatter = LONGITUDE_FORMATTER
        gl.yformatter = LATITUDE_FORMATTER
        gl.xlabel_style = {'size': 13.5}
        gl.ylabel_style = {'size': 13.5}
        gl.xpadding = 10
        gl.ypadding = 10
        gl.right_labels = False
        gl.top_labels = False
        axs = map_frame(axs, lat_min, lat_max, lon_min, lon_max, breaks_lon, breaks_lat,
                        width_lat, width_lon, height_lat, height_lon)

    axs.set_title(long_name, fontsize = 18, pad = pad)
    axs.tick_params(labelsize = 14)
    
    axs = map_markers(axs, bbox_list, coords_list, regions_names, lat_min, lat_max)
    
    if distribution_type != 'animated':
        
        cbr = fig.colorbar(im_ind, ax = axs, extend = 'both', orientation = 'horizontal', 
                           fraction = 0.05, pad = 0.15)   
        cbr.set_label(units_name, fontsize = 16)
        cbr.ax.tick_params(labelsize = 14)
        cbr.ax.xaxis.get_offset_text().set_fontsize(14)
      
    return im_ind

In [21]:
def comparison_maps(fig, axs, merge_ds_time, range_type, sensor, model, sensor_type, model_type, 
                    projection, pad, units_name, plot_bbox, color_scale, max_all, min_all, 
                    min_all_diff, max_all_diff, width_lon, height_lat,
                    vmin_manual, vmax_manual, vmin_manual_diff, vmax_manual_diff, 
                    bbox_list, coords_list, regions_names):

    """ Create 3 plots with:
        -   Component concentration for model data
        -   Component concentration for sensor data
        -   Component concentration for difference data

        Args:
            fig (figure): Plot figure
            axs (axes): Plot axes
            merge_ds_time (xarray): Merge xarray with total column data and their difference at specific time 
            range_type (str): Range type for colorbar:
            -  'original': Show original values in range
            -  'equal': Show same scale in range
            -  'manual': Show scale in range given by user
            -  'centered': Show scale centered in 0
            component_nom (str): Component chemical nomenclature
            sensor (str): Name of the sensor
            model (str): Name of the model
            model_type (str): Model type:
            -  'Forecast'
            -  'Reanalysis'
            sensor_type (str): Sensor type
            projection (projection): Geographical projection
            pad (float): Padding for the subtitles
            units_name (str): Component name and units
            plot_bbox (arr): Plot bounding box
            color_scale (list):Name of color scale (e.g. coolwarm) (in order for: model, sensor, difference)
            min_all (float): Absolute vmin
            max_all (float): Absolute vmax
            min_all_diff (float): Absolute vmin for difference values
            max_all_diff (float): Absolute vmax for difference values 
            width_lon (int): Horizontal width of frame individual sections (black - white lines)
            height_lat (int): Vertical height of frame individual sections (black - white lines)
            vmin_manual (float): Input vmin by user
            vmax_manual (float): Input vmax by user
            vmin_manual_diff (float): Input vmin by user for difference values
            vmax_manual_diff (float): Input vmax by user for difference values
            bbox_list (list): List of search bounding boxes (eg. (lat_min, lat_max, lon_min, lon_max, ...)
            coords_list (list): List of search coordinates (eg. (lat, lon, lat, lon, ...)
            regions_names (list): Region names
    """

    # Difference array
    diff_array = merge_ds_time.difference

    # First plot - CAMS 
    array = merge_ds_time['model_column']
    vmin, vmax = colorbar_range(range_type, array, diff_array, max_all, min_all, 
                                max_all_diff, min_all_diff, vmin_manual, vmax_manual, 
                                vmin_manual_diff, vmax_manual_diff)
    long_name = model.upper() + ' (' + model_type + ')'
    im1 = visualize_pcolormesh(
                               fig = fig, axs = axs[0], 
                               data_array = array,
                               longitude = array.longitude,
                               latitude = array.latitude,
                               projection = projection,
                               color_scale = color_scale[0],
                               pad = pad,
                               long_name = long_name,
                               units_name = units_name,
                               vmin = vmin, 
                               vmax = vmax, 
                               lon_min = plot_bbox[0][0],
                               lon_max = plot_bbox[1][0],
                               lat_min = plot_bbox[0][1],
                               lat_max = plot_bbox[1][1],
                               width_lon = width_lon,
                               height_lat = height_lat,
                               bbox_list = bbox_list, 
                               coords_list = coords_list,
                               regions_names = regions_names
                              )

    # Second plot - TROPOMI, IASI or GOME-2
    array = merge_ds_time['sensor_column']
    vmin, vmax = colorbar_range(range_type, array, diff_array, max_all, min_all, 
                                max_all_diff, min_all_diff, vmin_manual, vmax_manual,
                                vmin_manual_diff, vmax_manual_diff)
    long_name = 'GOME-2' + ' (' + sensor_type + ')' if sensor == 'gome' else sensor.upper() + ' (' + sensor_type + ')'
    im2 = visualize_pcolormesh(
                               fig = fig, axs = axs[1],
                               data_array = array,
                               longitude = array.longitude,
                               latitude = array.latitude,
                               projection = projection,
                               color_scale = color_scale[1],
                               pad = pad,
                               long_name = long_name,
                               units_name = units_name,
                               vmin = vmin,  
                               vmax = vmax, 
                               lon_min = plot_bbox[0][0],
                               lon_max = plot_bbox[1][0],
                               lat_min = plot_bbox[0][1],
                               lat_max = plot_bbox[1][1],
                               width_lon = width_lon,
                               height_lat = height_lat,                     
                               bbox_list = bbox_list, 
                               coords_list = coords_list,
                               regions_names = regions_names
                              )

    # Third plot - Difference
    array = diff_array
    vmin, vmax = colorbar_range(range_type, array, diff_array, max_all, min_all, 
                                max_all_diff, min_all_diff, vmin_manual, vmax_manual,
                                vmin_manual_diff, vmax_manual_diff)
    long_name = 'Difference (' + model.upper() + ' - ' + sensor.upper() + ')'
    im3 = visualize_pcolormesh(
                              fig = fig, axs = axs[2],
                              data_array = array,
                              longitude = array.longitude,
                              latitude = array.latitude,
                              projection = projection,
                              color_scale = color_scale[2],
                              pad = pad,
                              long_name = long_name,
                              units_name = units_name,
                              vmin = vmin,
                              vmax = vmax,
                              lon_min = plot_bbox[0][0],
                              lon_max = plot_bbox[1][0],
                              lat_min = plot_bbox[0][1],
                              lat_max = plot_bbox[1][1],
                              width_lon = width_lon,
                              height_lat = height_lat,
                              bbox_list = bbox_list, 
                              coords_list = coords_list,
                              regions_names = regions_names
                             )
    
    im = [im1, im2, im3]
 
    return im

In [None]:
def define_absolute_limits(vmin_manual, vmax_manual, model_variable, sensor_variable, diff_variable):

    """ Define absolute minimum and maximum within model and sensor datasets.
        Gets manual minimum and maximum or calculates it

        Args:
            vmin_manual (float): Input vmin by user
            vmax_manual (float): Input vmax by user
            model_variable (array): Variable to get limits in model data column
            sensor_variable (array): Variable to get limits in sensor data column
            diff_variable (array): Variable to get limits in difference data column

        Returns
            min_all (float): Absolute vmin
            max_all (float): Absolute vmax
            min_all_diff (float): Absolute vmin for difference values
            max_all_diff (float): Absolute vmax for difference values     
    """

    # Define absolute minimum and maximum within model and sensor datasets
    if vmin_manual == None and vmax_manual == None:
        
        min_model = np.nanmin(model_variable)
        max_model = np.nanmax(model_variable)
        min_sensor = np.nanmin(sensor_variable)
        max_sensor = np.nanmax(sensor_variable)
        max_all = max(max_sensor, max_model)
        min_all = min(min_sensor, min_model)

    else:
        
        if np.abs(vmax_manual) >= np.abs(vmin_manual):
            
            min_all = -np.abs(vmax_manual)
            max_all = np.abs(vmax_manual)

        elif np.abs(vmax_manual) < np.abs(vmin_manual):
            
            min_all = -np.abs(vmin_manual)
            max_all = np.abs(vmin_manual)

    # Define absolute minimum and maximum within difference
    min_all_diff = np.nanmin(diff_variable)
    max_all_diff = np.nanmax(diff_variable)

    return min_all, max_all, min_all_diff, max_all_diff

In [22]:
def visualize_model_vs_sensor(model, sensor, component_nom, units, merge_df, plot_dates, plot_bbox, pad, y, 
                              model_type, sensor_type, range_type, distribution_type, projection,
                              color_scale, width_lon, height_lat, 
                              vmin_manual = None, vmax_manual = None,
                              vmin_manual_diff = None, vmax_manual_diff = None, 
                              bbox_list = None, coords_list = None, regions_names = None):

    """ Plot model and sensor datasets in the study area for the selected dates, 
        along with a plot of the differences

        Args:
            model (str): Name of the model
            sensor (str): Name of the sensor
            component_nom (str): Component chemical nomenclature
            units (str): Component units
            merge_df (dataframe): Merge table with total column data and their difference
            plot_dates (arr): Plot dates
            plot_bbox (arr): Plot extent
            pad (float): Padding for the subtitles
            y (float): y-position of main title
            model_type (str): Model type:
            -  'Forecast'
            -  'Reanalysis'
            sensor_type (str): Sensor type
            range_type (str): Range type for colorbar:
            -  'original': Show original values in range
            -  'equal': Show same scale in range
            -  'manual': Show scale in range given by user
            -  'centered': Show scale centered in 0
            distribution_type (str): 
            -  'aggregated': Aggregate plots by time
            -  'individual': Show individual plots
            -  'animated: Show animation
            projection (projection): Geographical projection
            color_scale (list): Name of color scale (e.g. coolwarm) (in order for: model, sensor, difference)
            width_lon (int): Horizontal width of frame individual sections (black - white lines)
            height_lat (int): Vertical height of frame individual sections (black - white lines)
            vmin_manual (float): Input vmin by user
            vmax_manual (float): Input vmax by user
            vmin_manual_diff (float): Input vmin by user for difference values
            vmax_manual_diff (float): Input vmax by user for difference values
    """
    
    if len(color_scale) != 3:
        print('ERROR: color_scale has to include the scales (e.g. coolwarm) for the three maps (in order for: model, sensor, difference).')
        raise KeyboardInterrupt()

    # Get min and max 
    merge_df_bbox = merge_df.query('longitude >= @plot_bbox[0][0] and longitude <= @plot_bbox[1][0] and latitude >= @plot_bbox[0][1] and latitude <= @plot_bbox[1][1]')
    model_variable = merge_df_bbox['model_column']
    sensor_variable = merge_df_bbox['sensor_column']
    diff_variable = merge_df_bbox['difference']
    min_all, max_all, min_all_diff, max_all_diff = define_absolute_limits(vmin_manual, vmax_manual, 
                                                                          model_variable, sensor_variable, 
                                                                          diff_variable)

    units_name = component_nom + ' (' + units + ')'

    if distribution_type == 'aggregated':
            
        merge_ds_time = merge_df.to_xarray().mean(dim = 'time')
        latitude = merge_ds_time.latitude
        longitude = merge_ds_time.longitude
        merge_ds_time = merge_ds_time.assign_coords(latitude = latitude, longitude = longitude)

        fig, axs = plt.subplots(1, 3, figsize = (20, 5), subplot_kw = {'projection': projection})
        fig.set_facecolor('w')

        im = comparison_maps(fig, axs, merge_ds_time, range_type, sensor, model, sensor_type, model_type, 
                             projection, pad, units_name, plot_bbox, color_scale, max_all, min_all, 
                             min_all_diff, max_all_diff, width_lon, height_lat, vmin_manual, vmax_manual,
                             vmin_manual_diff, vmax_manual_diff, bbox_list, coords_list, regions_names)

        fig.suptitle(f'DISTRIBUTION OF {component_nom} (All times)',
                    fontsize = 18, fontweight = 'bold', y = y)

    elif distribution_type == 'seasonal':
        
        merge_df = merge_df.reset_index()
        merge_df['season'] = merge_df.apply(lambda row: get_season(row['time']), axis = 1)
        available_seasons = np.unique(merge_df['season'])
        merge_ds_seasons = merge_df.set_index(['latitude', 'longitude', 'time', 'season'])
        merge_ds_seasons = merge_ds_seasons.groupby(level = [0, 1, 3]).mean().to_xarray()
        
        for season in available_seasons:

            merge_ds_season = merge_ds_seasons.sel(season = season)

            fig, axs = plt.subplots(1, 3, figsize = (20, 5), subplot_kw = {'projection': projection})
            fig.set_facecolor('w')
            
            im = comparison_maps(fig, axs, merge_ds_season, range_type, sensor, model, sensor_type, model_type, 
                                 projection, pad, units_name, plot_bbox, color_scale, max_all, min_all, 
                                 min_all_diff, max_all_diff, width_lon, height_lat, vmin_manual, vmax_manual,
                                 vmin_manual_diff, vmax_manual_diff, bbox_list, coords_list, regions_names)
            
            fig.suptitle(f'DISTRIBUTION OF {component_nom} (Season: {season})',
                         fontsize = 18, fontweight = 'bold', y = y)

    elif distribution_type == 'individual':
        
        for time in plot_dates:

            merge_ds_time = merge_df.query('time == @time').to_xarray()
            latitude = merge_ds_time.sel(time = time).latitude
            longitude = merge_ds_time.sel(time = time).longitude
            merge_ds_time = merge_ds_time.sel(time = time).assign_coords(latitude = latitude, longitude = longitude)

            fig, axs = plt.subplots(1, 3, figsize = (20, 5), subplot_kw = {'projection': projection})
            fig.set_facecolor('w')
            
            im = comparison_maps(fig, axs, merge_ds_time, range_type, sensor, model, sensor_type, model_type, 
                                 projection, pad, units_name, plot_bbox, color_scale, max_all, min_all, 
                                 min_all_diff, max_all_diff, width_lon, height_lat, vmin_manual, vmax_manual,
                                 vmin_manual_diff, vmax_manual_diff, bbox_list, coords_list, regions_names)
            
            if (sensor == 'iasi' and sensor_type == 'L3') or (sensor == 'gome' and sensor_type == 'L3'):
                month = np.datetime64(time).astype('datetime64[M]')
                fig.suptitle(f'DISTRIBUTION OF {component_nom} (Month: {month})',
                             fontsize = 18, fontweight = 'bold', y = y)

            else:
                day = np.datetime64(time).astype('datetime64[D]')
                fig.suptitle(f'DISTRIBUTION OF {component_nom} (Date: {day})',
                             fontsize = 18, fontweight = 'bold', y = y)
                             
            plt.show()
        
    elif distribution_type == 'animated':

        fig, axs = plt.subplots(1, 3, figsize = (25, 10), subplot_kw = {'projection': projection})
        fig.set_facecolor('w')

        if (sensor == 'iasi' and sensor_type == 'L3') or (sensor == 'gome' and sensor_type == 'L3'):
            month = np.datetime64(plot_dates[0]).astype('datetime64[M]')
            fig_title = fig.text(0.5, 0.95, f'DISTRIBUTION OF {component_nom} (Month: {month})', 
                                 ha = 'center', fontsize = 22, fontweight = 'bold')

        else:
            day = np.datetime64(plot_dates[0]).astype('datetime64[D]')
            fig_title = fig.text(0.5, 0.95, f'DISTRIBUTION OF {component_nom} (Date: {day})', 
                                 ha = 'center', fontsize = 22, fontweight = 'bold')

        time = plot_dates[0]
        merge_ds_time = merge_df.query('time == @time').to_xarray()
        latitude = merge_ds_time.sel(time = time).latitude
        longitude = merge_ds_time.sel(time = time).longitude    
        merge_ds_time = merge_ds_time.sel(time= time).assign_coords(latitude = latitude, longitude = longitude)    
        im = comparison_maps(fig, axs, merge_ds_time, range_type, sensor, model, sensor_type, model_type,
                             projection, pad, units_name, plot_bbox, color_scale, max_all, min_all, 
                             min_all_diff, max_all_diff, width_lon, height_lat, vmin_manual, vmax_manual,
                             vmin_manual_diff, vmax_manual_diff, bbox_list, coords_list, regions_names)

        def animate(i):

            time = plot_dates[i]
            merge_ds_time = merge_df.query('time == @time').to_xarray()
            latitude = merge_ds_time.sel(time = time).latitude
            longitude = merge_ds_time.sel(time = time).longitude
            merge_ds_time = merge_ds_time.sel(time= time).assign_coords(latitude = latitude, longitude = longitude)
            im = comparison_maps(fig, axs, merge_ds_time, range_type, sensor, model, sensor_type, model_type, 
                                 projection, pad, units_name, plot_bbox, color_scale, max_all, min_all, 
                                 min_all_diff, max_all_diff, width_lon, height_lat, vmin_manual, vmax_manual,
                                 vmin_manual_diff, vmax_manual_diff, bbox_list, coords_list, regions_names)     

            if (sensor == 'iasi' and sensor_type == 'L3') or (sensor == 'gome' and sensor_type == 'L3'):
                month = np.datetime64(plot_dates[i]).astype('datetime64[M]')
                fig_title.set_text(f'DISTRIBUTION OF {component_nom} (Month: {month})')

            else:
                day = np.datetime64(plot_dates[i]).astype('datetime64[D]')
                fig_title.set_text(f'DISTRIBUTION OF {component_nom} (Date: {day})')

            return im

        anim = animation.FuncAnimation(fig, animate, frames = len(plot_dates), blit = True, interval = 1000)

        for j in range(0, 3):
           
            cbr = fig.colorbar(im[j], ax = axs[j], extend = 'both', orientation = 'horizontal', fraction = 0.05, pad = 0.15)
            cbr.set_label(units_name, fontsize = 18) 
            cbr.ax.tick_params(labelsize = 16)
            cbr.ax.xaxis.get_offset_text().set_fontsize(16)

        display(HTML(anim.to_jshtml()))
        anim.save('animation.gif')
        plt.close()

    else:
        print('The distribution type (distribution_type) must be defined as aggregated, individual or animated.')
        raise KeyboardInterrupt()

In [23]:
def visualize_model_original_vs_calculated(model, component_nom, units, merge_df, 
                                           model_total_ds, plot_dates, plot_bbox, pad, y, 
                                           model_type, range_type, projection, color_scale,
                                           width_lon, height_lat,
                                           vmin_manual = None, vmax_manual = None):

    """ Plot model total columns from the original dataset and the calculated one 
        in the study area for the selected dates

        Args:
            model (str): Name of the model
            component_nom (str): Component chemical nomenclature
            units (str): Component units
            merge_df (dataframe): Merge result
            model_total_ds (xarray): CAMS total columns dataset in xarray format
            plot_dates (arr): Plot dates
            plot_bbox (arr): Plot extent
            pad (float): Padding for the subtitles
            y (float): y-position of main title
            model_type (str): Model type:
            -  'Forecast'
            -  'Reanalysis'
            range_type (str): Range type for colorbar:
            -  'original': Show original values in range
            -  'equal': Show same scale in range
            -  'manual': Show scale in range given by user
            -  'centered': Show scale centered in 0
            projection: Geographical projection
            color_scale (list): Name of color scale (e.g. coolwarm) (in order for: original, calculated)
            width_lon (int): Horizontal width of frame individual sections (black - white lines)
            height_lat (int): Vertical height of frame individual sections (black - white lines)
            vmin_manual (float): Input vmin by user
            vmax_manual (float): Input vmax by user
    """

    units_name = component_nom + ' (' + units + ')'

    # Get min and max before splitting the data into timesteps
    min_model = np.nanmin(merge_df['model_column'])
    max_model = np.nanmax(merge_df['model_column'])
    min_sensor = np.nanmin(model_total_ds.component)
    max_sensor = np.nanmax(model_total_ds.component)
    max_all = max(max_sensor, max_model)
    min_all = min(min_sensor, min_model)
    min_all_diff, max_all_diff = None, None

    if len(color_scale) != 2:
        print('ERROR: color_scale has to include the scales (e.g. coolwarm) for the two maps (in order for: original, calculated).')
        raise KeyboardInterrupt()
    
    for time in plot_dates:

        fig, axs = plt.subplots(1, 2, figsize = (20, 5), subplot_kw = {'projection': projection})
        fig.set_facecolor('w')
        
        merge_ds_time = merge_df.query('time == @time').to_xarray()
        latitude = merge_ds_time.sel(time = time).latitude
        longitude = merge_ds_time.sel(time = time).longitude
        merge_ds_time = merge_ds_time.sel(time= time).assign_coords(latitude = latitude, longitude = longitude)

        step = 2

        # First plot - CAMS calculated total columns
        array = merge_ds_time.model_column
        diff_array = None
        vmin, vmax = colorbar_range(range_type, array, diff_array, max_all, min_all, 
                                    max_all_diff, min_all_diff, vmin_manual, vmax_manual, 
                                    vmin_manual_diff = None, vmax_manual_diff = None)
        long_name = 'CALCULATED TOTAL COLUMNS ' + model.upper() + ' (' + model_type + ')'
        visualize_pcolormesh(
                             fig = fig, axs = axs[0],
                             data_array = array,
                             longitude = array.longitude,
                             latitude = array.latitude,
                             projection = projection,
                             color_scale = color_scale[0],
                             pad = pad,
                             long_name = long_name,
                             units_name = units_name,
                             vmin = vmin, 
                             vmax = vmax, 
                             lon_min = plot_bbox[0][0],
                             lon_max = plot_bbox[1][0],
                             lat_min = plot_bbox[0][1],
                             lat_max = plot_bbox[1][1],
                             width_lon = width_lon,
                             height_lat = height_lat,
                             bbox_list = None, 
                             coords_list = None,
                             regions_names = None
                            )

        # Second plot - CAMS original total columns
        array = model_total_ds.component.isel(step = step).sel(time = time)
        diff_array = None
        vmin, vmax = colorbar_range(range_type, array, diff_array, max_all, min_all, 
                                    max_all_diff, min_all_diff, vmin_manual, vmax_manual, 
                                    vmin_manual_diff = None, vmax_manual_diff = None)
        long_name = 'ORIGINAL TOTAL COLUMNS ' + model.upper() + ' (' + model_type + ')'
        visualize_pcolormesh(
                             fig = fig, axs = axs[1],
                             data_array = array,
                             longitude = array.longitude,
                             latitude = array.latitude,
                             projection = projection,
                             color_scale = color_scale[1],
                             pad = pad,
                             long_name = long_name,
                             units_name = units_name,
                             vmin = vmin,
                             vmax = vmax, 
                             lon_min = plot_bbox[0][0],
                             lon_max = plot_bbox[1][0],
                             lat_min = plot_bbox[0][1],
                             lat_max = plot_bbox[1][1],
                             width_lon = width_lon,
                             height_lat = height_lat,
                             bbox_list = None, 
                             coords_list = None,
                             regions_names = None
                            )

        day = np.datetime64(time).astype('datetime64[D]')
        fig.suptitle(f'DISTRIBUTION OF {component_nom} (Date: {day})',
                     fontsize = 18, fontweight = 'bold', y = y)
        plt.show()

In [24]:
def get_google_api():

    """ Get Google API key for reverse geocoding (get country given the coordinates)
        
        Returns:
            environ_keys[1]: Google API key
    """

    # Open txt file with three lines:
    # GOOGLE API KEY (first line), GOOGLE CLIENT ID (second line) and GOOGLE CLIENT SECRET (third line)
    keys_file = open('data/keys.txt', 'r')
    keys = keys_file.readlines()
    environ_keys = [key.rstrip() for key in keys]

    # Set environment variables in your system
    os.environ['GOOGLE_API_KEY'] = environ_keys[1]
    os.environ['GOOGLE_CLIENT'] = environ_keys[2]
    os.environ['GOOGLE_CLIENT_SECRET'] = environ_keys[3]

    return environ_keys[1]

In [25]:
def get_season(day):

    """ Get season given the day

        Args:
            day (datetime): Date
        
        Returns:
            season (str): Season of the year
    """

    Y = 2000

    seasons = [('Winter', (dt.date(Y,  1,  1),  dt.date(Y,  3, 20))),
               ('Spring', (dt.date(Y,  3, 21),  dt.date(Y,  6, 20))),
               ('Summer', (dt.date(Y,  6, 21),  dt.date(Y,  9, 22))),
               ('Autumn', (dt.date(Y,  9, 23),  dt.date(Y, 12, 20))),
               ('Winter', (dt.date(Y, 12, 21),  dt.date(Y, 12, 31)))]
            
    day = day.replace(year = Y)

    season = next(season for season, (start, end) in seasons if start <= day <= end)
             
    return season

In [26]:
def linear_regression(X, Y, component_nom, axs):

    """ Fit a linear equation to scatter plot between X and Y and print results

        Args:
            X (array): Input sensor component values
            Y (array): Input model component values
            component_nom (str): Component chemical nomenclature
        
        Returns:
            fit_X (array): X in linear equation fit_Y = A * fit_X + B
            fit_Y (array): Y in linear equation fit_Y = A * fit_X + B
            R2 (float): slope of determination
            slope (float): A in linear equation fit_Y = A * fit_X + B
            intercept (float): B in linear equation fit_Y = A * fit_X + B
    """

    R2 = 'Unknown'
    slope = 'Unknown'
    intercept = 'Unknown'

    # Fit regression
    reg = LinearRegression().fit(X, Y)
    try:
        lim_min, lim_max = axs[0].get_xlim()
    except:
        lim_min, lim_max = axs.get_xlim()
    fit_X = np.linspace(lim_min, lim_max, 10) 
    fit_Y = fit_X * float(reg.coef_) + reg.intercept_
    
    # Get R2, slope and intercept
    R2 = reg.score(X, Y)
    slope = reg.coef_[0][0]
    intercept = reg.intercept_[0]

    # Calculate MSE and RMSE
    Y_pred = intercept + slope * X
    MSE = mean_squared_error(y_true = Y, y_pred = Y_pred, squared = True)
    RMSE = mean_squared_error(y_true = Y, y_pred = Y_pred, squared = False)

    return fit_X, fit_Y, slope, intercept, R2, RMSE, MSE

In [27]:
def stats_plots_general_settings(component_nom, axs, units, lim_min, lim_max):

    """ Set common settings for scatter plots

        Args:
            component_nom (str): Component chemical nomenclature
            plt (plot): Scatterplot
            units (str): Component units
            lim_min (float): Minimum value of component in scale
            lim_max (float): Maximum value of component in scale
    """

    # Scatter plot
    axs[0].set_xlabel(f'Sensor {component_nom} ({units})', fontsize = 20)
    axs[0].set_ylabel(f'Model {component_nom} ({units})', fontsize = 20)
    axs[0].set_ylim([lim_min, lim_max])

    # Histograms
    axs[1].set_xlabel(f'Sensor {component_nom} ({units})', fontsize = 20)
    axs[2].set_xlabel(f'Model {component_nom} ({units})', fontsize = 20)
    for i in range(1, 3):
        axs[i].set_ylabel(f'Count', fontsize = 20)
    
    # All
    for i in range(0, 3):
        axs[i].tick_params(labelsize = 18)
        axs[i].set_xlim([lim_min, lim_max])
        axs[i].locator_params(axis = 'both', nbins = 8)
        axs[i].xaxis.get_offset_text().set_fontsize(16)
        axs[i].yaxis.get_offset_text().set_fontsize(16)

In [28]:
def scatter_plot(merge_df, component_nom, units, sensor, plot_dates, y, extent_definition, 
                 show_seasons, scatter_plot_type, lim_min = None, lim_max = None, *args):

    """ Scatter plot between the model and sensor datasets in the study area for the selected dates (bbox or countries)

        Args:
            merge_df (dataframe): Merge result
            component_nom (str): Component chemical nomenclature
            units (str): Component units
            sensor (str): Name of the sensor
            plot_dates (arr): Plot dates
            plot_bbox (arr): Plot extent
            y (float): y-position of main title
            extent_definition (str):
            * 'country': Scatter plots for countries list
            * 'bbox': Scatter plots for bbox coordinates
            scatter_plot_type (str):
            * 'aggregated': Aggregate plots by time, country or season
            * 'individual': Individual plots per time, country or season
            *args: plot_countries, plot_bbox, lim_min, lim_max
    """

    summary = []

    # Set colors for scatter points
    sns.color_palette('colorblind', 10)

    # Drop NaN values
    merge_df = merge_df.dropna()

    # Prepare df
    if lim_min == None:
        lim_min = min(np.nanmin(merge_df['sensor_column']), np.nanmin(merge_df['model_column']))
    else:
        merge_df = merge_df[(merge_df['sensor_column'] >= lim_min) & (merge_df['model_column'] >= lim_min)]         

    if lim_max == None:
        lim_max = max(np.nanmax(merge_df['sensor_column']), np.nanmax(merge_df['model_column']))
    else:
        merge_df = merge_df[(merge_df['sensor_column'] <= lim_max) & (merge_df['model_column'] <= lim_max)]

    merge_df = merge_df.query('longitude >= @plot_bbox[0][0] and longitude <= @plot_bbox[1][0] and latitude >= @plot_bbox[0][1] and latitude <= @plot_bbox[1][1]')         
    merge_df = merge_df.reset_index()
    merge_df = merge_df[merge_df['time'].isin(plot_dates)]

    if show_seasons == False:

        if extent_definition == 'bbox':

            if scatter_plot_type == 'aggregated':

                if not merge_df.empty:
                    
                    fig, axs = plt.subplots(1, 3, figsize = (28, 7))
                    fig.set_facecolor('w')
            
                    # Scatter plot and histograms
                    sp = sns.scatterplot(data = merge_df, x = 'sensor_column', y = 'model_column', 
                                         hue = 'time', ax = axs[0])
                    sns.histplot(data = merge_df, x = 'sensor_column', kde = True,  ax = axs[1])
                    sns.histplot(data = merge_df, x = 'model_column', kde = True,  ax = axs[2])

                    stats_plots_general_settings(component_nom, axs, units, lim_min, lim_max)
                    fig.suptitle(f'{component_nom} (All times)', fontsize = 25, fontweight = 'bold', y = y)

                    # Line 1:1
                    line = mlines.Line2D([0, 1], [0, 1], color = 'grey', linestyle = '--', label = '1:1')
                    line.set_transform(axs[0].transAxes)
                    axs[0].add_line(line)

                    # Linear regression
                    X = merge_df['sensor_column'].values.reshape(-1, 1) 
                    Y = merge_df['model_column'].values.reshape(-1, 1) 
                    fit_X, fit_Y, slope, intercept, R2, RMSE, MSE = linear_regression(X, Y, component_nom, axs)
                    axs[0].plot(fit_X, fit_Y, color = 'black', label = 'A + B*X')

                    # Legend settings
                    if sensor_type == 'L3':

                        leg = sp.legend(loc = 'upper center', bbox_to_anchor = (0.5, -0.2),
                                        fancybox = True, ncol = 2, fontsize = 18)
                        leg.set_title('Legend', prop = {'size': 18, 'weight': 'bold'})

                    else:
                                        
                        leg = sp.legend(loc = 'upper center', bbox_to_anchor = (0.5, -0.2),
                                        fancybox = True, ncol = 2, fontsize = 18)
                        leg.set_title('Legend', prop = {'size': 18, 'weight': 'bold'})

                    plt.show()

                    # Update summary
                    summary.append({'Period': plot_dates, 'Location': plot_bbox, 
                                    'A': slope, 'B': intercept, 
                                    'R2': R2, 'RMSE': RMSE, 'MSE': MSE})

            elif scatter_plot_type == 'individual':
                
                for time in plot_dates:
                    
                    # Select dataframe for a time
                    merge_df_time = merge_df.query('time == @time')
                    
                    if not merge_df_time.empty:
                        
                        fig, axs = plt.subplots(1, 3, figsize = (28, 7))
                        fig.set_facecolor('w')

                        # Scatter plot and histograms
                        sp = sns.scatterplot(data = merge_df_time, x = 'sensor_column', y = 'model_column', ax = axs[0])
                        sns.histplot(data = merge_df_time, x = 'sensor_column', kde = True,  ax = axs[1])
                        sns.histplot(data = merge_df_time, x = 'model_column', kde = True,  ax = axs[2])

                        stats_plots_general_settings(component_nom, axs, units, lim_min, lim_max)
                        
                        if sensor == 'tropomi' or (sensor == 'gome' and sensor_type == 'L2'):
                            day = np.datetime64(time).astype('datetime64[D]')
                            fig.suptitle(f'{component_nom} (Date: {day})', 
                                         fontsize = 25, fontweight = 'bold', y = y)
                            
                        else:
                            month = np.datetime64(time).astype('datetime64[M]')
                            fig.suptitle(f'{component_nom} (Month: {month})', 
                                         fontsize = 25, fontweight = 'bold', y = y)

                        # Line 1:1
                        line = mlines.Line2D([0, 1], [0, 1], color = 'grey', linestyle = '--', label = '1:1')
                        line.set_transform(axs[0].transAxes)
                        axs[0].add_line(line)

                        # Linear regression
                        X = merge_df_time['sensor_column'].values.reshape(-1, 1) 
                        Y = merge_df_time['model_column'].values.reshape(-1, 1) 
                        fit_X, fit_Y, slope, intercept, R2, RMSE, MSE = linear_regression(X, Y, component_nom, axs)
                        axs[0].plot(fit_X, fit_Y, color = 'black', label = 'A + B*X')

                        # Legend settings
                        leg = sp.legend(loc = 'upper center', bbox_to_anchor = (0.5, -0.2),
                                        fancybox = True, ncol = 2, fontsize = 18)
                        leg.set_title('Legend', prop = {'size': 18, 'weight': 'bold'})
                        
                        plt.show()

                        # Update summary
                        summary.append({'Period': time, 'Location': plot_bbox, 
                                        'A': slope, 'B': intercept, 
                                        'R2': R2, 'RMSE': RMSE, 'MSE': MSE})

        elif extent_definition == 'country':

            # Read Google API key for reverse geocoding (get country by coordinates)
            google_api_key = get_google_api()

            # Reverse geocoding
            merge_df['country'] = merge_df.apply(lambda row: geocoder.google([row['latitude'], row['longitude']], 
                                                 method='reverse', key = google_api_key).country_long, axis = 1)

            # Find data for the countries in search list
            merge_df = merge_df[merge_df['country'].isin(plot_countries)]
            available_countries = np.unique(merge_df['country'])

            if scatter_plot_type == 'aggregated':

                if not merge_df.empty:
                    
                    fig, axs = plt.subplots(1, 3, figsize = (28, 7))
                    fig.set_facecolor('w')

                    # Scatter plot and histograms
                    sp = sns.scatterplot(data = merge_df, x = 'sensor_column', y = 'model_column', 
                                         hue = 'country', ax = axs[0])
                    sns.histplot(data = merge_df, x = 'sensor_column', kde = True,  ax = axs[1])
                    sns.histplot(data = merge_df, x = 'model_column', kde = True,  ax = axs[2])

                    stats_plots_general_settings(component_nom, axs, units, lim_min, lim_max)
                    fig.suptitle(f'{component_nom} (All countries)', fontsize = 25, fontweight = 'bold', y = y)
                    
                    # Linear regression
                    X = merge_df['sensor_column'].values.reshape(-1, 1) 
                    Y = merge_df['model_column'].values.reshape(-1, 1) 
                    fit_X, fit_Y, slope, intercept, R2, RMSE, MSE = linear_regression(X, Y, component_nom, axs)
                    axs[0].plot(fit_X, fit_Y, color = 'black', label = 'A + B*X')

                    # Line 1:1
                    line = mlines.Line2D([0, 1], [0, 1], color = 'grey', linestyle = '--', label = '1:1')
                    line.set_transform(axs[0].transAxes)
                    axs[0].add_line(line)

                    # Legend settings
                    leg = sp.legend(loc = 'upper center', bbox_to_anchor = (0.5, -0.2),
                                    fancybox = True, ncol = 2, fontsize = 18)
                    leg.set_title('Legend', prop = {'size': 18, 'weight': 'bold'})

                    plt.show()

                    # Update summary
                    summary.append({'Period': plot_dates, 'Location': available_countries, 
                                    'A': slope, 'B': intercept, 
                                    'R2': R2, 'RMSE': RMSE, 'MSE': MSE})

            elif scatter_plot_type == 'individual':

                for plot_country in plot_countries:

                    merge_df_country = merge_df[merge_df['country'] == plot_country]

                    if not merge_df_country.empty:
                        
                        fig, axs = plt.subplots(1, 3, figsize = (28, 7))
                        fig.set_facecolor('w')

                        # Scatter plot and histograms
                        sp = sns.scatterplot(data = merge_df_country, x = 'sensor_column', y = 'model_column', ax = axs[0])
                        sns.histplot(data = merge_df_country, x = 'sensor_column', kde = True,  ax = axs[1])
                        sns.histplot(data = merge_df_country, x = 'model_column', kde = True,  ax = axs[2])

                        stats_plots_general_settings(component_nom, axs, units, lim_min, lim_max)
                        fig.suptitle(f'{component_nom} ({plot_country})', fontsize = 25, fontweight = 'bold', y = y)

                        # Line 1:1
                        line = mlines.Line2D([0, 1], [0, 1], color = 'grey', linestyle = '--', label = '1:1')
                        line.set_transform(axs[0].transAxes)
                        axs[0].add_line(line)

                        # Linear regression
                        X = merge_df_country['sensor_column'].values.reshape(-1, 1) 
                        Y = merge_df_country['model_column'].values.reshape(-1, 1) 
                        fit_X, fit_Y, slope, intercept, R2, RMSE, MSE = linear_regression(X, Y, component_nom, axs)
                        axs[0].plot(fit_X, fit_Y, color = 'black', label = 'A + B*X')

                        # Legend settings
                        leg = sp.legend(loc = 'upper center', bbox_to_anchor = (0.5, -0.2),
                                        fancybox = True, ncol = 2, fontsize = 18)
                        leg.set_title('Legend', prop = {'size': 18, 'weight': 'bold'})
                        
                        plt.show()

                        # Update summary
                        summary.append({'Period': plot_dates, 'Location': plot_country, 
                                        'A': slope, 'B': intercept, 
                                        'R2': R2, 'RMSE': RMSE, 'MSE': MSE})

            else:
                print('ERROR: scatter_plot_type is wrongly defined. The options are ''aggregated'' and ''individual''.')
                raise KeyboardInterrupt()

        else:
            print('ERROR: extent_definition is wrongly defined. The options are ''bbox'' and ''country''.')
            raise KeyboardInterrupt()
                
    elif show_seasons == True:
        
        if show_seasons == True and extent_definition == 'country':
            print('ERROR: Set up show_seasons to False in order to show the scatter plots by countries.')
            raise KeyboardInterrupt()

        plot_seasons = ['Winter', 'Spring', 'Summer', 'Autumn']

        # Find data for the seasons in list
        merge_df['season'] = merge_df.apply(lambda row: get_season(row['time']), axis = 1)
        available_seasons = np.unique(merge_df['season'])

        if scatter_plot_type == 'aggregated':

            if not merge_df.empty:
                
                fig, axs = plt.subplots(1, 3, figsize = (28, 7))
                fig.set_facecolor('w')

                # Scatter plot and histograms
                sp = sns.scatterplot(data = merge_df, x = 'sensor_column', y = 'model_column', 
                                     hue = 'season', ax = axs[0])
                sns.histplot(data = merge_df, x = 'sensor_column', kde = True, ax = axs[1])
                sns.histplot(data = merge_df, x = 'model_column', kde = True, ax = axs[2])

                stats_plots_general_settings(component_nom, axs, units, lim_min, lim_max)
                fig.suptitle(f'{component_nom} (All seasons)', fontsize = 25, fontweight = 'bold', y = y)
 
                # Line 1:1
                line = mlines.Line2D([0, 1], [0, 1], color = 'grey', linestyle = '--', label = '1:1')
                line.set_transform(axs[0].transAxes)
                axs[0].add_line(line)

                # Linear regression
                X = merge_df['sensor_column'].values.reshape(-1, 1) 
                Y = merge_df['model_column'].values.reshape(-1, 1) 
                fit_X, fit_Y, slope, intercept, R2, RMSE, MSE = linear_regression(X, Y, component_nom, axs)
                axs[0].plot(fit_X, fit_Y, color = 'black', label = 'A + B*X')

                # Legend settings
                leg = sp.legend(loc = 'upper center', bbox_to_anchor = (0.5, -0.2),
                                fancybox = True, ncol = 2, fontsize = 18)
                leg.set_title('Legend', prop = {'size': 18, 'weight': 'bold'})

                plt.show()

                # Update summary
                summary.append({'Period': available_seasons, 'Location': plot_bbox, 
                                'A': slope, 'B': intercept, 
                                'R2': R2, 'RMSE': RMSE, 'MSE': MSE})
                                
        elif scatter_plot_type == 'individual':

            for plot_season in plot_seasons:
                
                # Prepare df
                merge_df_season = merge_df[merge_df['season'] == plot_season]
                
                if not merge_df_season.empty:
                    
                    fig, axs = plt.subplots(1, 3, figsize = (28, 7))
                    fig.set_facecolor('w')

                    # Scatter plot and histograms
                    sp = sns.scatterplot(data = merge_df_season, x = 'sensor_column', y = 'model_column', ax = axs[0])
                    sns.histplot(data = merge_df_season, x = 'sensor_column', kde = True, ax = axs[1])
                    sns.histplot(data = merge_df_season, x = 'model_column', kde = True, ax = axs[2])

                    fig.suptitle(f'{component_nom} ({plot_season})', fontsize = 25, fontweight = 'bold', y = y)
                    stats_plots_general_settings(component_nom, axs, units, lim_min, lim_max)

                    # Line 1:1
                    line = mlines.Line2D([0, 1], [0, 1], color = 'grey', linestyle = '--', label = '1:1')
                    line.set_transform(axs[0].transAxes)
                    axs[0].add_line(line)

                    # Linear regression
                    X = merge_df_season['sensor_column'].values.reshape(-1, 1) 
                    Y = merge_df_season['model_column'].values.reshape(-1, 1) 
                    fit_X, fit_Y, slope, intercept, R2, RMSE, MSE = linear_regression(X, Y, component_nom, axs)
                    axs[0].plot(fit_X, fit_Y, color = 'black', label = 'A + B*X')

                    # Legend settings
                    leg = sp.legend(loc = 'upper center', bbox_to_anchor = (0.5, -0.2),
                                    fancybox = True, ncol = 2, fontsize = 18)
                    leg.set_title('Legend', prop = {'size': 18, 'weight': 'bold'})

                    plt.show()

                    # Update summary
                    summary.append({'Period': plot_season, 'Location':  plot_bbox, 
                                    'A': slope, 'B': intercept, 
                                    'R2': R2, 'RMSE': RMSE, 'MSE': MSE})
        else:
            print('ERROR: scatter_plot_type is wrongly defined. The options are ''aggregated'' and ''individual''.')
            raise KeyboardInterrupt()

    else:
        print('ERROR: show_seasons is wrongly defined. The options are True and False.')
        raise KeyboardInterrupt()
    
    summary = pd.DataFrame(summary)

    return summary

In [29]:
def timeseries(merge_df, component_nom, sensor, sensor_type, model, plot_dates, units, 
               ymin, ymax, xticks, regions_names, coords_list):

    """ Get component data for the closest coordinates to the list of search coordinates and plot them along time

        Args:
            merge_df (dataframe): Merge result
            component_nom (str): Component chemical nomenclature
            sensor (str): Name of the sensor
            model (str): Name of the model
            plot_dates (arr): Plot dates
            units (str): Component units
            ymin (float): Minimum y-axis value
            ymax (float): Maximum y-axis value
            regions_names (list): Region names 
            coords_list (list): List of search coordinates (eg. (lat, lon, lat, lon, ...)

        Returns:
            timeseries_table (dataframe): Dataframe with results from search
    """
    
    timeseries_table = pd.DataFrame()

    # Drop NaN values
    merge_df = merge_df.dropna()

    # Drop the dates that have NaN values
    plot_dates = np.intersect1d(plot_dates, np.unique(merge_df.index.get_level_values(2)))

    # Get coordinates pairs
    coords_search = pairwise(coords_list)

    for i, region_name in zip(range(0, len(coords_search)), regions_names):
       
        for time in plot_dates:
            
            # List of available points per time
            timeseries_table_time = merge_df.query('time == @time').reset_index()
            available_points = list([(x, y) for x, y in zip(timeseries_table_time['latitude'], 
                                                            timeseries_table_time['longitude'])])
            
            # Get closest pair to coordinates in search list
            lat_found = closest_point(coords_search[i], available_points)[0]
            lon_found = closest_point(coords_search[i], available_points)[1]
            timeseries_table_time = merge_df.query('latitude == @lat_found and longitude == @lon_found and time == @time')

            timeseries_table_time = timeseries_table_time.reset_index()
            timeseries_table_time['lat_search'] = coords_search[i][0]
            timeseries_table_time['lon_search'] = coords_search[i][1]
            timeseries_table_time['region'] = region_name

            # Append retrieval table to previous coordinates
            timeseries_table = timeseries_table.append(timeseries_table_time)
        
        table_length = len(timeseries_table[(timeseries_table['lat_search'] == coords_search[i][0]) &
                                            (timeseries_table['lon_search'] == coords_search[i][1])])

        # Plot variations in time
        if table_length > 1:
        
            fig, ax = plt.subplots(figsize = (30, 7))
            fig.set_facecolor('w')

            timeseries_table_time = timeseries_table[(timeseries_table['lat_search'] == coords_search[i][0]) & 
                                                     (timeseries_table['lon_search'] == coords_search[i][1])]
            ax.plot(timeseries_table_time['time'], timeseries_table_time['model_column'], 
                    label = model.upper(), linestyle = '--', marker = 'o', 
                    linewidth = 2, markersize = 10, color = 'blue')
            ax.plot(timeseries_table_time['time'], timeseries_table_time['sensor_column'], 
                    label = 'GOME-2' if sensor == 'gome' else sensor.upper(),
                    linestyle = '--', marker = 'o', linewidth = 2, markersize = 10, color = 'red')
            
            ax.legend(loc = 'center left', bbox_to_anchor = (1, 0.5), prop = {'size': 25})

            if sensor_type == 'L2':
                ax.set_xlabel('Estimated time', fontsize = 25)
                
            elif sensor_type == 'L3':
                ax.set_xlabel('Month', fontsize = 25)

            ax.tick_params(labelsize = 22)
            ax.yaxis.get_offset_text().set_size(18) 
            ax.set_ylim([ymin, ymax])
            ax.set_xticks(xticks)
            ax.set_ylabel(f'{component_nom} ({units})', fontsize = 25)
            ax.set_title(f'{component_nom} concentration in {region_name} ({coords_search[i][0]}, {coords_search[i][1]})', 
                         fontsize = 25, fontweight = 'bold', y = 1.05)
        
    timeseries_table = timeseries_table.set_index(['region', 'lat_search', 'lon_search', 
                                                   'latitude', 'longitude', 'time'])

    return timeseries_table

In [None]:
def monthly_annual_cycle(merge_df, component_nom, sensor, model, units, 
                         ymin, ymax, regions_names, bbox_list):

    """ Get monthly annual cycle by region, defined by coordinates

        Args:
            merge_df (dataframe): Merge result
            component_nom (str): Component chemical nomenclature
            sensor (str): Name of the sensor
            model (str): Name of the model
            units (str): Component units
            ymin (float): Minimum y-axis value
            ymax (float): Maximum y-axis value
            regions_names (list): Region names
            bbox_list (list): List of search bounding boxes (eg. (lat_min, lat_max, lon_min, lon_max, ...)

        Returns:
            monthly_annual_cycle_table (dataframe): Dataframe with monthly annual cycle for all regions given
    """

    # Drop NaN values
    merge_df = merge_df.dropna()

    # Transform string to tuple (if there is only one element)
    if isinstance(regions_names, str):
        regions_names = tuple([regions_names])

    merge_df = merge_df.reset_index()
    merge_df['month'] = merge_df.apply(lambda row: row['time'].month, axis = 1)
    available_months = np.unique(merge_df['month'])

    regions_lats = pairwise(bbox_list)[0::2]
    regions_lons = pairwise(bbox_list)[1::2]

    monthly_annual_cycle_table = []

    for region_lats, region_lons, region_name in zip(regions_lats, regions_lons, regions_names):
        
        fig, ax = plt.subplots(figsize = (30, 7))
        fig.set_facecolor('w')

        summary_region = []
        
        # Define and apply bounding box
        region_bbox = ((region_lons[0], region_lats[0]), (region_lons[1], region_lats[1]))
        merge_df_region = merge_df.query('longitude >= @region_bbox[0][0] and longitude <= @region_bbox[1][0] and latitude >= @region_bbox[0][1] and latitude <= @region_bbox[1][1]')

        for month in available_months:

            # Retrieve mean and standard deviation by month
            merge_df_region_month = merge_df_region.query('month == @month')
            descr_stats_table = merge_df_region_month.describe()
            model_column_mean = descr_stats_table['model_column']['mean']
            model_column_std = descr_stats_table['model_column']['std']
            sensor_column_mean = descr_stats_table['sensor_column']['mean']
            sensor_column_std = descr_stats_table['sensor_column']['std']

            # Update summary
            summary_region.append({'location': region_name, 
                                   'month': month,
                                   'model mean': model_column_mean,
                                   'model std': model_column_std,
                                   'sensor mean': sensor_column_mean,
                                   'sensor std': sensor_column_std
                                  })

        summary_region = pd.DataFrame(summary_region)
        
        # Create error bars
        xval = available_months
        model_yval = summary_region[summary_region['location'] == region_name]['model mean']
        model_yerr = summary_region[summary_region['location'] == region_name]['model std']
        sensor_yval = summary_region[summary_region['location'] == region_name]['sensor mean']
        sensor_yerr = summary_region[summary_region['location'] == region_name]['sensor std']

        plt.errorbar(xval, model_yval, yerr = model_yerr, label = model.upper(),
                     linestyle = '--', marker = 'o', linewidth = 2, markersize = 10, color = 'blue')
        plt.errorbar(xval, sensor_yval, yerr = sensor_yerr, label = 'GOME-2' if sensor == 'gome' else sensor.upper(),
                     linestyle = '--', marker = 'o', linewidth = 2, markersize = 10, color = 'red')

        ax.legend(loc = 'center left', bbox_to_anchor = (1, 0.5), prop = {'size': 25})
        
        # Format axes
        ax.tick_params(labelsize = 22)
        ax.yaxis.get_offset_text().set_size(18)
        ax.set_xticks(np.arange(1, 13))
        ax.set_xticklabels(['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep','Oct', 'Nov', 'Dec'])
        ax.set_ylim([ymin, ymax])
        ax.set_ylabel(f'{component_nom} ({units})', fontsize = 25)
        ax.set_title(f'{component_nom} mean concentration around {region_name} {region_bbox}', 
                     fontsize = 25, fontweight = 'bold', y = 1.05)
        plt.show()

        monthly_annual_cycle_table.append(summary_region)

    monthly_annual_cycle_table = pd.concat(monthly_annual_cycle_table)
    
    return monthly_annual_cycle_table

In [None]:
def trends(merge_df, component_nom, sensor, model, units, ymin, ymax, 
           plot_dates, regions_names, bbox_list, sensor_break_date, model_break_date):

    """ Get trends by region, defined by coordinates. These trends are generated following linear and sinusoidal models.

        Args:
            merge_df (dataframe): Merge table with total column data and their difference
            component_nom (str): Component chemical nomenclature
            sensor (str): Name of the sensor
            model (str): Name of the model
            units (str): Component units
            ymin (float): Minimum y-axis value
            ymax (float): Maximum y-axis value
            plot_dates (arr): Plot dates
            regions_names (list): Region names
            bbox_list (list): List of search bounding boxes (eg. (lat_min, lat_max, lon_min, lon_max, ...)
            sensor_break_date (str): Break date (e.g. '2013-01-01) or None for sensor data
            model_break_date (str): Break date (e.g. '2013-01-01) or None for model data

        Returns:
            trends_table (dataframe): Dataframe with trends for all regions given
    """

    if len(np.unique(merge_df.reset_index()['time'])) >= 12:
    
        # Sinusoidal model
        def objective_function_sin(X, C, D, E, N):
            return C * np.sin(D * X + E) + N

        # Transform string to tuple (if there is only one element)
        if isinstance(regions_names, str):
            regions_names = tuple([regions_names])

        regions_lats = pairwise(bbox_list)[0::2]
        regions_lons = pairwise(bbox_list)[1::2]

        trends_table = []

        # Drop NaN values
        merge_df = merge_df.dropna()

        for region_lats, region_lons, region_name in zip(regions_lats, regions_lons, regions_names):

            summary_region = []
            summary_region_trend = []

            # Get data for each bounding box
            region_bbox = ((region_lons[0], region_lats[0]), (region_lons[1], region_lats[1]))
            merge_df_region = merge_df.query('longitude >= @region_bbox[0][0] and longitude <= @region_bbox[1][0] and latitude >= @region_bbox[0][1] and latitude <= @region_bbox[1][1]')

            # Drop the dates that have NaN values
            plot_dates = np.intersect1d(plot_dates, np.unique(merge_df.index.get_level_values(2)))
            start_date = pd.to_datetime(plot_dates[0])

            for time in plot_dates:
                
                # Get data for each time
                merge_df_region_time = merge_df_region.query('time == @time').reset_index()

                # Retrieve mean by date
                if not merge_df_region_time.empty:

                    # Calculate number of months since start date
                    merge_df_region_time['month'] = merge_df_region_time.apply(lambda row: (row['time'].year - start_date.year) * 12 + 
                                                                                           (row['time'].month - start_date.month),
                                                                                           axis = 1)
                    merge_df_region_time = merge_df_region_time.reset_index()

                    month = int(np.unique(merge_df_region_time['month']))
                    descr_stats_table = merge_df_region_time.describe()
                    model_column_mean = descr_stats_table['model_column']['mean']
                    sensor_column_mean = descr_stats_table['sensor_column']['mean']  
                    
                    # Update summary
                    summary_region.append({'location': region_name, 
                                           'time': time,
                                           'month': month,
                                           'model mean': model_column_mean,
                                           'sensor mean': sensor_column_mean,
                                          })

            summary_region = pd.DataFrame(summary_region)

            # Plot trends
            fig, ax = plt.subplots(figsize = (30, 7))
            fig.set_facecolor('w')

            if sensor_break_date == None and model_break_date == None:
                sources = ['sensor', 'model']
                colors = ['blue', 'red']

            elif sensor_break_date != None and model_break_date == None:
                sources = ['sensor-1', 'sensor-2', 'model']
                colors = ['blue', 'black', 'red']

            elif sensor_break_date == None and model_break_date != None:
                sources = ['sensor', 'model-1', 'model-2']
                colors = ['blue', 'red', 'green']

            elif sensor_break_date != None and model_break_date != None:
                sources = ['sensor-1', 'sensor-2', 'model-1', 'model-2']
                colors = ['blue', 'black', 'red', 'green']

            for source, color in zip(sources, colors):
                
                # Get data to generate the trend fits
                if sensor_break_date == None and model_break_date == None:
                    X = summary_region['month'].values
                    Y = summary_region[source + ' mean'].values

                elif sensor_break_date != None and model_break_date == None:
                    if source == 'model':
                        X = summary_region['month'].values
                        Y = summary_region[source + ' mean'].values
                    
                    elif source == 'sensor-1':
                        X = summary_region[summary_region['time'] <= sensor_break_date]['month'].values
                        Y = summary_region[summary_region['time'] <= sensor_break_date]['sensor mean'].values

                    elif source == 'sensor-2':
                        X = summary_region[summary_region['time'] > sensor_break_date]['month'].values
                        Y = summary_region[summary_region['time'] > sensor_break_date]['sensor mean'].values

                elif sensor_break_date == None and model_break_date != None:
                    if source == 'sensor':
                        X = summary_region['month'].values
                        Y = summary_region[source + ' mean'].values

                    if source == 'model-1':
                        X = summary_region[summary_region['time'] <= model_break_date]['month'].values
                        Y = summary_region[summary_region['time'] <= model_break_date]['model mean'].values

                    elif source == 'model-2':
                        X = summary_region[summary_region['time'] > model_break_date]['month'].values
                        Y = summary_region[summary_region['time'] > model_break_date]['model mean'].values

                elif sensor_break_date != None and model_break_date != None:
                    if source == 'model-1':
                        X = summary_region[summary_region['time'] <= model_break_date]['month'].values
                        Y = summary_region[summary_region['time'] <= model_break_date]['model mean'].values

                    elif source == 'model-2':
                        X = summary_region[summary_region['time'] > model_break_date]['month'].values
                        Y = summary_region[summary_region['time'] > model_break_date]['model mean'].values

                    elif source == 'sensor-1':
                        X = summary_region[summary_region['time'] <= sensor_break_date]['month'].values
                        Y = summary_region[summary_region['time'] <= sensor_break_date]['sensor mean'].values

                    elif source == 'sensor-2':
                        X = summary_region[summary_region['time'] > sensor_break_date]['month'].values
                        Y = summary_region[summary_region['time'] > sensor_break_date]['sensor mean'].values

                # Plot lines and scatter
                ax.scatter(X, Y, color = color, s = 100)

                # Linear model
                linear_fit_X, linear_fit_Y, B, A, linear_R2, linear_RMSE, linear_MSE = linear_regression(X.reshape(-1, 1), Y.reshape(-1, 1), component_nom, ax)
                ax.plot(X, A + B * X, '--', color = color, linewidth = 3, 
                        label = model.upper() + ' (A + B*X' if (source == 'model' or source == 'model-1' or source == 'model-2')
                                else 'GOME-2 (A + B*X)' if (source == 'sensor' or source == 'sensor-1' or source == 'sensor-2') and sensor == 'gome' 
                                else sensor.upper() + ' (A + B*X')

                try:
                    
                    # Set initial conditions
                    C0 = max(Y) - min(Y) # Initial amplitude
                    D0 = np.pi/6 # Initial frequency
                    E0 = 0 # Initial phase shift
                    N0 = np.mean(Y) # Initial offset
                    p0 = [C0, D0, E0, N0]

                    # Fit curve
                    sin_fit = curve_fit(objective_function_sin, X, Y, p0 = p0)
                    C, D, E, N = sin_fit[0][0], sin_fit[0][1], sin_fit[0][2], sin_fit[0][3]
                    sin_fit_Y = C * np.sin(D * X + E) + N

                    # Calculate R2, RMSE and MSE
                    sin_R2 = r2_score(y_true = Y, y_pred = sin_fit_Y)
                    sin_MSE = mean_squared_error(y_true = Y, y_pred = sin_fit_Y, squared = True)
                    sin_RMSE = mean_squared_error(y_true = Y, y_pred = sin_fit_Y, squared = False)

                    ax.plot(X, C * np.sin(D * X + E) + N, color = color, linewidth = 3, 
                            label = model.upper() + ' (C*np.sin(D*X + E) + N)' if (source == 'model' or source == 'model-1' or source == 'model-2')
                                    else 'GOME-2 (C*np.sin(D*X + E) + N)' if (source == 'sensor' or source == 'sensor-1' or source == 'sensor-2') and sensor == 'gome' 
                                    else sensor.upper() + ' (C*np.sin(D*X + E) + N)')

                except:
                    
                    print(f'{source} data at {region_name} cannot be fitted with a sinusoidal model to account for the seasonality, consider working with a larger dataset.')
                    sin_R2, sin_MSE, sin_RMSE = np.nan, np.nan, np.nan
                    C, D, E, N = np.nan, np.nan, np.nan, np.nan

                # Calculate number of months in study period
                start_period = X[0]
                end_period = X[-1]
                period = start_period - end_period

                # Calculate concentration difference in study period
                start_conc = A + B*start_period
                end_conc = A + B*end_period
                diff_conc = start_conc - end_conc

                # Calculate change rate
                change_rate_units = (diff_conc * 12) / period
                change_rate_100 = (diff_conc * 12) * 100 / (start_conc * period)

                # Update summary
                summary_region_trend.append({'Location': region_name, 
                                             'Source': source.capitalize(),
                                             'Rate ('+ units +' y-1)': change_rate_units,
                                             'Rate (% y-1)': change_rate_100,
                                             'A': A, 'B': B,
                                             'C': C, 'D': D,
                                             'E': E, 'N': N,
                                             'Linear R2': linear_R2, 
                                             'Linear RMSE': linear_RMSE,
                                             'Linear MSE': linear_MSE,
                                             'Sinusoidal R2': sin_R2, 
                                             'Sinusoidal RMSE': sin_RMSE,
                                             'Sinusoidal MSE': sin_MSE,
                                             })

                # Format axes
                ax.legend(loc='center left', bbox_to_anchor = (1, 0.5), prop = {'size': 27})
                ax.tick_params(labelsize = 26)
                ax.yaxis.get_offset_text().set_size(20)
                months_num = summary_region['month'].values
                ax.set_xticks(np.arange(months_num[0], months_num[-1] + 2, 6))
                ax.set_xticklabels(np.arange(months_num[0], months_num[-1] + 2, 6))
                ax.set_ylim([ymin, ymax])
                ax.set_xlabel(f'Number of months', fontsize = 30)
                ax.set_ylabel(f'{component_nom} ({units})', fontsize = 30)
                ax.set_title(f'{component_nom} mean concentration around {region_name} {region_bbox}', 
                             fontsize = 30, fontweight = 'bold', y = 1.05)

            plt.show()

            summary_region_trend = pd.DataFrame(summary_region_trend)
            trends_table.append(summary_region_trend)

        trends_table = pd.concat(trends_table)

    else:
        print('The trends can only be derived if there are data for more than a year.')

    return trends_table

In [None]:
def trend_maps(trends_table, component_nom, sensor, model, sensor_type, model_type, 
               units, plot_dates, pad, y, source_1, source_2, start_period, end_period, 
               vmin_manual_rate_units, vmax_manual_rate_units, 
               vmin_manual_rate_100, vmax_manual_rate_100,
               vmin_manual_diff_units, vmax_manual_diff_units, 
               vmin_manual_diff_100, vmax_manual_diff_100, width_lon, height_lat,
               bbox_list, coords_list, regions_names):

    """ Create 6 plots with:
        -   Rate (units/year) for model data
        -   Rate (%/year) for model data
        -   Rate (units/year) for sensor data
        -   Rate (%/year) for sensor data
        -   Rate (units/year) for difference data
        -   Rate (%/year) for difference data

        Args:
            trends_table (dataframe): Dataframe with trends for whole plot bounding box
            component_nom (str): Component chemical nomenclature
            sensor (str): Name of the sensor
            model (str): Name of the model
            model_type (str): Model type:
            -  'Forecast'
            -  'Reanalysis'
            sensor_type (str): Sensor type
            units (str): Component units
            plot_dates (arr): Plot dates
            pad (float): Padding for the subtitles
            y (float): y-position of main title
            start_period (int): Number of months that have passed at start date
            end_period (int): Number of months that have passed at end date
            source_1 (str): This can be sensor, sensor-1 or sensor-2
            source_2 (str): This can be model, model-1 or model-2
            sensor_break_date (str): Break date (e.g. '2013-01-01) or None for sensor data
            model_break_date (str): Break date (e.g. '2013-01-01) or None for model data
            vmin_manual_rate_units (float): Input vmin by user for rate in units/year
            vmax_manual_rate_units (float): Input vmax by user for rate in units/year
            vmin_manual_rate_100 (float): Input vmin by user for rate in %/year
            vmax_manual_rate_100 (float): Input vmax by user for rate in %/year
            vmin_manual_diff_units (float): Input vmin by user rate difference in units/year
            vmax_manual_diff_units (float): Input vmax by user rate difference in units/year
            vmin_manual_diff_100 (float): Input vmin by user for rate difference in %/year
            vmax_manual_diff_100 (float): Input vmax by user for rate difference in %/year
            width_lon (int): Horizontal width of frame individual sections (black - white lines)
            height_lat (int): Vertical height of frame individual sections (black - white lines)
            bbox_list (list): List of search bounding boxes (eg. (lat_min, lat_max, lon_min, lon_max, ...)
            coords_list (list): List of search coordinates (eg. (lat, lon, lat, lon, ...)
            regions_names (list): Region names
    """

    for rate_type, units_name, vmin, vmax, vmin_diff, vmax_diff in zip(['rate (' + units + ' y-1)', 'rate (% y-1)'],
                                                                       ['Trend (' + units + ' y-1)', 'Trend (% y-1)'],
                                                                       [vmin_manual_rate_units, vmin_manual_rate_100],
                                                                       [vmax_manual_rate_units, vmax_manual_rate_100],
                                                                       [vmin_manual_diff_units, vmin_manual_diff_100],
                                                                       [vmax_manual_diff_units, vmax_manual_diff_100]):

        trends_ds = trends_table.set_index(['source', 'latitude', 'longitude']).to_xarray()

        fig, axs = plt.subplots(1, 3, figsize = (20, 5), subplot_kw = {'projection': projection})
        fig.set_facecolor('w')

        # Difference array
        diff_array = trends_ds.sel(source = source_2)[rate_type] - trends_ds.sel(source = source_1)[rate_type]

        # First plot - Model rates
        array = trends_ds.sel(source = source_2)[rate_type]
        long_name = model.upper() + ' (' + model_type + ')'
        visualize_pcolormesh(
                            fig = fig, axs = axs[0],
                            data_array = array,
                            longitude = array.longitude,
                            latitude = array.latitude,
                            projection = projection,
                            color_scale = color_scale[0],
                            pad = pad,
                            long_name = long_name,
                            units_name = units_name,
                            vmin = vmin, 
                            vmax = vmax, 
                            lon_min = plot_bbox[0][0],
                            lon_max = plot_bbox[1][0],
                            lat_min = plot_bbox[0][1],
                            lat_max = plot_bbox[1][1],
                            width_lon = width_lon,
                            height_lat = height_lat,
                            bbox_list = bbox_list, 
                            coords_list = coords_list,
                            regions_names = regions_names
                           )

        # Second plot - Sensor rates
        array = trends_ds.sel(source = source_1)[rate_type]
        long_name = 'GOME-2' + ' (' + sensor_type + ')' if sensor == 'gome' else sensor.upper() + ' (' + sensor_type + ')'
        visualize_pcolormesh(
                            fig = fig, axs = axs[1],
                            data_array = array,
                            longitude = array.longitude,
                            latitude = array.latitude,
                            projection = projection,
                            color_scale = color_scale[1],
                            pad = pad,
                            long_name = long_name,
                            units_name = units_name,
                            vmin = vmin,
                            vmax = vmax, 
                            lon_min = plot_bbox[0][0],
                            lon_max = plot_bbox[1][0],
                            lat_min = plot_bbox[0][1],
                            lat_max = plot_bbox[1][1],
                            width_lon = width_lon,
                            height_lat = height_lat,
                            bbox_list = bbox_list, 
                            coords_list = coords_list,
                            regions_names = regions_names
                           )

        # Second plot - Rate difference
        array = diff_array
        long_name = 'Difference (' + model.upper() + ' - ' + sensor.upper() + ')'  
        visualize_pcolormesh(
                            fig = fig, axs = axs[2],
                            data_array = array,
                            longitude = array.longitude,
                            latitude = array.latitude,
                            projection = projection,
                            color_scale = color_scale[2],
                            pad = pad,
                            long_name = long_name,
                            units_name = units_name,
                            vmin = vmin_diff,
                            vmax = vmax_diff, 
                            lon_min = plot_bbox[0][0],
                            lon_max = plot_bbox[1][0],
                            lat_min = plot_bbox[0][1],
                            lat_max = plot_bbox[1][1],
                            width_lon = width_lon,
                            height_lat = height_lat,
                            bbox_list = bbox_list, 
                            coords_list = coords_list,
                            regions_names = regions_names
                           )

        fig.suptitle(f'TRENDS OF {component_nom} (Months {str(start_period)} - {str(end_period)})', 
                    fontsize = 18, fontweight = 'bold', y = y)
        plt.show()

In [None]:
def visualize_bbox_trends(merge_df, component_nom, sensor, model, units, plot_dates, 
                          sensor_break_date, model_break_date, pad, y,
                          vmin_manual_rate_units, vmax_manual_rate_units, 
                          vmin_manual_rate_100, vmax_manual_rate_100,
                          vmin_manual_diff_units, vmax_manual_diff_units, 
                          vmin_manual_diff_100, vmax_manual_diff_100,
                          width_lon, height_lat, 
                          bbox_list = None, coords_list = None, regions_names = None):

    """ Get trends at whole plot bounding box. These trends are generated following a linear model.
        This function is only available for trends without breaks or with a common break 
        at the same day for both the model and sensor datasets.

        Args:
            merge_df (dataframe): Merge table with total column data and their difference
            component_nom (str): Component chemical nomenclature
            sensor (str): Name of the sensor
            model (str): Name of the model
            model_type (str): Model type:
            -  'Forecast'
            -  'Reanalysis'
            sensor_type (str): Sensor type
            units (str): Component units
            plot_dates (arr): Plot dates
            sensor_break_date (str): Break date (e.g. '2013-01-01) or None for sensor data
            model_break_date (str): Break date (e.g. '2013-01-01) or None for model data
            pad (float): Padding for the subtitles
            y (float): y-position of main title
            vmin_manual_rate_units (float): Input vmin by user for rate in units/year
            vmax_manual_rate_units (float): Input vmax by user for rate in units/year
            vmin_manual_rate_100 (float): Input vmin by user for rate in %/year
            vmax_manual_rate_100 (float): Input vmax by user for rate in %/year
            vmin_manual_diff_units (float): Input vmin by user rate difference in units/year
            vmax_manual_diff_units (float): Input vmax by user rate difference in units/year
            vmin_manual_diff_100 (float): Input vmin by user for rate difference in %/year
            vmax_manual_diff_100 (float): Input vmax by user for rate difference in %/year
            width_lon (int): Horizontal width of frame individual sections (black - white lines)
            height_lat (int): Vertical height of frame individual sections (black - white lines)
            bbox_list (list): List of search bounding boxes (eg. (lat_min, lat_max, lon_min, lon_max, ...)
            coords_list (list): List of search coordinates (eg. (lat, lon, lat, lon, ...)
            regions_names (list): Region names
    """
    
    if ((sensor_break_date != None and model_break_date != None and sensor_break_date != model_break_date) or
        (sensor_break_date != None and model_break_date == None) or
        (sensor_break_date == None and model_break_date != None)):
            print('Trend maps can only be generated when the break date is the same for both datasets or when there are no break dates.')
            raise KeyboardInterrupt()

    else:

        print('WARNING: Computing the trends for each coordinate will take some time.')
        trends_table = []

        # Drop NaN values
        merge_df = merge_df.dropna()
        
        # Drop the dates that have NaN values
        plot_dates = np.intersect1d(plot_dates, np.unique(merge_df.index.get_level_values(2)))
        start_date = pd.to_datetime(plot_dates[0])

        for latitude in np.unique(merge_df.index.get_level_values(0)):
            for longitude in np.unique(merge_df.index.get_level_values(1)):
            
                # Get data for each set of coordinates
                merge_df_region = merge_df.query('longitude == @longitude and latitude == @latitude')

                #if len(np.unique(merge_df_region.index.get_level_values(2))) == len(plot_dates):
                    
                summary_region = []
                summary_region_trend = []

                for time in plot_dates:
                
                    # Get data for each time
                    merge_df_region_time = merge_df_region.query('time == @time').reset_index()

                    # Retrieve total columns by pixel and date
                    if not merge_df_region_time.empty:

                        # Calculate number of months since start date
                        merge_df_region_time['month'] = merge_df_region_time.apply(lambda row: ((row['time'].year - start_date.year) * 12 + 
                                                                                                (row['time'].month - start_date.month)), axis = 1)
                        merge_df_region_time = merge_df_region_time.reset_index()

                        month = int(np.unique(merge_df_region_time['month']))

                        # Update summary
                        summary_region.append({'latitude': latitude, 
                                               'longitude': longitude,
                                               'time': time,
                                               'month': month,
                                               'model column': merge_df_region_time['model_column'].iloc[0],
                                               'sensor column': merge_df_region_time['sensor_column'].iloc[0]
                                              })

                summary_region = pd.DataFrame(summary_region)

                # Create axes to get the limits for the linear regression
                fig, ax = plt.subplots(figsize = (30, 7))
                fig.set_facecolor('w')

                if sensor_break_date == None and model_break_date == None:
                    sources = ['sensor', 'model']

                elif sensor_break_date != None and model_break_date != None:
                    sources = ['sensor-1', 'sensor-2', 'model-1', 'model-2']

                for source in sources:
                    
                    # Get data to generate the trend fits
                    if sensor_break_date == None and model_break_date == None:
                        X = summary_region['month'].values
                        Y = summary_region[source + ' column'].values

                    elif sensor_break_date != None and model_break_date != None and sensor_break_date == model_break_date:
                        if source == 'model-1':
                            X = summary_region[summary_region['time'] <= model_break_date]['month'].values
                            Y = summary_region[summary_region['time'] <= model_break_date]['model column'].values

                        elif source == 'model-2':
                            X = summary_region[summary_region['time'] > model_break_date]['month'].values
                            Y = summary_region[summary_region['time'] > model_break_date]['model column'].values

                        elif source == 'sensor-1':
                            X = summary_region[summary_region['time'] <= sensor_break_date]['month'].values
                            Y = summary_region[summary_region['time'] <= sensor_break_date]['sensor column'].values

                        elif source == 'sensor-2':
                            X = summary_region[summary_region['time'] > sensor_break_date]['month'].values
                            Y = summary_region[summary_region['time'] > sensor_break_date]['sensor column'].values

                    # Plot lines and scatter
                    ax.scatter(X, Y, s = 100)
                    plt.close(fig)

                    # Linear model
                    linear_fit_X, linear_fit_Y, B, A, linear_R2, linear_RMSE, linear_MSE = linear_regression(X.reshape(-1, 1), Y.reshape(-1, 1), component_nom, ax)

                    # Calculate number of months in study period
                    start_period = X[0]
                    end_period = X[-1]
                    period = start_period - end_period

                    # Calculate concentration difference in study period
                    start_conc = A + B*start_period
                    end_conc = A + B*end_period
                    diff_conc = start_conc - end_conc

                    # Calculate change rate
                    change_rate_units = (diff_conc * 12) / period
                    change_rate_100 = (diff_conc * 12) * 100 / (start_conc * period)

                    # Update summary
                    summary_region_trend.append({'latitude': latitude, 
                                                 'longitude': longitude,
                                                 'source': source,
                                                 'rate (' + units + ' y-1)': change_rate_units,
                                                 'rate (% y-1)': change_rate_100,
                                                 'A': A, 'B': B
                                                })

                summary_region_trend = pd.DataFrame(summary_region_trend)
                trends_table.append(summary_region_trend)

        trends_table = pd.concat(trends_table)       
        
        if sensor_break_date == None and model_break_date == None:

            # Get sources (sensor and model)
            source_1 = sources[0]
            source_2 = sources[1]

            # Get number of months at start and end
            start_period = summary_region['month'].values[0]
            end_period = summary_region['month'].values[-1]
            
            # Generate maps
            trend_maps(trends_table, component_nom, sensor, model, sensor_type, model_type, 
                        units, plot_dates, pad, y, source_1, source_2, start_period, end_period, 
                        vmin_manual_rate_units, vmax_manual_rate_units, 
                        vmin_manual_rate_100, vmax_manual_rate_100,
                        vmin_manual_diff_units, vmax_manual_diff_units, 
                        vmin_manual_diff_100, vmax_manual_diff_100, width_lon, height_lat,
                        bbox_list, coords_list, regions_names)

        elif sensor_break_date != None and model_break_date != None and sensor_break_date == model_break_date:
            
            # Swap source names
            sources[1], sources[2] = sources[2], sources[1]
            j = 0

            for i in range(2):

                # Get sources in pairs (sensor-1 and model-1, sensor-2 and model-2)
                source_1 = sources[j]
                source_2 = sources[j + 1]
                
                # Get number of months at start and end
                if i == 0:
                    start_period = summary_region[summary_region['time'] <= model_break_date]['month'].values[0]
                    end_period = summary_region[summary_region['time'] <= model_break_date]['month'].values[-1]
                elif i == 1:
                    start_period = summary_region[summary_region['time'] > model_break_date]['month'].values[0]
                    end_period = summary_region[summary_region['time'] > model_break_date]['month'].values[-1]
                
                # Generate maps
                trend_maps(trends_table, component_nom, sensor, model, sensor_type, model_type, 
                           units, plot_dates, pad, y, source_1, source_2, start_period, end_period, 
                           vmin_manual_rate_units, vmax_manual_rate_units, 
                           vmin_manual_rate_100, vmax_manual_rate_100,
                           vmin_manual_diff_units, vmax_manual_diff_units, 
                           vmin_manual_diff_100, vmax_manual_diff_100, width_lon, height_lat,
                           bbox_list, coords_list, regions_names)
                j += 2   

    return trends_table