# General functions

In [None]:
def comparison_check(sensor, model, component_nom, model_full_name):

    """ Check if the comparison is possible

        Args:
            sensor (str): Name of the sensor
            model (str): Name of the model
            component_nom (str): Component chemical nomenclature
            model_full_name (str): Full name of the CAMS model among:
            - 'cams-global-atmospheric-composition-forecasts' 
            - 'cams-global-reanalysis-eac4-monthly'
    """

    if (sensor == 'tropomi' and model == 'cams') or (sensor == 'iasi' and model == 'cams'):

        if (model_full_name != 'cams-global-atmospheric-composition-forecasts' and
            model_full_name != 'cams-global-reanalysis-eac4-monthly'):

            print('ERROR: The model is not supported.')
            print('The models that are currently supported are:')
            print('- cams-global-atmospheric-composition-forecasts')
            print('- cams-global-reanalysis-eac4-monthly')
            raise KeyboardInterrupt

        else:
            
            tropomi_component_nom = ['NO2', 'CO', 'O3', 'SO2', 'CH4']
            iasi_component_nom = ['CO', 'O3']

            if ((sensor == 'tropomi' and component_nom not in tropomi_component_nom) or
                (sensor == 'iasi' and component_nom not in iasi_component_nom)):

                print(f'ERROR: This specific component cannot be retrieved by the sensor {sensor.upper()}.')
                raise KeyboardInterrupt

            else:

                print('The comparison is possible and will start now.')
    else:

        print('The comparison is only possible for:')
        print('1. cams (CAMS model) vs. tropomi (TROPOMI sensor)')
        print('2. cams (CAMS model) vs. iasi (IASI sensor)')

        raise KeyboardInterrupt

In [None]:
def components_table(sensor, component_nom):

    """ Create table with information about the components (molecular weight, full name in different datasets)

        Args:
            sensor (str): Name of the sensor
            component_nom (str): Component chemical nomenclature

        Returns:
            component (str): Component name
            component_mol_weight (float): Component molecular weight
            component_sensor_product (str): Component product name in TROPOMI or IASI database
            component_sensor_column (str): Component column name in TROPOMI or IASI database
    """

    component_nom_col = ['NO2', 'CO', 'O3', 'SO2', 'CH4']

    component_col = ['nitrogen_dioxide', 'carbon_monoxide', 'ozone', 'sulfur_dioxide', 'methane']
    component_mol_weight_col = [46.005, 28.01, 48, 64.066, 16.04]
    component_tropomi_product_col = ['L2__NO2___', 'L2__CO____', 'L2__O3____', 'L2__SO2___', 'L2__CH4___']
    component_tropomi_column_col = ['nitrogendioxide_tropospheric_column', 
                                    'carbonmonoxide_total_column', 
                                    'ozone_total_vertical_column', 
                                    'sulfurdioxide_total_vertical_column',
                                    'methane_tropospheric_column']
    component_iasi_column_col = ['-', 'COgridDAY', 'O3gridDAY', '-', '-']

    rows = {'Nomenclature': component_nom_col, 
            'Weight': component_mol_weight_col,
            'Component': component_col, 
            'TROPOMI_product': component_tropomi_product_col,
            'TROPOMI_column': component_tropomi_column_col,
            'IASI_column': component_iasi_column_col}

    components_table = pd.DataFrame(rows)

    component = components_table['Component'].loc[components_table['Nomenclature'] == component_nom].iloc[0]
    component_mol_weight = components_table['Weight'].loc[components_table['Nomenclature'] == component_nom].iloc[0]
    
    if sensor == 'tropomi':
        component_sensor_product = components_table['TROPOMI_product'].loc[components_table['Nomenclature'] == component_nom].iloc[0]

    elif sensor == 'iasi':
        component_sensor_product = None
    
    component_sensor_column = components_table[sensor.upper() +'_column'].loc[components_table['Nomenclature'] == component_nom].iloc[0]

    return component, component_mol_weight, component_sensor_product, component_sensor_column

In [None]:
def generate_folders(model, sensor, component_nom):

    """ Generate folders to download the datasets if they do not exist 

        Args:
            model (str): Name of the model
            sensor (str): Name of the sensor
            component_nom (str): Component chemical nomenclature
    """

    model_path = os.path.join(os.path.abspath(''), 'data/' + model + '/' + component_nom)
    sensor_path = os.path.join(os.path.abspath(''), 'data/' + sensor + '/' + component_nom)

    paths = [model_path, sensor_path]

    for path in paths:
        os.makedirs(path, exist_ok = True) 

In [None]:
def sensor_convert_units(sensor_ds, sensor_column, sensor):

    """ Convert the units of the sensor dataset for any component from mol/m2 to molecules/cm2

        Args:
            sensor_ds (xarray): sensor dataset in xarray format (TROPOMI or IASI)
            sensor_column (str): Name of sensor column in downloaded dataset
            sensor (str): Name of the sensor
            
        Returns:
            sensor_ds (xarray): sensor dataset in xarray format
    """

    if sensor == 'tropomi':
        
        if sensor_ds[sensor_column].units == 'mol m-2':

            print('The sensor component units are mol m**-2. They will be converted to molecules cm**-2.')
            sensor_ds[sensor_column] = sensor_ds[sensor_column] * 6.02214*10**19
            
            if 'apriori_profile' in list(sensor_ds.keys()):
                sensor_ds['apriori_profile'] = sensor_ds['apriori_profile'] * 6.02214*10**19
        else: 
            print('The sensor units could not be converted.')
            print(sensor_ds[sensor_column].units)

    elif sensor == 'iasi':
        
        if sensor_ds[sensor_column].units == 'mol m-2':
            
            print('The sensor component units are mol m**-2. They will be converted to molecules cm**-2.')
            sensor_ds = sensor_ds * 6.02214*10**19
        else: 
            print('The model units could not be converted.')
            print(sensor_ds[sensor_column].units)

    return sensor_ds

In [None]:
def model_convert_units(model_ds, model, component_mol_weight):

    """ Convert the units of the model dataset for any component from kg/kg or kg/m2 to molecules/cm2

        Args:
            model_ds (xarray): model dataset in xarray format (CAMS)
            model (str): Name of the model
            component_mol_weight (float): Component molecular weight
            
        Returns:
            model_ds (xarray): model dataset in xarray format
    """

    if model == 'cams':

        if model_ds.component.GRIB_units == 'kg kg**-1':
            print('The model component units are kg kg**-1. They will be converted to molecules cm**-2.')
            conversion_method = 'Simple'
            model_ds = CAMS_kg_kg_to_kg_m2(model_ds, model_levels, conversion_method)
            model_ds = CAMS_kg_m2_to_molecules_cm2(model_ds, component_mol_weight)

        elif model_ds.component.GRIB_units == 'kg m**-2':
            print('The model component units are kg m**-2. They will be converted to molecules cm**-2.')
            model_ds = CAMS_kg_m2_to_molecules_cm2(model_ds, component_mol_weight)
        
        else: 
            print('The model units could not be converted.')
            print(model_ds.component.GRIB_units)

    return model_ds

In [None]:
def nearest_neighbour(array, value):

    """ Find index of the closest value in an array (it can be used to locate the nearest neighbours in space and time)

        Args:
            array (arr): Array to find the nearest neighbour
            value (float): Search value
    """

    index = np.abs([x - value for x in array]).argmin(0)
    
    return index

In [None]:
def pairwise(dates):

    """ Split dates array in pairs

        Args:
            dates (arr): All dates

        Returns:
            period (tuple): Divisible dates into pairs
    """

    pair_element = iter(dates)
    period = list(zip(pair_element, pair_element))

    return period

In [None]:
def subset(ds, bbox):

    """ Subset any dataset (with latitude and longitude as coordinates) into desired bounding box.

        Args:
            ds (xarray): Dataset in xarray format
            bbox (arr): Query bounding box
    
        Returns:
            ds (xarray): Dataset in xarray format
    """

    # Get nearest longitude and latitude to bbox
    lon_min_index = nearest_neighbour(ds.longitude.data, bbox[0][0])
    lon_max_index = nearest_neighbour(ds.longitude.data, bbox[1][0])
    lat_min_index = nearest_neighbour(ds.latitude.data, bbox[0][1])
    lat_max_index = nearest_neighbour(ds.latitude.data, bbox[1][1])

    # Define slices
    slice_lat = slice(lat_min_index, lat_max_index + 1)
    slice_lon = slice(lon_min_index, lon_max_index + 1)

    # Set limits
    ds = ds.isel(longitude = slice_lon, latitude = slice_lat)

    return ds

In [None]:
def prepare_df(match_df, sensor):

    """ Prepare dataframe for merge

        Args:
            match_df (dataframe): Dataframe used to apply averaging kernels
            sensor (str): Name of the sensor
        
        Returns:
            match_df (dataframe): Dataframe used to apply averaging kernels
    """

    if sensor == 'tropomi':

        # Pass NaNs to data with qa_value under 0.5
        match_df.loc[match_df['qa_value'] < 0.5, [sensor_column, 'column_kernel']] = float('NaN')

        # Drop levels
        if component_nom == 'CO' or component_nom == 'SO2':
            
            match_df.index.names = ['corner', 'ground_pixel', 'layer', 'scanline']
        
        elif component_nom == 'O3':

            match_df.index.names = ['corner', 'ground_pixel', 'layer', 'level', 'scanline']
            
        match_df = match_df.groupby(by = ['layer', 'scanline', 'ground_pixel', 'time', 'delta_time']).mean()
        match_df = match_df.reset_index(level = ['layer', 'delta_time'])

    elif sensor == 'iasi':

        match_df = match_df.reset_index(level = ['latitude', 'longitude'])

    return match_df

In [None]:
def plot_period(sensor_ds):

    """ Define plot period

        Args:
            sensor_ds (xarray): sensor dataset in xarray format (TROPOMI or IASI)

        Returns:
            plot_dates (arr): dates for which the datasets comparison will be shown
    """

    period_answer = input('Do you want to visualize the plots for specific dates? Write Yes or press Enter if you want to visualize all: ')
    plot_dates = []

    if period_answer == 'Yes' or period_answer == 'yes':

        options_df = pd.DataFrame({'Date': sensor_ds.time.values})

        for index, row in options_df.iterrows():
            date_answer = input('Do you want to show the plots for ' + str(row['Date']) + '? Yes or No: ')
            if date_answer == 'Yes' or date_answer == 'yes':
                plot_dates.append(row['Date'])

    else:
        plot_dates = sensor_ds.time.values

    print('The plots will be shown for the following dates:')
    print(plot_dates)

    return plot_dates

In [None]:
def colorbar_range(range_type, merge, array, *args):

    """ Define colorbar range

        Args:
            range_type (str): Range type for colorbar:
            -  'Original': Show original values in range
            -  'Positive': Show only positive values in range
            -  'Equal': Show same scale in range
            merge (xarray): Merge result for a specific time
            array (xarray): Component for a specific time and model/sensor

        Returns:
            vmin, vmax (float): Limits of color bar
    """

    # The colorbar will show the original range
    if range_type == 'Original':

        # Define vmin
        vmin = np.nanmin(array)

        # Define vmax
        vmax = np.nanmax(array)

    # The colorbar will show the original range only with positive values
    elif range_type == 'Positive':
        
        # Define vmin
        if np.nanmin(array) < 0:
            vmin = 0
        else:
            vmin = np.nanmin(array)

        # Define vmax
        vmax = np.nanmax(array)

    # The colorbar will be in the same scale for both datasets
    elif range_type == 'Equal':
        
        # Define arrays
        array_1 = merge.model_column

        if sensor_column in args:
            array_2 = merge[sensor_column]

        elif model_total_ds in args:
            array_2 = model_total_ds.component.isel(step = step).sel(time = time)

        # Define vmin
        if (np.nanmin(array_2) < np.nanmin(array_1)) and np.nanmin(array_2) >= 0:
            vmin = np.nanmin(array_2)
        else:
            vmin = np.nanmin(array_1)

        # Define vmax
        if np.nanmax(array_2) < np.nanmax(array_1):
            vmax = np.nanmax(array_1)
        else:
            vmax = np.nanmax(array_2)
            
    return vmin, vmax

In [None]:
def visualize_pcolormesh(fig, axs, data_array, longitude, latitude, projection, color_scale, pad,
                         long_name, units, vmin, vmax, set_global = True, lonmin = -180, lonmax = 180, latmin = -90, latmax = 90):
    
    """ Visualize two datasets side by side

        Args:
            fig: Figure
            axs: Axes of figure
            data_array (xarray): Variable values to plot - It must be 2-dimensional
            longitude (arr): Longitudes within data_array
            latitude (arr): Latitudes within data_array
            projection: Geographical projection
            color_scale (str): Color scale for the color bar
            pad (float): Padding for the subtitles
            long_name (str): Plot name
            units (str): Units of variable
            vmin, vmax (float): Limits of color bar
            set_global: Extent setting
            lonmin, lonmax, latmin, latmax (float): Limits of longitude and latitude values
    """

    palette = copy(plt.get_cmap(color_scale))
    palette.set_under(alpha = 0)
    
    im = axs.pcolormesh(
                        longitude, latitude, data_array, 
                        cmap = palette, 
                        transform = projection,
                        vmin = vmin,
                        vmax = vmax,
                        norm = colors.Normalize(vmin = 0, vmax = vmax),
                        shading = 'auto'
                        )
                        
    axs.add_feature(cfeature.BORDERS, edgecolor = 'black', linewidth = 1)
    axs.add_feature(cfeature.COASTLINE, edgecolor = 'black', linewidth = 1)

    if (projection == ccrs.PlateCarree()):
        axs.set_extent([lonmin, lonmax, latmin, latmax], projection)
        gl = axs.gridlines(draw_labels = True, linestyle = '--')
        gl.top_labels = False
        gl.right_labels = False
        gl.xformatter = LONGITUDE_FORMATTER
        gl.yformatter = LATITUDE_FORMATTER
        gl.xlabel_style = {'size': 16}
        gl.ylabel_style = {'size': 16}

    if(set_global):
        axs.set_global()
        axs.gridlines()

    axs.set_title(long_name, fontsize = 18, pad = pad)
    axs.tick_params(labelsize = 14)

    cbr = fig.colorbar(im, ax = axs, extend = 'both', orientation = 'horizontal', fraction = 0.05, pad = 0.15)   
    cbr.set_label(units, fontsize = 16)
    cbr.ax.tick_params(labelsize = 14)
    cbr.ax.xaxis.get_offset_text().set_fontsize(14)

In [None]:
def visualize_model_vs_sensor(model, sensor, component_nom, merged_df, plot_dates, bbox, pad, y, model_type, sensor_type, range_type):

    """ Plot model and sensor datasets in the study area for the selected dates, 
        along with a plot of the differences

        Args:
            model (str): Name of the model
            sensor (str): Name of the sensor
            component_nom (str): Component chemical nomenclature
            merged_df (dataframe): Merge result 
            plot_dates (arr): All selected dates to plot
            bbox (arr): Query bounding box
            pad (float): Padding for the subtitles
            y (float): y-position of main title
            model_type (str): Model type:
            -  'Forecast'
            -  'Reanalysis'
            sensor_type (str): Sensor type ('NRT')
            range_type (str): Range type for colorbar:
            -  'Original': Show original values in range
            -  'Positive': Show only positive values in range
            -  'Equal': Show same scale in range
    """

    units = component_nom + ' (molecules/cm2)'
    projection = ccrs.PlateCarree()

    for time in plot_dates:

        fig, axs = plt.subplots(1, 3, figsize = (20, 4), subplot_kw = {'projection': projection})

        merge = merged_df.query('time == @time').to_xarray()
        latitude = merge.sel(time = time).latitude
        longitude = merge.sel(time = time).longitude
        merge = merge.sel(time= time).assign_coords(latitude = latitude, longitude = longitude)

        # First plot - CAMS 
        array = merge.model_column
        vmin, vmax = colorbar_range(range_type, merge, array, sensor_column)
        long_name = model.upper() + ' (' + model_type + ')'
        visualize_pcolormesh(
                            fig = fig, axs = axs[0],
                            data_array = array.fillna(-999),
                            longitude = array.longitude,
                            latitude = array.latitude,
                            projection = ccrs.PlateCarree(),
                            color_scale = 'coolwarm',
                            pad = pad,
                            long_name = long_name,
                            units = units,
                            vmin = vmin, 
                            vmax = vmax, 
                            set_global = False,
                            lonmin = bbox[0][0],
                            lonmax = bbox[1][0],
                            latmin = bbox[0][1],
                            latmax = bbox[1][1]
                            )

        # Second plot - TROPOMI
        array = merge[sensor_column]
        vmin, vmax = colorbar_range(range_type, merge, array, sensor_column)
        long_name = sensor.upper() + ' (' + sensor_type + ')'
        visualize_pcolormesh(
                            fig = fig, axs = axs[1],
                            data_array = array.fillna(-999),
                            longitude = array.longitude,
                            latitude = array.latitude,
                            projection = ccrs.PlateCarree(),
                            color_scale = 'coolwarm',
                            pad = pad,
                            long_name = long_name,
                            units = units,
                            vmin = vmin,  
                            vmax = vmax, 
                            set_global = False,
                            lonmin = bbox[0][0],
                            lonmax = bbox[1][0],
                            latmin = bbox[0][1],
                            latmax = bbox[1][1]
                            )

        # Third plot - Differences
        array = merge.difference
        long_name = 'Differences plot'
        visualize_pcolormesh(
                            fig = fig, axs = axs[2],
                            data_array = array.fillna(-999),
                            longitude = array.longitude,
                            latitude = array.latitude,
                            projection = ccrs.PlateCarree(),
                            color_scale = 'coolwarm',
                            pad = pad,
                            long_name = long_name,
                            units = units,
                            vmin = np.nanmin(array),  
                            vmax = np.nanmax(array), 
                            set_global = False,
                            lonmin = bbox[0][0],
                            lonmax = bbox[1][0],
                            latmin = bbox[0][1],
                            latmax = bbox[1][1]
                            )

        if sensor == 'tropomi':
            fig.suptitle(f'DISTRIBUTION OF {component_nom} (Estimated time: {time})',
                    fontsize = 18, y = y)
        
        elif sensor == 'iasi':
            month = time.astype('datetime64[M]')
            fig.suptitle(f'DISTRIBUTION OF {component_nom} (Month: {month})',
                        fontsize = 18, y = y)

        plt.show()

In [None]:
def visualize_model_original_vs_calculated(model, component_nom, merged_df, model_total_ds, plot_dates, bbox, pad, y, model_type, range_type):

    """ Plot model total columns from the original dataset and the calculated one 
        in the study area for the selected dates

        Args:
            model (str): Name of the model
            component_nom (str): Component chemical nomenclature
            merged_df (dataframe): Merge result
            model_total_ds (xarray): CAMS total columns dataset in xarray format
            plot_dates (arr): All selected dates to plot
            bbox (arr): Query bounding box
            pad (float): Padding for the subtitles
            y (float): y-position of main title
            model_type (str): Model type:
            -  'Forecast'
            -  'Reanalysis'
            range_type (str): Range type for colorbar:
            -  'Original': Show original values in range
            -  'Positive': Show only positive values in range
            -  'Equal': Show same scale in range
    """

    units = component_nom + ' (molecules/cm2)'
    projection = ccrs.PlateCarree()

    for time in plot_dates:

        fig, axs = plt.subplots(1, 2, figsize = (20, 5), subplot_kw = {'projection': projection})

        merge = merged_df.query('time == @time').to_xarray()
        latitude = merge.sel(time = time).latitude
        longitude = merge.sel(time = time).longitude
        merge = merge.sel(time= time).assign_coords(latitude = latitude, longitude = longitude)

        step = 2

        # First plot - CAMS
        array = merge.model_column
        vmin, vmax = colorbar_range(range_type, merge, array, model_total_ds, step, time)
        long_name = 'CALCULATED TOTAL COLUMNS ' + model.upper() + ' (' + model_type + ')'
        visualize_pcolormesh(
                            fig = fig, axs = axs[0],
                            data_array = array.fillna(-999),
                            longitude = array.longitude,
                            latitude = array.latitude,
                            projection = ccrs.PlateCarree(),
                            color_scale = 'coolwarm',
                            pad = pad,
                            long_name = long_name,
                            units = units,
                            vmin = vmin, 
                            vmax = vmax, 
                            set_global = False,
                            lonmin = bbox[0][0],
                            lonmax = bbox[1][0],
                            latmin = bbox[0][1],
                            latmax = bbox[1][1]
                            )

        # Second plot - TROPOMI
        array = model_total_ds.component.isel(step = step).sel(time = time)
        vmin, vmax = colorbar_range(range_type, merge, array, model_total_ds, step, time)
        long_name = 'ORIGINAL TOTAL COLUMNS ' + model.upper() + ' (' + model_type + ')'
        visualize_pcolormesh(
                            fig = fig, axs = axs[1],
                            data_array = array.fillna(-999),
                            longitude = array.longitude,
                            latitude = array.latitude,
                            projection = ccrs.PlateCarree(),
                            color_scale = 'coolwarm',
                            pad = pad,
                            long_name = long_name,
                            units = units,
                            vmin = vmin,
                            vmax = vmax, 
                            set_global = False,
                            lonmin = bbox[0][0],
                            lonmax = bbox[1][0],
                            latmin = bbox[0][1],
                            latmax = bbox[1][1]
                            )

        fig.suptitle(f'DISTRIBUTION OF {component_nom} (Estimated time: {time})',
                    fontsize = 18, y = y)
        plt.show()

In [None]:
def scatter_plot(merged_df, component_nom, plot_dates, y):

    """ Scatter plot between the model and sensor datasets in the study area for the selected dates

        Args:
            merged_df (dataframe): Merge result
            component_nom (str): Component chemical nomenclature
            plot_dates (arr): All selected dates to plot
            y (float): y-position of main title
    """

    for time in plot_dates:
        
        merge = merged_df.query('time == @time')
        plt.scatter(merge[sensor_column].values, merge['model_column'].values, color = 'black', s = 5)
        plt.title(f'{component_nom} (Estimated time: {time})', fontsize = 18, y = y)
        plt.xlabel(f'Sensor {component_nom} (molecules/cm2)', fontsize = 16)
        plt.ylabel(f'Model {component_nom} (molecules/cm2)', fontsize = 16)

        X = merge[sensor_column].values.reshape(-1, 1) 
        Y = merge['model_column'].values.reshape(-1, 1) 
        reg = LinearRegression().fit(X, Y)
        print(f'Estimated time: {time}')
        print(f'Fit equation: Model {component_nom} = Sensor {component_nom} * {float(reg.coef_):.2f} + ({float(reg.intercept_):.2E})')
        print(f'Coefficient of determination (R2): {reg.score(X, Y):.2f}')
        
        fit_X = np.linspace(np.nanmin(X), np.nanmax(X), 10)
        fit_Y = fit_X * float(reg.coef_) + reg.intercept_
        plt.plot(fit_X, fit_Y, color = 'red')
        plt.show()