# CAMS functions

In [None]:
def CAMS_download(dates, start_date, end_date, component, component_nom, model_full_name, model_level):

    """ Query and download the CAMS levels dataset from CDS API

        Args:
            dates (arr): Query dates
            start_date (str): Query start date
            end_date (str): Query end date
            component (str): Component name
            component_nom (str): Component chemical nomenclature
            model_full_name (str): Full name of the CAMS model among:
            - 'cams-global-atmospheric-composition-forecasts' 
            - 'cams-global-reanalysis-eac4-monthly'
            model_level (str): Model levels:
            -  'Simple' for total columns
            -  'Multiple' for levels

        Returns:
            CAMS_product_name (str): Product name of CAMS product
            CAMS_type (str): Model type:
            -  'Forecast'
            -  'Reanalysis'
    """

    c = cdsapi.Client()

    if model_full_name == 'cams-global-atmospheric-composition-forecasts':

        CAMS_type = 'Forecast'

        if model_level == 'Multiple':
            
            CAMS_product_name = component_nom + '-hourly-levels-' + start_date + '-' + end_date + '.grib'

            if os.path.isfile(os.path.join(os.path.abspath(''), 'data/cams/' + component_nom + 
                                                                '/' + CAMS_product_name)):

                print('The file exists, it will not be downloaded again.')
            
            else:

                print('The file does not exist, it will be downloaded.')
                c.retrieve(
                    model_full_name,
                    {
                        'date': start_date + '/' + end_date,
                        'type': 'forecast',
                        'format': 'grib',
                        'variable': component,
                        'model_level': [str(x + 1) for x in range(137)],
                        'time': '00:00',
                        'leadtime_hour': [
                            '0', '12', '18', '6', 
                        ],
                },
                'data/cams/' + component_nom + '/' + CAMS_product_name)
            
        elif model_level == 'Single':

            CAMS_product_name = component_nom + '-hourly-tc-' + start_date + '-' + end_date + '.grib'

            if os.path.isfile(os.path.join(os.path.abspath(''), 'data/cams/' + component_nom +
                                                                '/' + CAMS_product_name)):
                
                print('The file exists, it will not be downloaded again.')
            
            else:
                print('The file does not exist, it will be downloaded.')

                c = cdsapi.Client()
                c.retrieve(
                    'cams-global-atmospheric-composition-forecasts',
                    {
                        'date': start_date + '/' + end_date,
                        'type': 'forecast',
                        'format': 'grib',
                        'variable': 'total_column_' + component,
                        'time': '00:00',
                        'leadtime_hour': [
                            '0', '12', '18', '6', 
                        ],
                    },
                    'data/cams/' + component_nom + '/' + CAMS_product_name)

    elif model_full_name == 'cams-global-reanalysis-eac4-monthly':
        
        CAMS_product_name = component_nom + '-monthly-tc-' + start_date + '-' + end_date + '.grib'
        CAMS_type = 'Reanalysis'

        if os.path.isfile(os.path.join(os.path.abspath(''), 'data/cams/' + component_nom + 
                                                            '/' + CAMS_product_name)):

            print('The file exists, it will not be downloaded again.')
        
        else:

            print('The file does not exist, it will be downloaded.')       
            months = []
            years = []
            
            for date in dates:

                year = date.split('-')[0]
                month = date.split('-')[1]

                if year not in years:
                    years.append(year)
                    
                if month not in months:
                    months.append(month)

            c.retrieve(
                model_full_name,
                {
                    'format': 'grib',
                    'variable': 'total_column_' + component,
                    'year': years,
                    'month': months,
                    'product_type': 'monthly_mean',
                },
                'data/cams/' + component_nom + '/' + CAMS_product_name)

    return CAMS_product_name, CAMS_type

In [None]:
def CAMS_read(CAMS_product_name, component, component_nom, dates):

    """ Read CAMS levels dataset as xarray dataset object

        Args:
            CAMS_product_name (str): Product name of CAMS product
            component (str): Component name
            component_nom (str): Component chemical nomenclature
            dates (arr): Query dates
            
        Returns:
            CAMS_ds (xarray): CAMS levels dataset in xarray format
    """

    # Read as xarray dataset object
    CAMS_ds = xr.open_dataset('data/cams/' + component_nom + '/' + CAMS_product_name)

    # Change name to component
    if 'levels' in CAMS_product_name:

        if component == 'ozone':
            CAMS_ds = CAMS_ds.rename({'go3': 'component'})

        else:
            CAMS_ds = CAMS_ds.rename({component_nom.lower(): 'component'})

    elif 'tc' in CAMS_product_name:
        
        if component == 'ozone':
            CAMS_ds = CAMS_ds.rename({'gtco3': 'component'})

        else:
            CAMS_ds = CAMS_ds.rename({'tc' + component_nom.lower(): 'component'})

    # Remove data for dates that have been downloaded but not asked for (error of the CAMS API!)
    if 'monthly' in CAMS_product_name:
        
        datetimes = []

        for date in dates:

            year = int(date.split('-')[0])
            month = int(date.split('-')[1])
            time_str = np.datetime64(datetime(year, month, 1, 0, 0, 0, 0))
            datetimes.append(time_str)

        # Drop datetimes
        datetimes_to_delete = np.setdiff1d(CAMS_ds.time.values, np.array(datetimes))
        CAMS_ds = CAMS_ds.drop_sel(time = datetimes_to_delete) 

        # Update dates for analysis
        dates_to_keep = np.intersect1d(CAMS_ds.time.values, np.array(datetimes))
        dates = tuple(dates_to_keep.astype('datetime64[M]').astype(str))

    # Change longitude coordinates
    CAMS_ds = CAMS_ds.assign_coords(longitude = (((CAMS_ds.longitude + 180) % 360) - 180)).sortby('longitude')
    CAMS_ds = CAMS_ds.sortby('latitude')

    # Assign time as coordinates (when there is only one time)
    if CAMS_ds.time.values.size == 1:
        CAMS_ds = CAMS_ds.expand_dims(dim = ['time'])

    return CAMS_ds, dates

In [None]:
def CAMS_137_levels():

    """ Create table with information about the 137 CAMS levels
    
        Returns:
            CAMS_levels_df (dataframe): Table with 137 CAMS levels data
    """

    # Read csv table with 137 levels
    CAMS_levels_df = pd.read_csv('data/cams/137-levels.csv')

    # Drop first row and set n as index hybrid
    CAMS_levels_df = CAMS_levels_df.drop(0).reset_index(drop = True)
    CAMS_levels_df = CAMS_levels_df.set_index('n')
    CAMS_levels_df.index.names = ['hybrid']

    # Change important columns to numeric
    CAMS_levels_df['ph [Pa]'] = pd.to_numeric(CAMS_levels_df['ph [hPa]']) * 100
    CAMS_levels_df['Geopotential Altitude [m]'] = pd.to_numeric(CAMS_levels_df['Geopotential Altitude [m]'])
    CAMS_levels_df['Density [kg/m^3]'] = pd.to_numeric(CAMS_levels_df['Density [kg/m^3]'])

    # Calculate half pressures
    CAMS_levels_df['ph-diff [Pa]'] = CAMS_levels_df['ph [Pa]'].diff(1)
    CAMS_levels_df['ph-diff [Pa]'].iloc[0] = CAMS_levels_df['ph [Pa]'].iloc[0]

    # Calculate difference from geopotential altitude
    CAMS_levels_df['Depth [m]'] = CAMS_levels_df['Geopotential Altitude [m]'].diff(-1)
    CAMS_levels_df['Depth [m]'].iloc[-1] = CAMS_levels_df['Geopotential Altitude [m]'].iloc[-1]

    return CAMS_levels_df

In [None]:
def CAMS_kg_kg_to_kg_m2(CAMS_ds, CAMS_levels_df, conversion_method):

    """ Convert the units of the CAMS partial columns for any component from kg/kg to kg/m2

        Args:
            CAMS_ds (xarray): CAMS levels dataset in xarray format
            CAMS_levels_df (dataframe): Table with 137 CAMS levels data
            conversion_method (str): Type of conversion. It can be:
            * Simple: Multiply the partial columns by the layer depth and density
            * Complex: Calculate the partial column above each CAMS half level
        
        Returns:
            CAMS_ds (xarray): CAMS levels dataset in xarray format
    """

    if conversion_method == 'Simple':

        # Create xarray object from levels df
        CAMS_levels_df_ds = CAMS_levels_df.to_xarray()

        # From kg/kg to kg/m3
        CAMS_ds = CAMS_ds * CAMS_levels_df_ds['Density [kg/m^3]']

        # From kg/m3 to kg/m2
        CAMS_ds = CAMS_ds * CAMS_levels_df_ds['Depth [m]']

    elif conversion_method == 'Complex':
        
        g = 9.81
        gi = 1/g #s2/m

        da_hybrid = []

        # Initialize partial columns at the top of the atmosphere as 0
        PC_0 = 0 * CAMS_ds.sel(hybrid = 1)
        da_hybrid.append(PC_0)
        CAMS_ds_PC = xr.concat(da_hybrid, dim = 'hybrid')

        for hybrid in range(1, CAMS_ds.hybrid.size):
            
            PC_last = CAMS_ds_PC.sel(hybrid = hybrid)
            component = CAMS_ds.sel(hybrid = hybrid + 1)
            
            # Units: Component(kg/kg) * ph-diff(Pa = kg/m*s2)) * s2/m -> To kg/m2
            PC_da = PC_last + component * CAMS_levels_df['ph-diff [Pa]'].loc[hybrid] * gi
            
            da_hybrid.append(PC_da)
            CAMS_ds_PC = xr.concat(da_hybrid, pd.Index(range(1, hybrid + 2), name = 'hybrid'))

        CAMS_ds = CAMS_ds_PC

    return CAMS_ds

In [None]:
def CAMS_kg_m2_to_molecules_cm2(CAMS_ds, component_mol_weight):

    """ Convert the units of the CAMS dataset for any component from kg/m2 to molecules/cm2

        Args:
            CAMS_ds (xarray): CAMS levels dataset in xarray format
            component_mol_weight (float): Component molecular weight

        Returns:
            CAMS_ds (xarray): CAMS levels dataset in xarray format
    """

    NA = 6.022*10**23
    CAMS_ds = (CAMS_ds * NA * 1000) / (10000 * component_mol_weight)
    
    return CAMS_ds

In [None]:
def CAMS_molecules_cm2_to_DU(CAMS_ds, component_mol_weight):

    """ Convert the units of the CAMS dataset for any component from molecules/cm2 to DU for ozone

        Args:
            CAMS_ds (xarray): CAMS levels dataset in xarray format
            component_mol_weight (float): Component molecular weight

        Returns:
            CAMS_ds (xarray): CAMS levels dataset in xarray format
    """

    CAMS_ds = CAMS_ds / (2.69*10**16)
    
    return CAMS_ds

In [None]:
def CAMS_interpolation(CAMS_ds, TROPOMI_ds, bbox, component_nom):

    """ Interpolate the data in the coordinates of CAMS dataset for each level to a grid of 100x100 
        and show how it compares to TROPOMI dataset

        Args:
            CAMS_ds (xarray): CAMS levels dataset in xarray format
            TROPOMI_ds (xarray): TROPOMI dataset in xarray format
            bbox (arr): Query bounding box
            component_nom (str): Component chemical nomenclature
        
        Returns:
            CAMS_ds (xarray): Interpolated CAMS levels dataset in xarray format
    """

    # Grid data from CAMS
    x = CAMS_ds.longitude.values
    y = CAMS_ds.latitude.values
    x_old, y_old = np.meshgrid(x, y)

    # Grid data in 100x100
    xi = np.linspace(bbox[0][0], bbox[1][0], 100)
    yi = np.linspace(bbox[0][1], bbox[1][1], 100)
    x_new, y_new = np.meshgrid(xi, yi)

    da_hybrid = []
    da_step = []

    for step in range(CAMS_ds.step.size):

        for hybrid in range(CAMS_ds.hybrid.size):
            
            z = CAMS_ds.isel(hybrid = hybrid, step = step).component.values

            zi = scipy.interpolate.griddata((x_old.flatten(), y_old.flatten()), 
                                            z.flatten(), (xi[None,:], yi[:,None]), 
                                            method = 'linear')

            # Create data array for each layer
            da = xr.DataArray(data = xr.Variable(('lon', 'lat'), zi),
                            dims = ['lon', 'lat'],
                            coords = {'longitude': xr.Variable('lon', xi),
                                        'latitude': xr.Variable('lat', yi)
                                    }
            )

            # Append arrays for each layer
            da_hybrid.append(da)

        # Concatenate data arrays for all layers
        da_hybrid_concat = xr.concat(da_hybrid, pd.Index(range(CAMS_ds.hybrid.size), 
                                                        name = 'hybrid'))

        da_step.append(da_hybrid_concat)
        da_hybrid = []

    CAMS_ds_new = xr.concat(da_step, pd.Index(CAMS_ds.valid_time.values, 
                                            name = 'valid_time'))

    # VISUALIZATION

    z_1L = CAMS_ds.isel(hybrid = 136, step = 2).component.values
    zi_1L = CAMS_ds_new.isel(hybrid = 136, valid_time = 2).values

    fig, ax = plt.subplots(nrows = 1, ncols = 2, figsize = (20, 10))

    # Show old CAMS grid
    im1 = ax[0].scatter(x_old, y_old, c = z_1L, cmap = 'coolwarm', vmin = np.nanmin(z_1L), vmax = np.nanmax(z_1L))

    # Show contour plot of new CAMS data
    im2 = ax[1].contourf(x_new, y_new, zi_1L, cmap = 'coolwarm', vmin = np.nanmin(zi_1L), vmax = np.nanmax(zi_1L))

    # Show new CAMS grid
    ax[1].scatter(x_new, y_new, marker = 'o', c = 'grey', s = 3)

    # Show TROPOMI grid
    TROPOMI_lat = TROPOMI_ds['latitude'].values
    TROPOMI_lon = TROPOMI_ds['longitude'].values
    ax[1].scatter(TROPOMI_lon, TROPOMI_lat, marker = 'x', c = 'black', s = 30)

    # Add colorbars
    cbr1 = fig.colorbar(im1, ax = ax[0])
    cbr2 = fig.colorbar(im2, ax = ax[1])
    cbr1.set_label(f'{component_nom} (mol/m²)', fontsize = 18)
    cbr2.set_label(f'{component_nom} (mol/m²)', fontsize = 18)

    for i in range(2):
        
        ax[i].set_xlim([bbox[0][0], bbox[1][0]])
        ax[i].set_ylim([bbox[0][1], bbox[1][1]])
        ax[i].set_xlabel('Longitude', fontsize = 18)
        ax[i].set_ylabel('Latitude', fontsize = 18)
        ax[i].tick_params(labelsize = 16)

    ax[0].set_title('Original', fontsize = 20, pad = 20)
    ax[1].set_title('Interpolated', fontsize = 20, pad = 20)
    fig.suptitle('COMPONENT FOR CAMS AT HYBRID = 137 AT 12:00', fontsize = 22)
    plt.show()

    return CAMS_ds_new