In [None]:
import duckdb
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Union, List, Iterable
import shutil

In [None]:
def get_schema_df(path: Union[Path, str]):
    df = duckdb.query(f"""
        SELECT name, type, converted_type, logical_type
        FROM parquet_schema('{path}')
        WHERE repetition_type = 'OPTIONAL'
        ;
        """).df()
    return df

In [None]:
def validate_schema(df: pd.DataFrame, schema: dict):
    valid = True
    messages = []
    if len(df) !=  len(schema):
        # print(f"schema should have {len(validation_data)} fileds but has {len(df)}")
        messages.append(f"schema should have {len(validation_data)} fields but has {len(df)}.")
        valid = False
    
    for vd in schema:
        # get row by name
        row = df[df['name'] == vd['name']]
        if len(row) < 1:
            # print(f"No field named {vd['name']}")
            messages.append(f"No field named {vd['name']}.")
            valid = False
        if len(row) > 1:
            # print(f"More than one field named {vd['name']}")
            messages.append(f"More than one field named {vd['name']}.")
            valid = False
            
        rec = row.to_dict(orient='records')[0]
        
        if not rec['type'] == vd['type']:
            # print(f"{vd['name']} has type {rec['type']} but should be {vd['type']}")
            messages.append(f"{vd['name']} has type {rec['type']} but should be {vd['type']}.")
            valid = False
    
        if not rec['converted_type'] == vd['converted_type']:
            # print(f"{vd['name']} has type {rec['converted_type']} but should be {vd['converted_type']}")
            messages.append(f"{vd['name']} has converted_type {rec['converted_type']} but should be {vd['converted_type']}.")
            valid = False
        
    return valid, messages

In [None]:
def validate_path(path: Union[Path, str], pattern: str, schema: dict):
    path = Path(path)
    if path.is_dir():
        files = Path(path).glob(pattern)
    else:
        files = [path]
        
    for file in files:
        df = get_schema_df(file)
        passed, messages = validate_schema(df, schema)
        if not passed:
            raise ValueError(f"File: {str(file)} not valid.", messages)

In [None]:
def attempt_cast_to(path: Union[str, Path], backup_file: bool = True, cleanup: bool = False):
    """Attempt to fix parquet files by casting columns to correct type.

    Will only work locally where user has write permissions.
    """
    path = Path(path)
    backup_path = Path(f"{path}.bak")
    if path.is_file():
        # operation could be destructive. back up file
        if backup_file:    
            shutil.copy(path, backup_path)

        # try to fix it
        duckdb.query(f"""
            COPY (
                SELECT 
                    location_id::varchar as location_id
                    , value_time::timestamp as value_time
                    , reference_time::timestamp as reference_time
                    , value::float as value
                    , variable_name::varchar as variable_name
                    , measurement_unit::varchar as measurement_unit
                    , configuration::varchar as configuration
                FROM read_parquet('{path}')
            ) TO '{path}' (FORMAT PARQUET);
        """)

        # check that len(org) == len(fixed) and clean up
        if cleanup:
            should_cleanup = True
            df = duckdb.query(f"""
                SELECT
                    (SELECT count(*) FROM read_parquet('{path}')) as new
                    , (SELECT count(*) FROM read_parquet('{backup_path}')) as org
            ;""").to_df()

            if not (df["new"][0] == df["org"][0]):
                should_cleanup = False

            df = duckdb.query(f"""
                SELECT
                    (SELECT count(distinct(location_id)) FROM read_parquet('{path}')) as new
                    , (SELECT count(distinct(location_id)) FROM read_parquet('{backup_path}')) as org
            ;""").to_df()

            if not (df["new"][0] == df["org"][0]):
                should_cleanup = False


            if should_cleanup:
                backup_path.unlink()
    else:
        raise FileNotFoundError(path)

In [None]:
def fix_path_schema(path: Union[Path, str], pattern: str, backup_file: bool = True, cleanup: bool = False):
    path = Path(path)
    if path.is_dir():
        files = Path(path).glob(pattern)
    else:
        files = [path]
        
    for file in files:
        attempt_cast_to(file, backup_file=backup_file, cleanup=cleanup)

In [None]:
class TimeseriesLocalPath():
    
    def __init__(self, path: Union[Path, str, List[Union[Path, str]]], pattern: str = '**/*.parquet'):
        self.path = path
        self.pattern = pattern
        self.path_patterns = self._get_path_patterns()

        self.timeseries_schema = [
            {
                "name": "location_id",
                "type": "BYTE_ARRAY",
                "converted_type": "UTF8"
            },
            {
                "name": "reference_time",
                "type": "INT64",
                "converted_type": "TIMESTAMP_MICROS"
            },
            {
                "name": "value_time",
                "type": "INT64",
                "converted_type": "TIMESTAMP_MICROS"
            },
            {
                "name": "value",
                "type": "FLOAT",
                "converted_type": None
            },
            {
                "name": "variable_name",
                "type": "BYTE_ARRAY",
                "converted_type": "UTF8"
            },
            {
                "name": "measurement_unit",
                "type": "BYTE_ARRAY",
                "converted_type": "UTF8"
            },
            {
                "name": "configuration",
                "type": "BYTE_ARRAY",
                "converted_type": "UTF8"
            },
        ]

    def _get_path_patterns(self):
        if isinstance(self.path, List):
            strs = []
            for p in self.path:
                p = Path(p)
                if p.is_dir():
                    strs.append(str(Path(p, self.pattern)))
                else:
                    strs.append(str(Path(p)))
            return f"{[s for s in strs]}"
        else:
            p = Path(self.path)
            if p.is_dir():
                return str(Path(p, self.pattern))
            else:
                return str(Path(p))
        
    def validate(self):
        if isinstance(self.path, List):
            for p in self.path:
                validate_path(p, self.pattern, self.timeseries_schema)
        else:
            validate_path(self.path, self.pattern, self.timeseries_schema)

    def fix_schema(self, backup_file: bool = True, cleanup: bool = False): 
        if isinstance(self.path, List):
            for p in self.path:
                fix_path_schema(p, self.pattern, backup_file=backup_file, cleanup=cleanup)
        else:
            fix_path_schema(self.path, self.pattern, backup_file=backup_file, cleanup=cleanup)
            
    def __str__(self):
        return self.path_patterns

    def __repr__(self):
        return f"TimeseriesPath(path: {self.path}, pattern: {self.pattern})"

In [None]:
ts = TimeseriesLocalPath('/data/common/baselines/nwm30_retrospective_conus/streamflow_hourly_inst/')
print(ts)
ts.fix_schema(cleanup=True)
ts.validate()

In [None]:
df = duckdb.query("""
SELECT distinct(variable_name)
FROM read_parquet('/data/playground/retro_demo/retro/timeseries/usgs_2016.parquet');""").df()
df

In [None]:
len(set(variable_name_domains).intersection(set(df["variable_name"]))) >= 1

In [None]:
df = duckdb.query("""
SELECT distinct(measurement_unit)
FROM read_parquet('/data/playground/retro_demo/retro/timeseries/usgs_2016.parquet');""").df()
df

In [None]:
measurement_unit_domains = ["cms"]

In [None]:
len(set(measurement_unit_domains).intersection(set(df["measurement_unit"]))) >= 1