# Flag summary development

## Environment set-up

In [90]:
import boto3
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
from io import BytesIO, StringIO

# New logger function
from merge_log_config import logger

# Silence warnings
import warnings
from shapely.errors import ShapelyDeprecationWarning

warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.filterwarnings(
    "ignore", category=ShapelyDeprecationWarning
)  # Warning is raised when creating Point object from coords. Can't figure out why.

plt.rcParams["figure.dpi"] = 300

In [91]:
# Set AWS credentials
s3 = boto3.resource("s3")
s3_cl = boto3.client("s3")  # for lower-level processes

# Set relative paths to other folders and objects in repository.
bucket_name = "wecc-historical-wx"

In [3]:
def merge_ds_to_df(ds):
    """Converts xarray ds for a station to pandas df in the format needed for processing.

    Parameters
    ----------
    ds: xr.Dataset
        Data object with information about each network and station
    verbose: boolean
        Flag as to whether to print runtime statements to terminal. Default is False. Set in ALLNETWORKS_merge.py run.

    Returns
    -------
    df: pd.DataFrame
        Table object with information about each network and station
    MultiIndex: pd.DataFrame (I think)
        Original multi-index of station and time, to be used on conversion back to ds
    attrs:
        Save ds attributes to inherent to the final merged file
    var_attrs:
        Save variable attributes to inherent to the final merged file
    """

    # Save attributes to inherent them to the final merged file
    attrs = ds.attrs
    var_attrs = {var: ds[var].attrs for var in list(ds.data_vars.keys())}

    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=RuntimeWarning)
        df = ds.to_dataframe()

    # Save instrumentation heights
    if "anemometer_height_m" not in df.columns:
        try:
            df["anemometer_height_m"] = (
                np.ones(ds["time"].shape) * ds.anemometer_height_m
            )
        except:
            logger.info("Filling anemometer_height_m with NaN.")
            df["anemometer_height_m"] = np.ones(len(df)) * np.nan
        finally:
            pass
    if "thermometer_height_m" not in df.columns:
        try:
            df["thermometer_height_m"] = (
                np.ones(ds["time"].shape) * ds.thermometer_height_m
            )
        except:
            logger.info("Filling thermometer_height_m with NaN.")
            df["thermometer_height_m"] = np.ones(len(df)) * np.nan
        finally:
            pass

    # De-duplicate time axis
    df = df[~df.index.duplicated()].sort_index()

    # Save station/time multiindex
    MultiIndex = df.index
    station = df.index.get_level_values(0)
    df["station"] = station

    # Station pd.Series to str
    station = station.unique().values[0]

    # Convert time/station index to columns and reset index
    df = df.droplevel(0).reset_index()

    return df, MultiIndex, attrs, var_attrs

## Function

### Native timestep

In [92]:
def flag_summary_native(df: pd.DataFrame, network: str, station: str) -> None:
    """
    Generates summary of flags set on all QAQC tests.
    Returns 
    - list of unique flag values for each variable
    - % of total obs per variable that was flagged
    - total number of observations per variable
    - number of flagged observations per variable

    Parameters
    ----------
    df : pd.DataFrame
        station dataset converted to dataframe through QAQC pipeline

    Returns
    -------
    pd.dataFrame
    """
    # identify _eraqc variables
    eraqc_vars = [var for var in df.columns if "_eraqc" in var]

    # filter df for only qaqc columns
    df = df[eraqc_vars]

    # generate df of counts of each unique flag for each variable
    flag_counts = df.apply(pd.Series.value_counts)

    # rename columns
    flag_counts = flag_counts.columns.str.replace("_eraqc", "", regex=True)

    # Save file to station bucket
    new_buffer = StringIO()
    flag_counts.to_csv(new_buffer, index=False)
    content = new_buffer.getvalue()
    s3_cl.put_object(
        Bucket=bucket_name,
        Body=content,
        Key="4_merge_qx/{}/{}/eraqc_flag_counts_{}.csv".format(network,station,station),
    )

    return None

In [63]:
# url = "s3://wecc-historical-wx/3_qaqc_wx/VALLEYWATER/VALLEYWATER_6001.zarr"
url = "s3://wecc-historical-wx/3_qaqc_wx/ASOSAWOS/ASOSAWOS_72493023230.zarr"
ds = xr.open_zarr(url)

In [64]:
df, MultiIndex, attrs, var_attrs = merge_ds_to_df(ds)

In [65]:
eraqc_flags = list(range(1,38))
eraqc_vars = [var for var in df.columns if "_eraqc" in var]

In [None]:
# filter df for only qaqc columns
df_test = df[eraqc_vars]

In [85]:
# generate df of counts of each unique flag for each variable
csv = df_test.apply(pd.Series.value_counts)

In [86]:
csv

Unnamed: 0,elevation_eraqc,pr_eraqc,ps_altimeter_eraqc,ps_eraqc,psl_eraqc,sfcWind_dir_eraqc,sfcWind_eraqc,tas_eraqc,tdps_eraqc
21.0,,,,14553.0,22551.0,,,,
23.0,,,25.0,4.0,28.0,,,10.0,201.0
26.0,,,,,,,,99.0,332.0
27.0,,,,,,,16.0,,16.0
28.0,,,,,,,,,25.0


In [87]:
# rename columns

csv.columns.str.replace("_eraqc", "", regex=True)

Index(['elevation', 'pr', 'ps_altimeter', 'ps', 'psl', 'sfcWind_dir',
       'sfcWind', 'tas', 'tdps'],
      dtype='object')

In [88]:
csv

Unnamed: 0,elevation_eraqc,pr_eraqc,ps_altimeter_eraqc,ps_eraqc,psl_eraqc,sfcWind_dir_eraqc,sfcWind_eraqc,tas_eraqc,tdps_eraqc
21.0,,,,14553.0,22551.0,,,,
23.0,,,25.0,4.0,28.0,,,10.0,201.0
26.0,,,,,,,,99.0,332.0
27.0,,,,,,,16.0,,16.0
28.0,,,,,,,,,25.0


### Hourly