# General functions

In [2]:
def comparison_check(sensor, model, component_nom, model_full_name, sensor_type):

    """ Check if the comparison is possible

        Args:
            sensor (str): Name of the sensor
            model (str): Name of the model
            component_nom (str): Component chemical nomenclature
            model_full_name (str): Full name of the CAMS model among:
            - 'cams-global-atmospheric-composition-forecasts' 
            - 'cams-global-reanalysis-eac4-monthly'
            sensor_type (str): Sensor type
    """

    if ((sensor == 'tropomi' and model == 'cams' and model_full_name == 'cams-global-atmospheric-composition-forecasts') or 
        (sensor == 'iasi' and sensor_type == 'L2' and model == 'cams' and model_full_name == 'cams-global-atmospheric-composition-forecasts') or
        (sensor == 'iasi' and sensor_type == 'L3' and model == 'cams' and model_full_name == 'cams-global-reanalysis-eac4-monthly') or
        (sensor == 'gome' and sensor_type == 'L2' and model == 'cams' and model_full_name == 'cams-global-atmospheric-composition-forecasts') or
        (sensor == 'gome' and sensor_type == 'L3' and model == 'cams' and model_full_name == 'cams-global-reanalysis-eac4-monthly')):

        if (model_full_name != 'cams-global-atmospheric-composition-forecasts' and
            model_full_name != 'cams-global-reanalysis-eac4-monthly'):

            print('ERROR: The model is not supported.')
            print('The models that are currently supported are:')
            print('- cams-global-atmospheric-composition-forecasts')
            print('- cams-global-reanalysis-eac4-monthly')
            raise KeyboardInterrupt()

        else:
            
            tropomi_component_nom = ['NO2', 'CO', 'O3', 'SO2']
            iasi_L2_component_nom = ['O3', 'CO']
            iasi_L3_component_nom = ['CO', 'O3']
            gome_L2_component_nom = ['NO2', 'O3', 'HCHO']
            gome_L3_component_nom = ['NO2']

            if ((sensor == 'tropomi' and component_nom not in tropomi_component_nom) or
                (sensor == 'iasi' and sensor_type == 'L2' and component_nom not in iasi_L2_component_nom) or
                (sensor == 'iasi' and sensor_type == 'L3' and component_nom not in iasi_L3_component_nom) or
                (sensor == 'gome' and sensor_type == 'L2' and component_nom not in gome_L2_component_nom) or
                (sensor == 'gome' and sensor_type == 'L3' and component_nom not in gome_L3_component_nom)):

                print(f'ERROR: This specific component cannot be retrieved by the sensor {sensor.upper()} ({sensor_type}).')
                raise KeyboardInterrupt()

            else:

                print('The comparison is possible and will start now.')
    else:

        print('Currently, it is possible to compare:')
        print('1 - Forecast data from CAMS model and L2 data from TROPOMI, IASI and GOME-2 sensors.')
        print('2 - Reanalysis data from CAMS model and L3 monthly data from IASI and GOME-2 sensors.')

        raise KeyboardInterrupt()

In [3]:
def components_table(sensor, component_nom, sensor_type):

    """ Create table with information about the components (molecular weight, full name in different datasets)

        Args:
            sensor (str): Name of the sensor
            component_nom (str): Component chemical nomenclature
            sensor_type (str): Sensor type

        Returns:
            component (str): Component name
            component_mol_weight (float): Component molecular weight
            component_sensor_product (str): Component product name in TROPOMI, IASI or GOME-2 database
            sensor_column (str): Component column name in TROPOMI, IASI or GOME-2 database
    """

    sensor_product_type = None

    component_nom_col = ['NO2', 'CO', 'O3', 'SO2', 'CH4', 'HCHO', 'NH3']

    component_col = ['nitrogen_dioxide', 'carbon_monoxide', 'ozone', 'sulfur_dioxide', 'methane', 'formaldehyde', 'ammonia']
    component_mol_weight_col = [46.005, 28.01, 48, 64.066, 16.04, 30.031, 17.031]
    component_tropomi_product_col = ['L2__NO2___', 'L2__CO____', 'L2__O3____', 'L2__SO2___', 'L2__CH4___', '-', '-']
    component_tropomi_column_col = ['nitrogendioxide_tropospheric_column', 
                                    'carbonmonoxide_total_column', 
                                    'ozone_total_vertical_column', 
                                    'sulfurdioxide_total_vertical_column',
                                    'methane_tropospheric_column',
                                    '-',
                                    '-'
                                    ]
    component_iasi_L3_column_col = ['-', 'COgridDAY', 'O3gridDAY', '-', '-', '-', 'NH3gridDAY']
    component_iasi_L2_column_col = ['-', 'CO_total_column', 'O3_total_column', '-', '-', '-', '']
    component_gome_L3_column_col = ['NO2trop', '-', 'O3total', '-', '-', 'HCHOtotal', '-']
    component_gome_L2_column_col = component_gome_L3_column_col

    rows = {'Nomenclature': component_nom_col, 
            'Weight': component_mol_weight_col,
            'Component': component_col, 
            'TROPOMI_L2_product': component_tropomi_product_col,
            'TROPOMI_L2_column': component_tropomi_column_col,
            'IASI_L3_column': component_iasi_L3_column_col,
            'IASI_L2_column': component_iasi_L2_column_col,
            'GOME_L3_column': component_gome_L3_column_col,
            'GOME_L2_column': component_gome_L2_column_col}

    components_table = pd.DataFrame(rows)

    component = components_table['Component'].loc[components_table['Nomenclature'] == component_nom].iloc[0]
    component_mol_weight = components_table['Weight'].loc[components_table['Nomenclature'] == component_nom].iloc[0]
    
    if sensor == 'tropomi':
        sensor_product_type = components_table['TROPOMI_L2_product'].loc[components_table['Nomenclature'] == component_nom].iloc[0]

    sensor_column = components_table[sensor.upper() + '_' + sensor_type + '_column'].loc[components_table['Nomenclature'] == component_nom].iloc[0]

    return component, component_mol_weight, sensor_product_type, sensor_column

In [4]:
def generate_folders(model, sensor, component_nom, sensor_type):

    """ Generate folders to download the datasets if they do not exist 

        Args:
            model (str): Name of the model
            sensor (str): Name of the sensor
            component_nom (str): Component chemical nomenclature
            sensor_type (str): Sensor type
    """

    model_path = os.path.join('/', '/'.join(os.getcwd().split('/')[1:3]), 'adc-toolbox', os.path.relpath('data/' + model + '/' + component_nom))
    
    if sensor_type == 'L3':
        sensor_path = os.path.join('/', '/'.join(os.getcwd().split('/')[1:3]), 'adc-toolbox', os.path.relpath('data/' + sensor + '/' + component_nom + '/monthly/'))
        
    else:
        sensor_path = os.path.join('/', '/'.join(os.getcwd().split('/')[1:3]), 'adc-toolbox', os.path.relpath('data/' + sensor + '/' + component_nom))

    paths = [model_path, sensor_path]

    for path in paths:
        os.makedirs(path, exist_ok = True)

In [5]:
def search_period(start_date, end_date, sensor, sensor_type):

    """ Give list or tuple with dates that will be used to download the datasets

        Args:
            start_date (str): Query start date
            end_date (str): Query end date
            sensor (str): Name of the sensor
            sensor_type (str): Sensor type

        Returns:
            dates (list or tuple): Query dates
    """

    print('SEARCH PERIOD')

    start_date_dt = dt.datetime(int(start_date.split('-')[0]), int(start_date.split('-')[1]), int(start_date.split('-')[2]), 0, 0, 0)
    end_date_dt = dt.datetime(int(end_date.split('-')[0]), int(end_date.split('-')[1]), int(end_date.split('-')[2]), 0, 0, 0)
    range_dt = pd.date_range(start_date_dt, end_date_dt)

    if (sensor == 'gome' and sensor_type == 'L2') or (sensor == 'iasi' and sensor_type == 'L2'):
        dates = tuple(np.unique([date.strftime('%Y-%m-%d') for date in range_dt]))
        print(f'- In days: {dates}')

    elif (sensor == 'gome' and sensor_type == 'L3') or (sensor == 'iasi' and sensor_type == 'L3'):
        dates = tuple(np.unique([date.strftime('%Y-%m') for date in range_dt]))
        print(f'- In months: {dates}')

    elif sensor == 'tropomi':
        range_dt_initial = range_dt
        range_dt_final = range_dt_initial + dt.timedelta(hours = 23)
        dates = list(zip([date.strftime('%Y-%m-%dT%H:%M:%SZ') for date in range_dt_initial], 
                        [date.strftime('%Y-%m-%dT%H:%M:%SZ') for date in range_dt_final]))
        print(f'- In days: {dates}')
    
    return dates

In [None]:
def search_bbox(lon_min, lat_min, lon_max, lat_max):

    """ Generate bounding box from coordinates
        
        Args:
            lon_min (float): Minimum longitude
            lat_min (float): Minimum latitude
            lon_max (float): Maximum longitude
            lat_max (float): Maximum latitude
            
        Returns:
            bbox (arr): Query bounding box

    """

    bbox = ((lon_min, lat_min), (lon_max, lat_max))

    print('SEARCH BOUNDING BOX')
    print(f'Latitudes: from {lat_min} to {lat_max}')
    print(f'Longitudes: from {lon_min} to {lon_max}')

    return bbox

In [None]:
def available_period(sensor, sensor_type, dates, component_nom, *args):

    """ Remove dates if the folders where the dataset had to be downloaded are empty (dataset not available)
        
        Args:
            sensor (str): Name of the sensor
            sensor_type (str): Sensor type
            dates (list): Query dates
            component_nom (str): Component chemical nomenclature
            *args: satellites
            
        Returns:
            dates (list or tuple): Available dates

    """
        
    dates_to_delete = []

    for date in dates:
        
        if sensor_type == 'L3':
            output_path = os.path.join('/', '/'.join(os.getcwd().split('/')[1:3]), 'adc-toolbox', os.path.relpath('data/' + sensor + '/' + component_nom + '/monthly/' + date))

        elif sensor_type == 'L2':
            output_path = os.path.join('/', '/'.join(os.getcwd().split('/')[1:3]), 'adc-toolbox', os.path.relpath('data/' + sensor + '/' + component_nom + '/' + date))

        if not os.listdir(output_path):
            os.rmdir(output_path)
            dates_to_delete.append(date)

        if sensor_type == 'L2' and sensor == 'gome':
            for satellite in satellites:
                if not os.listdir(output_path + '/' + satellite):
                    os.rmdir(output_path+ '/' + satellite)

    dates_to_keep = np.setdiff1d(dates, np.array(dates_to_delete))
    dates = tuple(dates_to_keep)
    
    return dates

In [6]:
def sensor_download(sensor, sensor_type, component_nom, dates, *args):

    """ Download sensor datasets

        Args:
            sensor (str): Name of the sensor
            sensor_type (str): Sensor type
            component_nom (str): Component chemical nomenclature
            dates (list or tuple): Available dates
            *args: bbox, satellites or product_type

        Returns:
            dates (list): Available dates
    """ 

    print('RESULTS')
    
    if sensor == 'tropomi':

        for date in dates:

            print(f'For {date}:')
            input_type = 'Query'
            TROPOMI_download(input_type, bbox, date, product_type, component_nom)

    elif sensor == 'iasi' or sensor == 'gome':
        
        for date in dates:
            
            print(f'For {date}:')

            for satellite in satellites:

                if sensor == 'iasi' and sensor_type == 'L2':
                    IASI_L2_download(component_nom, date, satellite)

                elif sensor == 'iasi' and sensor_type == 'L3':
                    IASI_L3_download(component_nom, date, satellite)
        
                elif sensor == 'gome' and sensor_type == 'L2':
                    GOME_L2_download(component_nom, date, satellite)

                elif sensor == 'gome' and sensor_type == 'L3':
                    GOME_L3_download(component_nom, date, satellite)

        dates = available_period(sensor, sensor_type, dates, component_nom, satellites)

    return dates

In [7]:
def sensor_read(sensor, sensor_type, sensor_column, component_nom, dates, *args):

    """ Read sensor datasets as xarray dataset objects

        Args:
            sensor (str): Name of the sensor
            sensor_type (str): Sensor type
            component_nom (str): Component chemical nomenclature
            dates (list or tuple): Available dates
            *args: satellites, lat_res, lon_res

        Returns:
            sensor_ds (xarray): sensor dataset in xarray format
            support_input_ds (xarray): TROPOMI dataset that contains support input data in xarray format
            support_details_ds (xarray): TROPOMI dataset that contains support details data in xarray format
    """ 
    
    support_input_ds = None
    support_details_ds = None

    if dates:

        if sensor == 'tropomi':
            sensor_ds, support_input_ds, support_details_ds = TROPOMI_read(dates, component_nom, sensor_column)
        
        elif sensor == 'iasi' and sensor_type == 'L2':
            sensor_ds = IASI_L2_read(component_nom, sensor_column, dates, lat_res, lon_res)

        elif sensor == 'iasi' and sensor_type == 'L3':
            sensor_ds = IASI_L3_read(component_nom, sensor_column, dates, lat_res, lon_res)

        elif sensor == 'gome' and sensor_type == 'L2':
            sensor_ds = GOME_L2_read(component_nom, dates, lat_res, lon_res)

        elif sensor == 'gome' and sensor_type == 'L3':
            sensor_ds = GOME_L3_read(component_nom, sensor_column, dates, lat_res, lon_res)

    else:
        print('The datasets could not be downloaded for the dates that were queried.')
        
    return sensor_ds, support_input_ds, support_details_ds

In [8]:
def sensor_convert_units(sensor_ds, sensor, component):

    """ Convert the units of the sensor dataset for any component from mol/m2 to molecules/cm2

        Args:
            sensor_ds (xarray): sensor dataset in xarray format (TROPOMI, IASI or GOME-2)
            sensor (str): Name of the sensor
            component (str): Component name
            
        Returns:
            sensor_ds (xarray): sensor dataset in xarray format
    """

    if sensor == 'tropomi':
        
        if sensor_ds['sensor_column'].units == 'mol m-2':

            sensor_ds['sensor_column'] = sensor_ds['sensor_column'] * 6.02214*10**19
            sensor_ds['sensor_column'] = sensor_ds['sensor_column'].assign_attrs({'units': 'molec cm-2'})
            print('The sensor component units have been converted from mol cm-2 to molec cm-2.')
            
            if 'apriori_profile' in list(sensor_ds.keys()):
                sensor_ds['apriori_profile'] = sensor_ds['apriori_profile'] * 6.02214*10**19

            if sensor_ds['sensor_column'].units == 'molec cm-2' and component == 'ozone':
                sensor_ds['sensor_column'] = sensor_ds['sensor_column'] / (2.69*10**16)
                sensor_ds['sensor_column'] = sensor_ds['sensor_column'].assign_attrs({'units': 'DU'})
                print('The sensor component units have been converted from molec cm-2 to DU.')

                if 'apriori_profile' in list(sensor_ds.keys()):
                    sensor_ds['apriori_profile'] = sensor_ds['apriori_profile'] / (2.69*10**16)

    elif sensor == 'iasi':
        
        if sensor_ds.units == 'mol m-2':

            sensor_ds = sensor_ds * 6.02214*10**19
            sensor_ds = sensor_ds.assign_attrs({'units': 'molec cm-2'})
            print('The sensor component units have been converted from mol cm-2 to molec cm-2.')

        if sensor_ds.units == 'molec cm-2' and component == 'ozone':
            sensor_ds = sensor_ds / (2.69*10**16)
            sensor_ds = sensor_ds.assign_attrs({'units': 'DU'})
            print('The sensor component units have been converted from molec cm-2 to DU.')

    return sensor_ds

In [9]:
def model_convert_units(model, model_ds, sensor, component_mol_weight, component, model_levels_df, 
                        start_date, end_date, component_nom, apply_kernels = False, 
                        CAMS_UID = None, CAMS_key = None):

    """ Convert the units of the model dataset for any component from kg/kg or kg/m2 to molecules/cm2

        Args:
            model (str): Name of the model
            model_ds (xarray): model dataset in xarray format (CAMS)
            sensor (str): Name of the sensor
            component_mol_weight (float): Component molecular weight
            component (str): Component name
            model_levels_df (dataframe): Table with 137 CAMS levels data
            start_date (str): Query start date
            end_date (str): Query end date
            component_nom (str): Component chemical nomenclature
            apply_kernels (bool): Apply (True) or not (False) the averaging kernels 
            CAMS_UID (str): ADS user ID
            CAMS_key (str): ADS key
            
        Returns:
            model_ds (xarray): model dataset in xarray format
    """

    if model == 'cams':

        if model_ds.component.units == 'kg kg**-1':

            model_ds = CAMS_kg_kg_to_kg_m2(model_ds, model_levels_df, sensor, start_date, 
                                           end_date, component_nom, apply_kernels, CAMS_UID, CAMS_key)
            units = 'kg m**-2'
            model_ds['component'] = model_ds.component.assign_attrs({'units': units})
            print('The model component units have been converted from kg kg**-1 to kg m**-2.')
            
        if model_ds.component.units == 'kg m**-2':

            model_ds = CAMS_kg_m2_to_molecules_cm2(model_ds, component_mol_weight)
            units = 'molec cm-2'
            model_ds['component'] = model_ds.component.assign_attrs({'units': units})
            print('The model component units have been converted from kg m**-2 to molec cm-2.')

        if model_ds.component.units == 'molec cm-2' and component == 'ozone':

            model_ds = CAMS_molecules_cm2_to_DU(model_ds)
            units = 'DU'
            model_ds['component'] = model_ds.component.assign_attrs({'units': units})
            print('The model component units have been converted from molec cm-2 to DU.')
           
        else:
            units = 'molec cm-2'

    return model_ds, units

In [10]:
def nearest_neighbour(array, value):

    """ Find index of the closest value in a 1D-array

        Args:
            array (arr): Array to find the nearest neighbour
            value (float): Search value
    """

    index = np.abs([x - value for x in array]).argmin(0)
    
    return index

In [11]:
def closest_point(point, array):

    """ Find pair the closest values in a 2D-array

        Args:
            array (arr): Array to find the nearest neighbour
            point (tuple): Search coordinates
    """

    pair = array[cdist([point], array).argmin()]

    return pair

In [12]:
def pairwise(dates):

    """ Split dates array in pairs

        Args:
            dates (arr): All dates

        Returns:
            period (tuple): Divisible dates into pairs
    """

    pair_element = iter(dates)
    period = list(zip(pair_element, pair_element))

    return period

In [13]:
def subset(ds, bbox, sensor, component_nom, subset_type):

    """ Subset any dataset (with latitude and longitude as coordinates) into desired bounding box.

        Args:
            ds (xarray): Dataset in xarray format
            bbox (arr): Query bounding box
    
        Returns:
            ds (xarray): Dataset in xarray format
    """

    if sensor == 'tropomi' and subset_type == 'sensor_subset':

        ds = TROPOMI_subset(ds, bbox, component_nom)

    else:

        # Get nearest longitude and latitude to bbox
        lon_min_index = nearest_neighbour(ds.longitude.data, bbox[0][0])
        lon_max_index = nearest_neighbour(ds.longitude.data, bbox[1][0])
        lat_min_index = nearest_neighbour(ds.latitude.data, bbox[0][1])
        lat_max_index = nearest_neighbour(ds.latitude.data, bbox[1][1])

        # Define slices
        slice_lat = slice(lat_min_index, lat_max_index + 1)
        slice_lon = slice(lon_min_index, lon_max_index + 1)

        # Set limits
        ds = ds.isel(longitude = slice_lon, latitude = slice_lat)

    return ds

In [14]:
def prepare_df(match_df, sensor, component_nom, time):

    """ Prepare dataframe for match

        Args:
            match_df (dataframe): Dataframe used to apply averaging kernels
            sensor (str): Name of the sensor
            component_nom (str): Component chemical nomenclature
            time (timestamp): Current time
        
        Returns:
            match_df (dataframe): Dataframe used to apply averaging kernels
    """

    if sensor == 'tropomi':

        # Pass NaNs to data with qa_value under 0.5 (these values will be shown as transparent)
        match_df.loc[match_df['qa_value'] <= 0.5, ['sensor_column', 'column_kernel']] = float('NaN')

        # Drop levels
        if component_nom == 'CO' or component_nom == 'SO2':
            
            match_df.index.names = ['corner', 'ground_pixel', 'layer', 'scanline']
        
        elif component_nom == 'O3':

            match_df.index.names = ['corner', 'ground_pixel', 'layer', 'level', 'scanline']
            
        match_df = match_df.groupby(by = ['layer', 'scanline', 'ground_pixel', 'time', 'delta_time']).mean()
        match_df = match_df.reset_index(level = ['layer', 'delta_time'])

    elif sensor == 'iasi' or sensor == 'gome':

        match_df = match_df.reset_index(level = ['latitude', 'longitude'])
        
        if sensor == 'gome' and sensor_type == 'L2':

            year = time.astype('datetime64[D]').astype(str).split('-')[0]
            month = time.astype('datetime64[D]').astype(str).split('-')[1]
            day = time.astype('datetime64[D]').astype(str).split('-')[2]
            match_df['delta_time'] = match_df['delta_time'].fillna(value = dt.datetime(int(year), 
                                                                                       int(month), 
                                                                                       int(day), 
                                                                                       12, 0, 0))
        
    return match_df

In [15]:
def generate_match_table(sensor_ds, model_ds, bbox, sensor, component_nom, apply_kernels = False):

    """ Intermediate merge table with total column or partial column from both datasets, 
        the averaging kernels are applied if possible

        Args:
            sensor_ds (xarray): sensor dataset in xarray format (TROPOMI, IASI or GOME-2)
            model_ds (xarray): model dataset in xarray format (CAMS)
            bbox (arr): Query bounding box
            sensor (str): Name of the sensor
            component_nom (str): Component chemical nomenclature
            apply_kernels (bool): Apply (True) or not (False) the averaging kernels 

        Returns:
            match_table (dataframe): Intermediate merge table with total column or partial column from both datasets
    """
    
    match_table = pd.DataFrame()

    if sensor == 'tropomi' and apply_kernels == True:

        print('APPLICATION OF AVERAGING KERNELS')
        print('For the application of the averaging kernels, it is necessary to calculate:')
        print('1. Level pressures')
        print('2. Column kernels')
        print('The apriori profiles should be retrieved, but they are not necessary.')

        # Calculate TM5 level pressures, column kernels and apriori profiles
        print('DATA AVAILABILITY')
        sensor_ds = TROPOMI_pressure(sensor_ds, component_nom, support_input_ds, support_details_ds)
        sensor_ds = TROPOMI_column_kernel(sensor_ds, component_nom, support_details_ds)
        sensor_ds = TROPOMI_apriori_profile(sensor_ds, component, support_details_ds)

    for time in sensor_ds.time.values:
        
        # Print estimated time or month
        if sensor_type == 'L2':
            day = np.datetime64(time).astype('datetime64[D]')
            print(f'FOR DATE: {day}')

        elif sensor_type == 'L3':
            month = np.datetime64(time).astype('datetime64[M]')
            print(f'FOR MONTH: {month}')

        # Reduce data to only one timestamp
        model_ds_time = model_ds.sel(time = time)
        sensor_ds_time = sensor_ds.sel(time = time)

        # Subset sensor dataset
        sensor_ds_time = subset(sensor_ds_time, bbox, sensor, component_nom, subset_type = 'sensor_subset')
        
        # Transform sensor data into dataframe and prepare it for merging it with the model data
        match_df = sensor_ds_time.to_dataframe()
        match_df = prepare_df(match_df, sensor, component_nom, time)
        
        if sensor == 'tropomi' and 'column_kernel' in list(sensor_ds.keys()) and apply_kernels == True:

            match_df = TROPOMI_apply_kernels(match_df, model_ds_time, sensor_ds_time, component_nom)
            
        else:
            
            if apply_kernels == True:
                print('The application of the averaging kernels cannot take place because there is no data of the column kernels.')

            if 'hybrid' in list(model_ds.coords):

                print('The partial columns will be sumed up.')
                print('The sum will be matched to the sensor data by nearest neighbours.')

                model_ds_time = model_ds_time.component.sum(dim = 'hybrid', skipna = False)
                model_times = model_ds_time.valid_time.data
                
                match_df['step_index'] = match_df.apply(lambda row: nearest_neighbour(model_times, row['delta_time']), axis = 1)
                match_df['model_time'] = match_df.apply(lambda row: model_ds_time.valid_time[row['step_index']].values, axis = 1)
                match_df['model_column'] = match_df.apply(lambda row: model_ds_time.sel(
                                                                      latitude = row['latitude'], 
                                                                      longitude = row['longitude'],
                                                                      method = 'nearest').isel(step = 
                                                                      int(row['step_index'])).values, 
                                                                      axis = 1)
   
                match_df = match_df.set_index('layer', append = True)
                
            else:

                print('The model dataset does not contain levels data.')
                print('The model dataset will be merged with the sensor dataset by nearest neighbours.')

                model_times = model_ds_time.valid_time.data

                # Monthly data
                if 'step' not in list(model_ds.dims):
                    
                    match_df['model_column'] = match_df.apply(lambda row: float(model_ds_time.sel(
                                                                                latitude = row['latitude'], 
                                                                                longitude = row['longitude'],
                                                                                method = 'nearest').component.values), 
                                                                                axis = 1)
                # Hourly / Daily data
                else:

                    match_df['step_index'] = match_df.apply(lambda row: nearest_neighbour(model_times, row['delta_time']), axis = 1)
                    match_df['model_column'] = match_df.apply(lambda row: float(model_ds_time.sel(
                                                                                latitude = row['latitude'], 
                                                                                longitude = row['longitude'],
                                                                                method = 'nearest').isel(step = 
                                                                                int(row['step_index'])).component.values), 
                                                                                axis = 1)

        match_df = match_df[~match_df.index.duplicated()]
        match_table = match_table.append(match_df)

    return match_table

In [16]:
def generate_merge_table(match_table, sensor_ds, model_ds, sensor, apply_kernels = False):

    """ Final merge table with total column component data for each dataset, 
        their difference in each grid point are calculated

        Args:
            match_table (dataframe): Intermediate merge table with total column or partial column from both datasets
            sensor_ds (xarray): sensor dataset in xarray format (TROPOMI, IASI or GOME-2)
            model_ds (xarray): model dataset in xarray format (CAMS)
            apply_kernels (bool): Apply (True) or not (False) the averaging kernels
            sensor (str): Name of the sensor
        
        Returns:
            merge_table (dataframe): Merge table with datasets column data and their difference
    """

    merge_table = []

    if 'hybrid' in list(model_ds.coords):

        for time in sensor_ds.time.values:

            match_ds = match_table.query('time == @time').to_xarray()

            # Read latitudes and longitudes from data array
            latitude = match_ds.sel(time = time).latitude.mean(dim = 'layer')
            longitude = match_ds.sel(time = time).longitude.mean(dim = 'layer')

            # Get sum of CAMS data of each layer to get column data
            if 'column_kernel' in list(match_ds.keys()) and apply_kernels == True:
                model_final_ds_time = match_ds.sel(time = time).model_column.sum(dim = 'layer', skipna = False).astype(float)

            else:
                model_final_ds_time = match_ds.sel(time = time).model_column.mean(dim = 'layer', skipna = False).astype(float)

            model_final_ds_time = model_final_ds_time.assign_coords(latitude = latitude, longitude = longitude)

            # Get mean of TROPOMI data of each layer (it must be equal)
            sensor_final_ds_time = match_ds.sensor_column.sel(time = time).mean(dim = 'layer', skipna = False).astype(float)
            sensor_final_ds_time = sensor_final_ds_time.assign_coords(latitude = latitude, longitude = longitude)

            merge_ds_time = xr.merge([model_final_ds_time, sensor_final_ds_time])
            merge_ds_time['difference'] = merge_ds_time.sensor_column - merge_ds_time.model_column
            merge_table.append(merge_ds_time.to_dataframe())

        merge_table = pd.concat(merge_table)

    else:

        merge_table = match_table
        merge_table['difference'] = merge_table['sensor_column'] - merge_table['model_column']

    # Organize dataset for visualization
    if sensor == 'tropomi':
        merge_table = merge_table.reset_index().set_index(['scanline', 'ground_pixel', 'time'])
        merge_table = merge_table[['latitude', 'longitude', 'model_column', 'sensor_column', 'difference']]

    else:
        merge_table = merge_table.reset_index().set_index(['latitude', 'longitude', 'time'])
        merge_table = merge_table[['model_column', 'sensor_column', 'difference']]
    
    return merge_table

In [17]:
def plot_period(sensor_ds, sensor):

    """ Define plot period

        Args:
            sensor_ds (xarray): sensor dataset in xarray format (TROPOMI, IASI or GOME-2)

        Returns:
            plot_dates (arr): Plot dates
    """

    period_answer = input('Do you want to visualize the plots for specific dates? Press Enter for Yes or write No:')
    plot_dates = []

    if period_answer == 'No' or period_answer == 'no':
        plot_dates = sensor_ds.time.values
    
    else:
        if sensor == 'tropomi':
            options_df = pd.DataFrame({'Date': sensor_ds.time.values})
        
        elif sensor == 'iasi' or sensor == 'gome':
            options_df = pd.DataFrame({'Date': sensor_ds.time.values.astype('datetime64[M]')})

        for index, row in options_df.iterrows():
            date_answer = input('Do you want to show the plots for ' + str(row['Date']) + '? Press Enter for Yes or write No:') 
            if date_answer == 'No' or date_answer == 'no':
                pass
            else:
                plot_dates = np.append(plot_dates, row['Date'])

    print('The plots will be shown for the following dates:')
    if sensor == 'tropomi':
        print(plot_dates)
    
    elif sensor == 'iasi' or sensor == 'gome':
        print(plot_dates.astype('datetime64[M]'))

    return plot_dates

In [18]:
def plot_extent(bbox):

    """ Define plot extent

        Args:
            bbox (arr): Query bounding box

        Returns:
            plot_bbox (arr): Plot bounding box
    """

    extent_answer = input(f'Do you want to visualize the plots for a specific extent? Press Enter for Yes or write No (default {bbox}):')

    if extent_answer == 'No' or extent_answer == 'no':
        plot_bbox = ((bbox[0][0], bbox[0][1]), (bbox[1][0], bbox[1][1]))

    else:
        # Define minimum longitude
        plot_lon_min = float(input('Write value of minimum longitude: '))
        while (plot_lon_min < bbox[0][0]) or (plot_lon_min > bbox[1][0]):
            print(f'ERROR: Longitude must be between {bbox[0][0]} and {bbox[1][0]}.')
            plot_lon_min = float(input('Write value of minimum longitude (again): '))

        # Define maximum longitude
        plot_lon_max = float(input('Write value of maximum longitude: '))
        while (plot_lon_max < bbox[0][0]) or (plot_lon_max > bbox[1][0]) or (plot_lon_max <= plot_lon_min):
            print(f'ERROR: Longitude must be between {bbox[0][0]} and {bbox[1][0]} and be higher than the minimum {plot_lon_min}.')
            plot_lon_max = float(input('Write value of maximum longitude (again): '))

        # Define minimum latitude
        plot_lat_min = float(input('Write value of minimum latitude: '))
        while (plot_lat_min < bbox[0][1]) or (plot_lat_min > bbox[1][1]):
            print(f'ERROR: Latitude must be between {bbox[0][1]} and {bbox[1][1]}.')
            plot_lat_min = float(input('Write value of minimum latitude (again): '))

        # Define maximum latitude
        plot_lat_max = float(input('Write value of maximum latitude: '))
        while (plot_lat_max < bbox[0][1]) or (plot_lat_max > bbox[1][1]) or (plot_lat_max <= plot_lat_min):
            print(f'ERROR: Latitude must be between {bbox[0][1]} and {bbox[1][1]} and be higher than the minimum {plot_lat_min}.')
            plot_lat_max = float(input('Write value of maximum latitude (again): '))

        # Define plot bbox
        plot_bbox = ((plot_lon_min, plot_lat_min), (plot_lon_max, plot_lat_max))

    print('The plots will be shown for the following spatial extent: ')
    print(plot_bbox)
    
    return plot_bbox

In [19]:
def colorbar_range(range_type, merge, array, 
                   max_all, min_all, max_all_diff, min_all_diff,
                   vmin_manual = None, vmax_manual = None):

    """ Define colorbar range

        Args:
            range_type (str): Range type for colorbar:
            -  'original': Show original values in range
            -  'equal': Show same scale in range
            merge (xarray): Merge result for a specific time
            array (xarray): Component for a specific time and model/sensor

        Returns:
            vmin, vmax (float): Limits of color bar
    """
    
    # The colorbar for difference will be defined
    if np.array_equal(array, merge.difference, equal_nan = True) == True:
        
        if np.abs(max_all_diff) >= np.abs(min_all_diff):
            
            vmin = -np.abs(max_all_diff)
            vmax = np.abs(max_all_diff)

        elif np.abs(max_all_diff) < np.abs(min_all_diff):
            
            vmin = -np.abs(min_all_diff)
            vmax = np.abs(min_all_diff)

    # The colorbar will show the original range
    elif range_type == 'original':
      
        vmin = np.nanmin(array)
        vmax = np.nanmax(array)

    # The colorbar will be in the same scale for both datasets
    elif range_type == 'equal':
       
        vmin = min_all
        vmax = max_all
    
    # The colorbar will be in the scale given by the user
    elif range_type == 'manual':
       
        if vmin_manual == None or vmax_manual == None:
            print('ERROR: vmin_manual and vmax_manual have to be defined and cannot be None.')
            raise KeyboardInterrupt()
            
        else:
            vmin = vmin_manual
            vmax = vmax_manual

    return vmin, vmax

In [20]:
def visualize_pcolormesh(fig, axs, data_array, longitude, latitude, projection, color_scale, 
                         pad, long_name, units_name, vmin, vmax, lonmin, lonmax, latmin, latmax):
    
    """ Visualize two datasets side by side

        Args:
            fig: Figure
            axs: Axes of figure
            data_array (xarray): Variable values to plot - It must be 2-dimensional
            longitude (arr): Longitudes within data_array
            latitude (arr): Latitudes within data_array
            projection: Geographical projection
            color_scale (str): Color scale for the color bar
            pad (float): Padding for the subtitles
            long_name (str): Plot name
            units_name (str): Component name and units
            vmin, vmax (float): Limits of color bar
            lonmin, lonmax, latmin, latmax (float): Limits of longitude and latitude values
    """

    palette = copy(plt.get_cmap(color_scale))
    palette.set_under(alpha = 0)
    axs.clear()
    im_ind = axs.pcolormesh(
                        longitude, latitude, data_array, 
                        cmap = palette, 
                        transform = ccrs.PlateCarree(),
                        vmin = vmin,
                        vmax = vmax,
                        norm = colors.Normalize(vmin = vmin, vmax = vmax),
                        shading = 'auto'
                        )
                        
    axs.add_feature(cfeature.BORDERS, edgecolor = 'black', linewidth = 1)
    axs.add_feature(cfeature.COASTLINE, edgecolor = 'black', linewidth = 1)
    axs.gridlines()

    if projection == ccrs.PlateCarree():
        
        axs.set_extent([lonmin, lonmax, latmin, latmax], ccrs.PlateCarree())
        gl = axs.gridlines(draw_labels = True, linestyle = '--')
        gl.top_labels = False
        gl.right_labels = False
        gl.xformatter = LONGITUDE_FORMATTER
        gl.yformatter = LATITUDE_FORMATTER
        gl.xlabel_style = {'size': 16}
        gl.ylabel_style = {'size': 16}

    axs.set_title(long_name, fontsize = 18, pad = pad)
    axs.tick_params(labelsize = 14)
    
    if distribution_type != 'animated':
        
        cbr = fig.colorbar(im_ind, ax = axs, extend = 'both', orientation = 'horizontal', 
                           fraction = 0.05, pad = 0.15)   
        cbr.set_label(units_name, fontsize = 16)
        cbr.ax.tick_params(labelsize = 14)
        cbr.ax.xaxis.get_offset_text().set_fontsize(14)
      
    return im_ind

In [21]:
def create_maps(fig, axs, merge, range_type, sensor, model, sensor_type, model_type, projection, pad, 
                units_name, plot_bbox, color_scale, max_all, min_all, min_all_diff, max_all_diff,
                vmin_manual = None, vmax_manual = None):

    # First plot - CAMS 
    array = merge['model_column']
    vmin, vmax = colorbar_range(range_type, merge, array, max_all, min_all, max_all_diff, min_all_diff, vmin_manual, vmax_manual)
    long_name = model.upper() + ' (' + model_type + ')'
    im1 = visualize_pcolormesh(
                              fig = fig, axs = axs[0], 
                              data_array = array,
                              longitude = array.longitude,
                              latitude = array.latitude,
                              projection = projection,
                              color_scale = color_scale,
                              pad = pad,
                              long_name = long_name,
                              units_name = units_name,
                              vmin = vmin, 
                              vmax = vmax, 
                              lonmin = plot_bbox[0][0],
                              lonmax = plot_bbox[1][0],
                              latmin = plot_bbox[0][1],
                              latmax = plot_bbox[1][1]
                              )

    # Second plot - TROPOMI, IASI or GOME-2
    array = merge['sensor_column']
    vmin, vmax = colorbar_range(range_type, merge, array, max_all, min_all, max_all_diff, min_all_diff, vmin_manual, vmax_manual)
    long_name = sensor.upper() + ' (' + sensor_type + ')'
    im2 = visualize_pcolormesh(
                              fig = fig, axs = axs[1],
                              data_array = array,
                              longitude = array.longitude,
                              latitude = array.latitude,
                              projection = projection,
                              color_scale = color_scale,
                              pad = pad,
                              long_name = long_name,
                              units_name = units_name,
                              vmin = vmin,  
                              vmax = vmax, 
                              lonmin = plot_bbox[0][0],
                              lonmax = plot_bbox[1][0],
                              latmin = plot_bbox[0][1],
                              latmax = plot_bbox[1][1]
                              )

    # Third plot - Differences
    array = merge.difference
    vmin, vmax = colorbar_range(range_type, merge, array, max_all, min_all, max_all_diff, min_all_diff, vmin_manual, vmax_manual)
    long_name = 'Differences plot'
    im3 = visualize_pcolormesh(
                              fig = fig, axs = axs[2],
                              data_array = array,
                              longitude = array.longitude,
                              latitude = array.latitude,
                              projection = projection,
                              color_scale = color_scale,
                              pad = pad,
                              long_name = long_name,
                              units_name = units_name,
                              vmin = vmin,
                              vmax = vmax,
                              lonmin = plot_bbox[0][0],
                              lonmax = plot_bbox[1][0],
                              latmin = plot_bbox[0][1],
                              latmax = plot_bbox[1][1]
                              )
    
    im = [im1, im2, im3]
    
    return im

In [None]:
def define_absolute_limits(merge_table_bbox, vmin_manual, vmax_manual):

    """ Define absolute minimum and maximum within model and sensor datasets.
        Gets manual minimum and maximum or calculates it

        Args:
            merge_table_bbox (dataframe): Merge table for plot bbox
            vmin_manual (float): Input vmin by user
            vmax_manual (float): Input vmax by user

        Returns
            min_all (float): Absolute vmin
            max_all (float): Absolute vmax
            min_all_diff (float): Absolute vmin for difference values
            max_all_diff (float): Absolute vmax for difference values     
    """

    # Define absolute minimum and maximum within model and sensor datasets
    if vmin_manual == None and vmax_manual == None:
        
        min_model = np.nanmin(merge_table_bbox['model_column'])
        max_model = np.nanmax(merge_table_bbox['model_column'])
        min_sensor = np.nanmin(merge_table_bbox['sensor_column'])
        max_sensor = np.nanmax(merge_table_bbox['sensor_column'])
        max_all = max(max_sensor, max_model)
        min_all = min(min_sensor, min_model)

    else:
        
        if np.abs(vmax_manual) >= np.abs(vmin_manual):
            
            min_all = -np.abs(vmax_manual)
            max_all = np.abs(vmax_manual)

        elif np.abs(vmax_manual) < np.abs(vmin_manual):
            
            min_all = -np.abs(vmin_manual)
            max_all = np.abs(vmin_manual)

    # Define absolute minimum and maximum within difference
    min_all_diff = np.nanmin(merge_table_bbox['difference'])
    max_all_diff = np.nanmax(merge_table_bbox['difference'])

    return min_all, max_all, min_all_diff, max_all_diff

In [22]:
def visualize_model_vs_sensor(model, sensor, component_nom, units, merge_table, plot_dates, plot_bbox, pad, y, 
                              model_type, sensor_type, range_type, distribution_type, projection,
                              color_scale, vmin_manual = None, vmax_manual = None):

    """ Plot model and sensor datasets in the study area for the selected dates, 
        along with a plot of the differences

        Args:
            model (str): Name of the model
            sensor (str): Name of the sensor
            component_nom (str): Component chemical nomenclature
            units (str): Component units
            merge_table (dataframe): Merge table with datasets column data and their difference
            plot_dates (arr): Plot dates
            plot_bbox (arr): Plot extent
            pad (float): Padding for the subtitles
            y (float): y-position of main title
            model_type (str): Model type:
            -  'Forecast'
            -  'Reanalysis'
            sensor_type (str): Sensor type
            range_type (str): Range type for colorbar:
            -  'original': Show original values in range
            -  'equal': Show same scale in range
            distribution_type (str): 
            -  'aggregated': Aggregate plots by time
            -  'individual': Show individual plots
            -  'animated: Show animation
            projection: Geographical projection
            color_scale (str): Name of color scale (e.g. coolwarm)
            vmin_manual (float): Input vmin by user
            vmax_manual (float): Input vmax by user
    """
    
    units_name = component_nom + ' (' + units + ')'

    # Get min and max 
    merge_table_bbox = merge_table.query('longitude >= @plot_bbox[0][0] and longitude <= @plot_bbox[1][0] and latitude >= @plot_bbox[0][1] and latitude <= @plot_bbox[1][1]')
    min_all, max_all, min_all_diff, max_all_diff = define_absolute_limits(merge_table_bbox, vmin_manual, vmax_manual)

    if distribution_type == 'aggregated':
            
        merge = merge_table.to_xarray().mean(dim = 'time')
        latitude = merge.latitude
        longitude = merge.longitude
        merge = merge.assign_coords(latitude = latitude, longitude = longitude)

        fig, axs = plt.subplots(1, 3, figsize = (20, 5), subplot_kw = {'projection': projection})
        fig.set_facecolor('w')

        im = create_maps(fig, axs, merge, range_type, sensor, model, sensor_type, model_type, 
                        projection, pad, units_name, plot_bbox, color_scale, max_all, min_all, 
                        min_all_diff, max_all_diff, vmin_manual, vmax_manual)

        fig.suptitle(f'DISTRIBUTION OF {component_nom} (All times)',
                    fontsize = 18, fontweight = 'bold', y = y)

    elif distribution_type == 'individual':
        
        for time in plot_dates:

            merge = merge_table.query('time == @time').to_xarray()
            latitude = merge.sel(time = time).latitude
            longitude = merge.sel(time = time).longitude
            merge = merge.sel(time= time).assign_coords(latitude = latitude, longitude = longitude)

            fig, axs = plt.subplots(1, 3, figsize = (20, 5), subplot_kw = {'projection': projection})
            fig.set_facecolor('w')
            
            im = create_maps(fig, axs, merge, range_type, sensor, model, sensor_type, model_type, 
                            projection, pad, units_name, plot_bbox, color_scale, max_all, min_all, 
                            min_all_diff, max_all_diff, vmin_manual, vmax_manual)
            
            if (sensor == 'iasi' and sensor_type == 'L3') or (sensor == 'gome' and sensor_type == 'L3'):
                month = np.datetime64(time).astype('datetime64[M]')
                fig.suptitle(f'DISTRIBUTION OF {component_nom} (Month: {month})',
                            fontsize = 18, fontweight = 'bold', y = y)

            else:
                day = np.datetime64(time).astype('datetime64[D]')
                fig.suptitle(f'DISTRIBUTION OF {component_nom} (Date: {day})',
                            fontsize = 18, fontweight = 'bold', y = y)
                             
            plt.show()
        
    elif distribution_type == 'animated':

        fig, axs = plt.subplots(1, 3, figsize = (25, 10), subplot_kw = {'projection': projection})
        fig.set_facecolor('w')

        if (sensor == 'iasi' and sensor_type == 'L3') or (sensor == 'gome' and sensor_type == 'L3'):
            month = np.datetime64(plot_dates[0]).astype('datetime64[M]')
            fig_title = fig.text(0.5, 0.95, f'DISTRIBUTION OF {component_nom} (Month: {month})', 
                                 ha = 'center', fontsize = 22, fontweight = 'bold')

        else:
            day = np.datetime64(plot_dates[0]).astype('datetime64[D]')
            fig_title = fig.text(0.5, 0.95, f'DISTRIBUTION OF {component_nom} (Date: {day})', 
                                 ha = 'center', fontsize = 22, fontweight = 'bold')

        time = plot_dates[0]
        merge = merge_table.query('time == @time').to_xarray()
        latitude = merge.sel(time = time).latitude
        longitude = merge.sel(time = time).longitude    
        merge = merge.sel(time= time).assign_coords(latitude = latitude, longitude = longitude)    
        im = create_maps(fig, axs, merge, range_type, sensor, model, sensor_type, model_type,
                         projection, pad, units_name, plot_bbox, color_scale, 
                         max_all, min_all, min_all_diff, max_all_diff, vmin_manual, vmax_manual)

        def animate(i):

            time = plot_dates[i]
            merge = merge_table.query('time == @time').to_xarray()
            latitude = merge.sel(time = time).latitude
            longitude = merge.sel(time = time).longitude
            merge = merge.sel(time= time).assign_coords(latitude = latitude, longitude = longitude)
            im = create_maps(fig, axs, merge, range_type, sensor, model, sensor_type, model_type, 
                             projection, pad, units_name, plot_bbox, color_scale, 
                             max_all, min_all, min_all_diff, max_all_diff, vmin_manual, vmax_manual)     

            if (sensor == 'iasi' and sensor_type == 'L3') or (sensor == 'gome' and sensor_type == 'L3'):
                month = np.datetime64(plot_dates[i]).astype('datetime64[M]')
                fig_title.set_text(f'DISTRIBUTION OF {component_nom} (Month: {month})')

            else:
                day = np.datetime64(plot_dates[i]).astype('datetime64[D]')
                fig_title.set_text(f'DISTRIBUTION OF {component_nom} (Date: {day})')

            return im

        anim = animation.FuncAnimation(fig, animate, frames = len(plot_dates), blit = True, interval = 1000)

        for j in range(0, 3):
           
            cbr = fig.colorbar(im[j], ax = axs[j], extend = 'both', orientation = 'horizontal', fraction = 0.05, pad = 0.15)
            cbr.set_label(units_name, fontsize = 18) 
            cbr.ax.tick_params(labelsize = 16)
            cbr.ax.xaxis.get_offset_text().set_fontsize(16)

        display(HTML(anim.to_jshtml()))
        anim.save('animation.mp4')
        plt.close()

    else:
        print('The distribution type (distribution_type) must be defined as aggregated, individual or animated.')
        raise KeyboardInterrupt()

In [23]:
def visualize_model_original_vs_calculated(model, component_nom, units, merge_table, 
                                           model_total_ds, plot_dates, plot_bbox, pad, y, 
                                           model_type, range_type, projection, color_scale,
                                           vmin_manual = None, vmax_manual = None):

    """ Plot model total columns from the original dataset and the calculated one 
        in the study area for the selected dates

        Args:
            model (str): Name of the model
            component_nom (str): Component chemical nomenclature
            units (str): Component units
            merge_table (dataframe): Merge result
            model_total_ds (xarray): CAMS total columns dataset in xarray format
            plot_dates (arr): Plot dates
            plot_bbox (arr): Plot extent
            pad (float): Padding for the subtitles
            y (float): y-position of main title
            model_type (str): Model type:
            -  'Forecast'
            -  'Reanalysis'
            range_type (str): Range type for colorbar:
            -  'original': Show original values in range
            -  'equal': Show same scale in range
            projection: Geographical projection
            color_scale (str): Name of color scale (e.g. coolwarm)
            vmin_manual (float): Input vmin by user
            vmax_manual (float): Input vmax by user
    """

    units_name = component_nom + ' (' + units + ')'

    # Get min and max before splitting the data into timesteps
    min_model = np.nanmin(merge_table['model_column'])
    max_model = np.nanmax(merge_table['model_column'])
    min_sensor = np.nanmin(model_total_ds.component)
    max_sensor = np.nanmax(model_total_ds.component)
    max_all = max(max_sensor, max_model)
    min_all = min(min_sensor, min_model)

    for time in plot_dates:

        fig, axs = plt.subplots(1, 2, figsize = (20, 5), subplot_kw = {'projection': projection})
        fig.set_facecolor('w')
        
        merge = merge_table.query('time == @time').to_xarray()
        latitude = merge.sel(time = time).latitude
        longitude = merge.sel(time = time).longitude
        merge = merge.sel(time= time).assign_coords(latitude = latitude, longitude = longitude)

        step = 2

        # First plot - CAMS calculated total columns
        array = merge.model_column
        vmin, vmax = colorbar_range(range_type, merge, array, max_all, min_all, units, vmin_manual, vmax_manual)
        long_name = 'CALCULATED TOTAL COLUMNS ' + model.upper() + ' (' + model_type + ')'
        visualize_pcolormesh(
                            fig = fig, axs = axs[0],
                            data_array = array.fillna(-999),
                            longitude = array.longitude,
                            latitude = array.latitude,
                            projection = projection,
                            color_scale = color_scale,
                            pad = pad,
                            long_name = long_name,
                            units_name = units_name,
                            vmin = vmin, 
                            vmax = vmax, 
                            lonmin = plot_bbox[0][0],
                            lonmax = plot_bbox[1][0],
                            latmin = plot_bbox[0][1],
                            latmax = plot_bbox[1][1]
                            )

        # Second plot - CAMS original total columns
        array = model_total_ds.component.isel(step = step).sel(time = time)
        vmin, vmax = colorbar_range(range_type, merge, array, max_all, min_all, units, vmin_manual, vmax_manual)
        long_name = 'ORIGINAL TOTAL COLUMNS ' + model.upper() + ' (' + model_type + ')'
        visualize_pcolormesh(
                            fig = fig, axs = axs[1],
                            data_array = array.fillna(-999),
                            longitude = array.longitude,
                            latitude = array.latitude,
                            projection = projection,
                            color_scale = color_scale,
                            pad = pad,
                            long_name = long_name,
                            units_name = units_name,
                            vmin = vmin,
                            vmax = vmax, 
                            lonmin = plot_bbox[0][0],
                            lonmax = plot_bbox[1][0],
                            latmin = plot_bbox[0][1],
                            latmax = plot_bbox[1][1]
                            )

        day = np.datetime64(time).astype('datetime64[D]')
        fig.suptitle(f'DISTRIBUTION OF {component_nom} (Date: {day})',
                    fontsize = 18, fontweight = 'bold', y = y)
        plt.show()

In [24]:
def get_google_api():

    """ Get Google API key for reverse geocoding (get country given the coordinates)
        
        Returns:
            environ_keys[1]: Google API key
    """

    # Open txt file with three lines:
    # GOOGLE API KEY (first line), GOOGLE CLIENT ID (second line) and GOOGLE CLIENT SECRET (third line)
    keys_file = open('data/keys.txt', 'r')
    keys = keys_file.readlines()
    environ_keys = [key.rstrip() for key in keys]

    # Set environment variables in your system
    os.environ['GOOGLE_API_KEY'] = environ_keys[1]
    os.environ['GOOGLE_CLIENT'] = environ_keys[2]
    os.environ['GOOGLE_CLIENT_SECRET'] = environ_keys[3]

    return environ_keys[1]

In [25]:
def get_season(day):

    """ Get season given the day

        Args:
            day (datetime): Date
        
        Returns:
            season (str): Season of the year
    """

    Y = 2000

    seasons = [('Winter', (dt.date(Y,  1,  1),  dt.date(Y,  3, 20))),
               ('Spring', (dt.date(Y,  3, 21),  dt.date(Y,  6, 20))),
               ('Summer', (dt.date(Y,  6, 21),  dt.date(Y,  9, 22))),
               ('Autumn', (dt.date(Y,  9, 23),  dt.date(Y, 12, 20))),
               ('Winter', (dt.date(Y, 12, 21),  dt.date(Y, 12, 31)))]
            
    day = day.replace(year = Y)

    season = next(season for season, (start, end) in seasons
             if start <= day <= end)
             
    return season

In [26]:
def linear_regression(X, Y, component_nom):

    """ Fit a linear equation to scatter plot between X and Y and print results

        Args:
            X (array): Input sensor component values
            Y (array): Input model component values
            component_nom (str): Component chemical nomenclature
        
        Returns:
            fit_X (array): X in linear equation fit_Y = A * fit_X + B
            fit_Y (array): Y in linear equation fit_Y = A * fit_X + B
            score (float): Coefficient of determination
            coefficient (float): A in linear equation fit_Y = A * fit_X + B
            intercept (float): B in linear equation fit_Y = A * fit_X + B
    """

    score = 'Unknown'
    coefficient = 'Unknown'
    intercept = 'Unknown'

    reg = LinearRegression().fit(X, Y)
    fit_X = np.linspace(np.nanmin(X), np.nanmax(X), 10)
    fit_Y = fit_X * float(reg.coef_) + reg.intercept_
    
    score = reg.score(X, Y)
    coefficient = reg.coef_[0][0]
    intercept = reg.intercept_[0]

    print(f'Fit equation: {component_nom}_model = {component_nom}_sensor * {float(reg.coef_):.2f} + ({float(reg.intercept_):.2E})')
    print(f'Coefficient of determination (R2): {reg.score(X, Y):.2f}')

    return fit_X, fit_Y, score, coefficient, intercept

In [27]:
def scatter_plot_general_settings(component_nom, axs, units, lim_min, lim_max):

    """ Set common settings for scatter plots

        Args:
            component_nom (str): Component chemical nomenclature
            plt (plot): Scatterplot
            units (str): Component units
            lim_min (float): Minimum value of component in scale
            lim_max (float): Maximum value of component in scale
    """

    # Scatter plot
    axs[0].set_xlabel(f'Sensor {component_nom} ({units})', fontsize = 16)
    axs[0].set_ylabel(f'Model {component_nom} ({units})', fontsize = 16)
    axs[0].tick_params(labelsize = 14)
    axs[0].set_xlim([lim_min, lim_max])
    axs[0].set_ylim([lim_min, lim_max])

    # Histograms
    axs[1].set_xlabel(f'Sensor {component_nom} ({units})', fontsize = 16)
    axs[2].set_xlabel(f'Model {component_nom} ({units})', fontsize = 16)
    for i in range(1, 3):
        axs[i].set_ylabel(f'Count', fontsize = 16)
        axs[i].tick_params(labelsize = 14)
        axs[i].set_xlim([lim_min, lim_max])

In [28]:
def scatter_plot(merge_table, component_nom, units, sensor, plot_dates, y, extent_definition, 
                 show_seasons, scatter_plot_type, lim_min = None, lim_max = None, *args):

    """ Scatter plot between the model and sensor datasets in the study area for the selected dates (bbox or countries)

        Args:
            merge_table (dataframe): Merge result
            component_nom (str): Component chemical nomenclature
            units (str): Component units
            sensor (str): Name of the sensor
            plot_dates (arr): Plot dates
            plot_bbox (arr): Plot extent
            y (float): y-position of main title
            extent_definition (str):
            * 'country': Scatter plots for countries list
            * 'bbox': Scatter plots for bbox coordinates
            scatter_plot_type (str):
            * 'aggregated': Aggregate plots by time, country or season
            * 'individual': Individual plots per time, country or season
            *args: plot_countries, plot_bbox, lim_min, lim_max
    """

    sns.color_palette('colorblind', 10)

    if lim_min == None:
        lim_min = min(np.nanmin(merge_table['sensor_column']), np.nanmin(merge_table['model_column']))
    if lim_max == None:
        lim_max = max(np.nanmax(merge_table['sensor_column']), np.nanmax(merge_table['model_column']))

    summary = []

    merge = merge_table
    merge = merge.query('longitude >= @plot_bbox[0][0] and longitude <= @plot_bbox[1][0] and latitude >= @plot_bbox[0][1] and latitude <= @plot_bbox[1][1]')

    if show_seasons == False:

        if extent_definition == 'bbox':

            if scatter_plot_type == 'aggregated':

                # Prepare df
                merge = merge.reset_index()
                merge = merge[merge['time'].isin(plot_dates)]

                if not merge.empty:
                    
                    fig, axs = plt.subplots(1, 3, figsize = (20, 5))
                    fig.set_facecolor('w')
                    
                    # Linear regression
                    X = merge['sensor_column'].values.reshape(-1, 1) 
                    Y = merge['model_column'].values.reshape(-1, 1) 
                    fit_X, fit_Y, score, coefficient, intercept = linear_regression(X, Y, component_nom)
                    axs[0].plot(fit_X, fit_Y, color = 'black', label = 'Linear regression')
                    
                    # Scatter plot and histograms
                    sp = sns.scatterplot(data = merge, x = 'sensor_column', y = 'model_column', 
                                         hue = 'time', ax = axs[0])
                    sns.histplot(data = merge, x = 'sensor_column', kde = True,  ax = axs[1])
                    sns.histplot(data = merge, x = 'model_column', kde = True,  ax = axs[2])

                    scatter_plot_general_settings(component_nom, axs, units, lim_min, lim_max)
                    fig.suptitle(f'{component_nom} (All times)', fontsize = 18, fontweight = 'bold', y = y)

                    # Scatter plot legend
                    if sensor_type == 'L3':

                        leg = sp.legend(loc = 'upper center', bbox_to_anchor = (0.5, -0.2),
                                        fancybox = True, ncol = 3, fontsize = 14)
                        leg.set_title('Fit and months', prop = {'size': 14})

                    else:
                                        
                        leg = sp.legend(loc = 'upper center', bbox_to_anchor = (0.5, -0.2),
                                        fancybox = True, ncol = 3, fontsize = 14)
                        leg.set_title('Fit and dates', prop = {'size': 14})

                    # Update summary
                    summary.append({'Period': plot_dates, 'Location': plot_bbox, 
                                    'Score': score, 'Coefficient': coefficient, 
                                    'Intercept': intercept})

            elif scatter_plot_type == 'individual':
                
                for time in plot_dates:
                    
                    # Prepare df
                    merge_time = merge.query('time == @time')
                    
                    if not merge_time.empty:
                        
                        fig, axs = plt.subplots(1, 3, figsize = (20, 5))
                        fig.set_facecolor('w')

                        # Scatter plot and histograms
                        sp = sns.scatterplot(data = merge_time, x = 'sensor_column', y = 'model_column', ax = axs[0])
                        sns.histplot(data = merge_time, x = 'sensor_column', kde = True,  ax = axs[1])
                        sns.histplot(data = merge_time, x = 'model_column', kde = True,  ax = axs[2])

                        scatter_plot_general_settings(component_nom, axs, units, lim_min, lim_max)
                        
                        if sensor == 'tropomi' or (sensor == 'gome' and sensor_type == 'L2'):
                            day = np.datetime64(time).astype('datetime64[D]')
                            fig.suptitle(f'{component_nom} (Date: {day})', 
                                         fontsize = 18, fontweight = 'bold', y = y)
                            
                        else:
                            month = np.datetime64(time).astype('datetime64[M]')
                            fig.suptitle(f'{component_nom} (Month: {month})', 
                                         fontsize = 18, fontweight = 'bold', y = y)

                        # Linear regression
                        X = merge_time['sensor_column'].values.reshape(-1, 1) 
                        Y = merge_time['model_column'].values.reshape(-1, 1) 
                        fit_X, fit_Y, score, coefficient, intercept = linear_regression(X, Y, component_nom)
                        axs[0].plot(fit_X, fit_Y, color = 'black', label = 'Linear regression')
                        plt.show()

                        # Update summary
                        summary.append({'Period': time, 'Location': plot_bbox, 
                                        'Score': score, 'Coefficient': coefficient, 
                                        'Intercept': intercept})

        elif extent_definition == 'country':
            
            # Prepare df
            merge = merge.reset_index()
            merge = merge[merge['time'].isin(plot_dates)]

            # Read Google API key for reverse geocoding (get country by coordinates)
            google_api_key = get_google_api()

            # Reverse geocoding
            merge['Country'] = merge.apply(lambda row: geocoder.google([row['latitude'], row['longitude']], 
                                            method='reverse', key = google_api_key).country_long, axis = 1)

            # Find data for the countries in search list
            merge = merge[merge['Country'].isin(plot_countries)]
            available_countries = np.unique(merge['Country'])

            if scatter_plot_type == 'aggregated':

                if not merge.empty:
                    
                    fig, axs = plt.subplots(1, 3, figsize = (20, 5))
                    fig.set_facecolor('w')

                    # Linear regression
                    X = merge['sensor_column'].values.reshape(-1, 1) 
                    Y = merge['model_column'].values.reshape(-1, 1) 
                    fit_X, fit_Y, score, coefficient, intercept = linear_regression(X, Y, component_nom)
                    axs[0].plot(fit_X, fit_Y, color = 'black', label = 'Linear regression')

                    # Scatter plot and histograms
                    sp = sns.scatterplot(data = merge, x = 'sensor_column', y = 'model_column', 
                                         hue = 'Country', ax = axs[0])
                    sns.histplot(data = merge, x = 'sensor_column', kde = True,  ax = axs[1])
                    sns.histplot(data = merge, x = 'model_column', kde = True,  ax = axs[2])

                    scatter_plot_general_settings(component_nom, axs, units, lim_min, lim_max)
                    fig.suptitle(f'{component_nom} (All countries)', fontsize = 18, fontweight = 'bold', y = y)
                    leg = sp.legend(loc = 'upper center', bbox_to_anchor = (0.5, -0.2),
                                    fancybox = True, ncol = 3, fontsize = 14)
                    leg.set_title('Fit and countries', prop = {'size': 14})

                    # Update summary
                    summary.append({'Period': plot_dates, 'Location': available_countries, 
                                    'Score': score, 'Coefficient': coefficient, 
                                    'Intercept': intercept})

            elif scatter_plot_type == 'individual':

                for plot_country in plot_countries:

                    merge_country = merge[merge['Country'] == plot_country]

                    if not merge_country.empty:
                        
                        fig, axs = plt.subplots(1, 3, figsize = (20, 5))
                        fig.set_facecolor('w')

                        # Linear regression
                        X = merge_country['sensor_column'].values.reshape(-1, 1) 
                        Y = merge_country['model_column'].values.reshape(-1, 1) 
                        fit_X, fit_Y, score, coefficient, intercept = linear_regression(X, Y, component_nom)
                        axs[0].plot(fit_X, fit_Y, color = 'black')

                        # Update summary
                        summary.append({'Period': plot_dates, 'Location': plot_country, 
                                        'Score': score, 'Coefficient': coefficient, 
                                        'Intercept': intercept})

                        # Scatter plot and histograms
                        sp = sns.scatterplot(data = merge_country, x = 'sensor_column', y = 'model_column', ax = axs[0])
                        sns.histplot(data = merge_country, x = 'sensor_column', kde = True,  ax = axs[1])
                        sns.histplot(data = merge_country, x = 'model_column', kde = True,  ax = axs[2])

                        scatter_plot_general_settings(component_nom, axs, units, lim_min, lim_max)
                        fig.suptitle(f'{component_nom} ({plot_country})', fontsize = 18, fontweight = 'bold', y = y)
                        plt.show()

            else:
                print('ERROR: scatter_plot_type is wrongly defined. The options are ''aggregated'' and ''individual''.')
                raise KeyboardInterrupt()

        else:
            print('ERROR: extent_definition is wrongly defined. The options are ''bbox'' and ''country''.')
            raise KeyboardInterrupt()
                
    elif show_seasons == True:
        
        if show_seasons == True and extent_definition == 'country':
            print('ERROR: Set up show_seasons to False in order to show the scatter plots by countries.')
            raise KeyboardInterrupt()

        plot_seasons = ['Winter', 'Spring', 'Summer', 'Autumn']

        # Prepare df
        merge = merge.reset_index()
        merge = merge[merge['time'].isin(plot_dates)]

        # Find data for the seasons in list
        merge['Season'] = merge.apply(lambda row: get_season(row['time']), axis = 1)
        available_seasons = np.unique(merge['Season'])

        if scatter_plot_type == 'aggregated':

            if not merge.empty:
                
                fig, axs = plt.subplots(1, 3, figsize = (20, 5))
                fig.set_facecolor('w')

                # Linear regression
                X = merge['sensor_column'].values.reshape(-1, 1) 
                Y = merge['model_column'].values.reshape(-1, 1) 
                fit_X, fit_Y, score, coefficient, intercept = linear_regression(X, Y, component_nom)
                axs[0].plot(fit_X, fit_Y, color = 'black', label = 'Linear regression')

                # Scatter plot and histograms
                sp = sns.scatterplot(data = merge, x = 'sensor_column', y = 'model_column', 
                                     hue = 'Season', ax = axs[0])
                sns.histplot(data = merge, x = 'sensor_column', kde = True, ax = axs[1])
                sns.histplot(data = merge, x = 'model_column', kde = True, ax = axs[2])

                scatter_plot_general_settings(component_nom, axs, units, lim_min, lim_max)
                fig.suptitle(f'{component_nom} (All seasons)', fontsize = 18, fontweight = 'bold', y = y)
                leg = sp.legend(loc = 'upper center', bbox_to_anchor = (0.5, -0.2),
                                fancybox = True, ncol = 3, fontsize = 14)
                leg.set_title('Fit and seasons', prop = {'size': 14})
                plt.show()

                # Update summary
                summary.append({'Period': available_seasons, 'Location': plot_bbox, 
                                'Score': score, 'Coefficient': coefficient, 
                                'Intercept': intercept})

        elif scatter_plot_type == 'individual':

            for plot_season in plot_seasons:
                
                # Prepare df
                merge_season = merge[merge['Season'] == plot_season]
                
                if not merge_season.empty:
                    
                    fig, axs = plt.subplots(1, 3, figsize = (20, 5))
                    fig.set_facecolor('w')

                    # Linear regression
                    X = merge_season['sensor_column'].values.reshape(-1, 1) 
                    Y = merge_season['model_column'].values.reshape(-1, 1) 
                    fit_X, fit_Y, score, coefficient, intercept = linear_regression(X, Y, component_nom)
                    axs[0].plot(fit_X, fit_Y, color = 'black')

                    # Update summary
                    summary.append({'Period': plot_season, 'Location':  plot_bbox, 
                                    'Score': score, 'Coefficient': coefficient, 
                                    'Intercept': intercept})

                    # Scatter plot and histograms
                    sp = sns.scatterplot(data = merge_season, x = 'sensor_column', y = 'model_column', ax = axs[0])
                    sns.histplot(data = merge_season, x = 'sensor_column', kde = True, ax = axs[1])
                    sns.histplot(data = merge_season, x = 'model_column', kde = True, ax = axs[2])

                    fig.suptitle(f'{component_nom} ({plot_season})', fontsize = 18, fontweight = 'bold', y = y)
                    scatter_plot_general_settings(component_nom, axs, units, lim_min, lim_max)
                    plt.show()

        else:
            print('ERROR: scatter_plot_type is wrongly defined. The options are ''aggregated'' and ''individual''.')
            raise KeyboardInterrupt()

    else:
        print('ERROR: show_seasons is wrongly defined. The options are True and False.')
        raise KeyboardInterrupt()
    
    summary = pd.DataFrame(summary)

    return summary

In [29]:
def retrieve_coords(merge_table, coords_search_list, component_nom, sensor, model, plot_dates, units):

    """ Get component data for the closest coordinates to the list of search coordinates and plot them along time

        Args:
            merge_table (dataframe): Merge result
            coords_search_list (list): List of search coordinates
            component_nom (str): Component chemical nomenclature
            sensor (str): Name of the sensor
            model (str): Name of the model
            plot_dates (arr): Plot dates
            units (str): Component units

        Returns:
            retrieval_table_all (dataframe): Dataframe with results from search
    """
    
    retrieval_table_all = pd.DataFrame()

    coords_search = pairwise(coords_search_list)

    for i in range(0, len(coords_search)):

        for time in plot_dates:

            # List of available points per time
            retrieval_table = merge_table.query('time == @time').reset_index()
            available_points = list([(x, y) for x, y in zip(retrieval_table['latitude'], retrieval_table['longitude'])])
            
            # Get closest pair to coordinates in search list
            lat_found = closest_point(coords_search[i], available_points)[0]
            lon_found = closest_point(coords_search[i], available_points)[1]
            retrieval_table = merge_table.query('latitude == @lat_found and longitude == @lon_found and time == @time')

            retrieval_table = retrieval_table.reset_index()
            retrieval_table['lat_search'] = coords_search[i][0]
            retrieval_table['lon_search'] = coords_search[i][1]
        
            # Append retrieval table to previous coordinates
            retrieval_table_all = retrieval_table_all.append(retrieval_table)
        
        table_length = len(retrieval_table_all[(retrieval_table_all['lat_search'] == coords_search[i][0]) &
                                               (retrieval_table_all['lon_search'] == coords_search[i][1])])

        # Plot variations in time
        if table_length > 1:
        
            fig, ax = plt.subplots(figsize = (30, 5))
            fig.set_facecolor('w')

            retrieval_table_time = retrieval_table_all[(retrieval_table_all['lat_search'] == coords_search[i][0]) & 
                                                       (retrieval_table_all['lon_search'] == coords_search[i][1])]
            plt1 = ax.plot(retrieval_table_time['time'], retrieval_table_time['sensor_column'], color = 'red', label = sensor.upper())
            plt2 = ax.plot(retrieval_table_time['time'], retrieval_table_time['model_column'], color = 'black', label = model.upper())

            ax.legend(loc='center left', bbox_to_anchor = (1, 0.5), prop = {'size': 25})

            if sensor == 'tropomi':
                ax.set_xlabel('Estimated time', fontsize = 25)
                
            elif sensor == 'iasi' or sensor == 'gome':
                ax.set_xlabel('Month', fontsize = 25)

            ax.tick_params(labelsize = 22)
            ax.set_ylabel(f'{component_nom} ({units})', fontsize = 25)
            ax.set_title(f'{component_nom} total column near ({coords_search[i][0]}, {coords_search[i][1]})', 
                         fontsize = 25, fontweight = 'bold', y = 1.05)
        
    retrieval_table_all = retrieval_table_all.set_index(['lat_search', 
                                                         'lon_search', 
                                                         'latitude', 
                                                         'longitude', 
                                                         'time'])

    return retrieval_table_all