# Meta Analysis Step 2: AWS Platform Processing

This notebook processes AWS (Athena) database tables and extracts metadata.

**Purpose:**
- Connect to AWS Athena database (no credential renewal - uses IAM role)
- Load PCDS metadata from Step 1
- Extract AWS table metadata (columns, data types)
- Process row counts and date ranges
- Save results as pickle/JSON for Step 3

**Inputs:**
- `pcds_metadata.json` - Metadata from Step 1

**Outputs:**
- `aws_meta_results.pkl` - Complete metadata results
- `aws_metadata.json` - Structured metadata for next steps

**Note:** This runs on Linux/remote server with IAM role. No AWS credential renewal needed.

## Cell 1: Import Required Libraries

In [None]:
import re
import os
import csv
import json
import shutil
import pickle
import argparse
import warnings
import numpy as np
import pandas as pd
import pyathena as pa
import pandas.io.sql as psql

from upath import UPath
from loguru import logger
from tqdm import tqdm
from datetime import datetime, timedelta, timezone
from dataclasses import dataclass, field, fields, is_dataclass
from configparser import ConfigParser
from confection import Config
from unittest import mock
from enum import Enum
from typing import Literal, Dict, List
from collections import defaultdict, abc

warnings.filterwarnings("ignore", category=UserWarning, message='.*pandas only supports SQLAlchemy connectable.*')

# Note: This notebook uses Parquet format for cross-platform compatibility
# Install pyarrow if needed: pip install pyarrow or conda install -c conda-forge pyarrow
import pyarrow

## Cell 2: Constants and Configuration

In [None]:
# --- Global Constants ---
SEP = '; '
AWS_DT_FORMAT = '%Y-%m-%d'
TODAY = datetime.now()
ONEDAY = timedelta(days=1)
WIDTH = 80
NO_DATE = 'no_date_provided'

class PullStatus(Enum):
    """Enumeration for data pull status codes"""
    NONEXIST_AWS = 'Nonexisting AWS Table'
    NONDATE_AWS = 'Nonexisting Date Variable in AWS'
    EMPTY_AWS = 'Empty AWS Table'
    NO_MAPPING = 'Column Mapping Not Provided'
    SUCCESS = 'Successful Data Access'

# --- SQL Templates for AWS (Athena) ---
AWS_SQL_META = """
select column_name, data_type from information_schema.columns
where table_schema = LOWER('{db}') and table_name = LOWER('{table}')
"""

AWS_SQL_NROW = """
SELECT COUNT(*) AS nrow FROM {db}.{table}
where {limit}
"""

AWS_SQL_DATE = """
SELECT {date}, count(*) AS nrows
FROM {db}.{table} 
WHERE {limit}
GROUP BY {date}
"""

## Cell 3: Exception Classes and Data Types

In [None]:
# --- Custom Exceptions ---
class NONEXIST_TABLE(Exception):
    """Exception raised when database view does not exist"""
    pass

class NONEXIST_DATEVAR(Exception):
    """Exception raised when no date-like variable exists"""
    pass

# --- Helper Functions for Configuration Reading ---
def read_str_lst(lst_str, sep='\n'):
    """Parse newline-separated string into list"""
    return [x for x in lst_str.strip().split(sep) if x]

def read_dstr_lst(dct_str, sep='='):
    """Parse key=value pairs into dictionary"""
    d = dict(line.split(sep, 1) for line in read_str_lst(dct_str))
    return {k.strip(): v.strip() for k, v in d.items()}

# --- Base Type Class ---
class BaseType:
    """Base class with logging and nested dataclass support"""
    def __post_init__(self):
        for _field in fields(self):
            if is_dataclass(_field.type):
                field_val = _field.type(**getattr(self, _field.name))
                setattr(self, _field.name, field_val)

    def tolog(self, indent=1, padding=''):
        """Convert dataclass to formatted string for logging"""
        import pprint as pp
        def get_val(x, pad):
            if isinstance(x, BaseType):
                return x.tolog(indent, pad)
            elif isinstance(x, Dict):
                return pp.pformat(x, indent)
            else:
                return repr(x)
        cls_name = self.__class__.__name__
        padding = padding + '\t' * indent
        fields_str = [f'{padding}{k}={get_val(v, padding)}' for k, v in vars(self).items()]
        return f'{cls_name}(\n' + ',\n'.join(fields_str) + '\n)'

# --- Configuration Dataclasses ---
@dataclass
class MetaRange:
    """Range configuration for row selection"""
    start_rows: int | None
    end_rows: int | None

    def __iter__(self):
        yield from [self.start_rows or 1, self.end_rows or float('inf')]

@dataclass
class MetaInput(BaseType):
    """Input configuration"""
    name: str
    step: str
    env: str
    range: MetaRange
    category: Literal['loan', 'dpst']
    clear_cache: bool = True

@dataclass
class MetaCSV:
    """CSV output configuration"""
    file: UPath
    columns: str
    
    def __post_init__(self):
        self.columns = read_str_lst(self.columns)

@dataclass
class S3Config:
    """S3 path configuration"""
    run: UPath
    data: UPath

@dataclass
class LogConfig:
    """Logging configuration"""
    level: Literal['info', 'warning', 'debug', 'error']
    format: str
    file: str
    overwrite: bool

    def todict(self):
        return {
            'level': self.level.upper(),
            'format': self.format,
            'sink': self.file,
            'mode': 'w' if self.overwrite else 'a'
        }

@dataclass
class NextConfig:
    """Next step configuration"""
    file: UPath
    fields: str
    
    def __post_init__(self):
        self.fields = read_dstr_lst(self.fields)

@dataclass
class CacheConfig:
    """Cache configuration (not used in this notebook)"""
    enable: bool
    directory: UPath
    expire_hours: int = None
    force_restart: bool = False
    verbose: bool = False

@dataclass
class MetaOutput(BaseType):
    """Output configuration"""
    folder: UPath
    to_pkl: UPath
    csv: MetaCSV
    to_s3: S3Config
    log: LogConfig
    next: NextConfig
    cache: CacheConfig

@dataclass
class MetaConfig(BaseType):
    """Main configuration class"""
    input: MetaInput
    output: MetaOutput

@dataclass
class MetaRecord:
    """Record tracking during processing"""
    next_d: dict = field(default_factory=dict)
    pull_status: PullStatus = None

## Cell 4: Configuration Reading Functions

In [None]:
#--- Patch confection library to preserve case sensitivity ---#
def patch_confection():
    def get_configparser(interpolate: bool = True):
        from confection import CustomInterpolation
        config = ConfigParser(
            interpolation=CustomInterpolation() if interpolate else None,
            allow_no_value=True,
        )
        config.optionxform = str
        return config
    mock_obj = mock.patch('confection.get_configparser', wraps=get_configparser)
    if not hasattr(mock_obj, 'is_local'):
        mock_obj.start()

#--- Read configuration file and create config object ---#
def read_config(config_class: BaseType, config_path: None | UPath | str = None, overrides={}):
    patch_confection()
    if UPath(config_path).is_file():
        config = Config().from_disk(config_path, overrides=overrides)
    else:
        config = Config().from_str(config_path, overrides=overrides)
    return config_class(**{**config.pop('root', {}), **config})

## Cell 5: Utility Classes and Functions

In [None]:
#--- Start logging session with separator ---#
def start_run():
    logger.info('\n\n' + '=' * WIDTH)

#--- End logging session with separator ---#
def end_run():
    logger.info('\n\n' + '=' * WIDTH)

class IO:
    """File I/O utility class - uses Parquet/JSON for cross-platform compatibility"""

    @staticmethod
    def write_dataframe(file, df):
        """Save DataFrame in portable Parquet format"""
        file = UPath(file)
        df.to_parquet(file, index=True, engine='pyarrow', compression='snappy')

    @staticmethod
    def read_dataframe(file):
        """Load DataFrame from Parquet format"""
        file = UPath(file)
        return pd.read_parquet(file, engine='pyarrow')

    @staticmethod
    def write_json(file, data, cls=None):
        """Save to JSON with proper serialization"""
        import numpy as np
        import pandas as pd
        import datetime as dt

        def convert(obj):
            if isinstance(obj, (np.integer, np.floating)):
                return obj.item()
            elif isinstance(obj, np.ndarray):
                return obj.tolist()
            elif pd.isna(obj):
                return None
            elif isinstance(obj, (dt.datetime, dt.date)):
                return obj.isoformat()
            elif isinstance(obj, set):
                return list(obj)
            raise TypeError(f"Object of type {type(obj)} is not JSON serializable")

        with open(file, 'w') as f:
            json.dump(data, f, indent=2, default=convert, cls=cls)

    @staticmethod
    def read_json(file):
        """Read JSON file into dictionary"""
        with open(file, 'r') as fp:
            data = json.load(fp)
        return data

    @staticmethod
    def write_pickle(file, data):
        """Deprecated: Use write_dataframe or write_json instead"""
        with open(file, 'wb') as f:
            pickle.dump(data, f)

    @staticmethod
    def read_pickle(file):
        """Deprecated: Use read_dataframe or read_json instead"""
        with open(file, 'rb') as fp:
            data = pickle.load(fp)
        return data

    @staticmethod
    def delete_file(file):
        """Delete file if it exists"""
        if (filepath := UPath(file)).exists():
            filepath.unlink()

class UDict(dict):
    """Case-insensitive dictionary for flexible key matching"""

    def __getitem__(self, key):
        return super().__getitem__(self._match(key))

    def __contains__(self, key):
        try:
            self._match(key)
            return True
        except KeyError:
            return False

    def _match(self, key):
        """Find matching key regardless of case"""
        for k in self:
            if k.lower() == key.lower():
                return k
        raise KeyError(key)

    def update(self, other=None, **kwargs):
        if other is not None:
            for k, v in other.items() if isinstance(other, abc.Mapping) else other:
                self[k] = v
        for k, v in kwargs.items():
            assert self._match(k)
            self[k] = v

    def get(self, key, default_value=None):
        try:
            return self[key]
        except KeyError:
            return default_value

class Misc:
    """Miscellaneous utility functions"""

    @staticmethod
    def convert2int(a):
        """Safely convert value to integer"""
        try:
            return int(a)
        except (TypeError, ValueError):
            return None

    @staticmethod
    def convert2datestr(a):
        """Convert datetime to string format"""
        if isinstance(a, datetime):
            return a.strftime('%Y-%m-%d')
        return a

## Cell 6: AWS Database Connection and SQL Engine

In [None]:
class SQLengine:
    """SQL query engine for AWS Athena database"""
    
    def __init__(self, platform: Literal['AWS']):
        self._platform = platform
        self.reset()

    def reset(self):
        """Reset internal state"""
        self._where = None
        self._type = None
        self._date = None
        self._dateraw = None
        self._table = None
        self._format = AWS_DT_FORMAT

    def extract_var(self, stmt):
        """Extract variable names from SQL date expression"""
        def _extract_var():
            word, time, tagt = r'\w+_\w+', r"'[^']*'", r'[^,]+'
            pattern1 = fr"{word}\({word}\(({tagt}),\s*{time}\),\s*{time}\)"
            pattern2 = fr"{word}\(({tagt}),\s*{time}\)"
            if (m := re.match(pattern1, stmt)):
                return stmt, m.group(1)
            elif (m := re.match(pattern2, stmt)):
                return stmt, m.group(1)
            return stmt, stmt
        
        date_var, date_raw = _extract_var()
        return date_var, date_raw.lower()

    def query(self, query, connection, **query_kwargs):
        """Execute SQL query and return DataFrame"""
        query = self.clean_query(query)
        df = psql.read_sql_query(query, connection, **query_kwargs)
        
        #>>> Normalize column names to lowercase <<<#
        df.columns = [x.lower() for x in df.columns]
        return df

    def clean_query(self, query: str):
        """Clean and prepare SQL query for execution"""
        #>>> Extract table name from query <<<#
        table_pattern = r'([\w.]+)\s+MORF\b'
        self._table = re.search(table_pattern, query[::-1], flags=re.I).group(1)[::-1]
        
        #>>> Add alias to date column if needed <<<#
        date_pattern = r'(?!\\s+(?:AS\s+)\w+)'
        if self._date and (match := re.search(
            re.escape(self._date) + date_pattern,
            re.split(r'\b(?:FROM|WHERE)\b', query, flags=re.I)[0],
            flags=re.I
        )):
            st, ed = match.span()
            query = query[:st] + f'{self._date} as {self._dateraw}' + query[ed:]
        
        #>>> Remove empty WHERE clauses <<<#
        where_pattern = r'^\s*where\s*$'
        return re.sub(where_pattern, '', query, flags=re.I | re.M)

    def get_where_sql(self, date_var: str, date_type: str, start_dt=None, end_dt=None, where_cstr='') -> str:
        """Build WHERE clause for date filtering"""
        self._type = date_type
        
        #>>> Handle subquery in where constraint <<<#
        if not pd.isna(where_cstr) and (m := re.search(r'(?<=\()select.*(?=\))', where_cstr)):
            rhs = self.query_AWS(m.group()).iloc[0, 0]
            if isinstance(rhs, str):
                where_cstr = "%s '%s'" % (where_cstr[:m.start() - 1], rhs)
            else:
                where_cstr = "%s '%s'" % (where_cstr[:m.start() - 1], rhs.strftime('%Y-%m-%d'))
        
        where_sql = [where_cstr]
        self.get_date_sql(date_var, date_type)
        
        #>>> Add date range filters <<<#
        if not pd.isna(start_dt):
            start_dt = Misc.convert2datestr(start_dt)
            where_sql.append(f"{self._date} >= '{start_dt}'")
        if not pd.isna(end_dt):
            end_dt = Misc.convert2datestr(end_dt)
            where_sql.append(f"{self._date} <= '{end_dt}'")
        
        self._where = ' AND '.join(x for x in where_sql if not pd.isna(x))

    @staticmethod
    def get_date_format(date_var):
        """Extract date format from variable specification"""
        pattern = r'^(.+?)(?:\s*\(([^)]+)\))?$'
        date_var, date_format = re.match(pattern, date_var).groups()
        return date_var, date_format

    def get_date_sql(self, date_var: str, date_type: str):
        """Convert date column to standard format in SQL"""
        date_var, date_format = self.get_date_format(date_var)
        is_date = re.search(r'time|date', date_type, re.IGNORECASE)
        
        #>>> Parse string dates if format provided <<<#
        if date_format and (not is_date):
            date_var = f"DATE_PARSE({date_var}, '{date_format}')"
            is_date = True
        
        #>>> Convert to standard string format <<<#
        if is_date:
            date_var = f"DATE_FORMAT({date_var}, '%Y-%m-%d')"
        
        self._date, self._dateraw = self.extract_var(date_var)

    def __repr__(self):
        return f'SQL({self._platform})\n' \
               f'   table: {self._table}\n' \
               f'   where: {self._where}\n' \
               f'   date : {self._date} ({self._dateraw})'

    def query_AWS(self, query_stmt: str, **query_kwargs):
        """Execute query on AWS Athena (no credential renewal - uses IAM role)"""
        CONN = pa.connect(
            s3_staging_dir="s3://355538383407-us-east-1-athena-output/uscb-analytics/",
            region_name="us-east-1",
        )
        return self.query(query_stmt, CONN, **query_kwargs)

## Cell 7: AWS Processing Functions

In [None]:
# Initialize global objects
proc_aws = SQLengine('AWS')
record = MetaRecord()

#--- Process AWS table metadata (columns and row count) ---#
def process_aws_meta(row):
    database, table = (info_str := row['aws_tbl']).split('.', maxsplit=1)
    CONN = pa.connect(
        s3_staging_dir="s3://355538383407-us-east-1-athena-output/uscb-analytics/",
        region_name="us-east-1",
    )
    logger.info(f"\tStart processing {info_str}")
    
    #>>> Query column metadata and row counts <<<#
    try:
        df_type = proc_aws.query(AWS_SQL_META.format(table=table, db=database), CONN)
        date_var = re.match(r'(\w+)(?=\s*\()?', row['aws_dt']).group(1)
        if date_var == NO_DATE:
            proc_aws._where = row.get('aws_where')
        else:
            proc_aws.get_where_sql(
                date_var=row['aws_dt'],
                date_type=df_type.query(f"column_name == '{date_var.lower()}'")['data_type'].item(),
                start_dt=row.get('start_dt'),
                end_dt=row.get('end_dt'),
                where_cstr=row.get('aws_where')
            )
        nrow_sql = AWS_SQL_NROW.format(table=table, db=database, limit=proc_aws._where)
        df_nrow = proc_aws.query(nrow_sql, CONN)
    except pd.errors.DatabaseError:
        logger.warning(f"Couldn't find {table.lower()} in {database.lower()}")
        raise NONEXIST_TABLE("AWS View Not Existing")
    
    df_type.columns = [x.lower() for x in df_type.columns]
    return {'column': df_type, 'row': df_nrow}

#--- Query AWS table for date-wise row counts ---#
def process_aws_date(row):
    database, table = (info_str := row['aws_tbl']).split('.', maxsplit=1)
    CONN = pa.connect(
        s3_staging_dir="s3://355538383407-us-east-1-athena-output/uscb-analytics/",
        region_name="us-east-1",
    )
    try:
        date_sql = AWS_SQL_DATE.format(
            table=table, limit=proc_aws._where, date=proc_aws._date, db=database
        )
        df_meta = proc_aws.query(date_sql, CONN)
        logger.info(f"\tFinish Processing {info_str}")
    except pd.errors.DatabaseError:
        if proc_aws._dateraw:
            logger.warning(f"Column {proc_aws._dateraw.upper()} not found in {table.upper()}")
        raise NONEXIST_DATEVAR("Date-like Variable Not In AWS")
    
    df_meta.columns = [x.lower() for x in df_meta.columns]
    return df_meta

#--- Initialize output folders and logging ---#
def start_setup(start_row, C_out):
    try:
        assert start_row <= 1
        os.remove(C_out.csv.file)
    except (TypeError, AssertionError, FileNotFoundError):
        pass
    os.makedirs(C_out.folder, exist_ok=True)
    logger.add(**C_out.log.todict())

## Cell 8: Configuration Parsing

In [None]:
#--- Parse command line arguments and load configuration ---#
def parse_config():
    parser = argparse.ArgumentParser(description='Conduct Meta Info Analysis - AWS Step')
    parser.add_argument(
        '--category',
        choices=['loan', 'dpst'],
        default='dpst',
        help='which meta template to use',
    )
    parser.add_argument(
        '--name', type=str,
        default='test_0827',
        help='how to name this analysis (override)'
    )
    parser.add_argument(
        '--query', type=str,
        default='group == "test_0827"',
        help='how to name this analysis (override)'
    )
    args = parser.parse_args()

    if args.category == 'dpst':
        config_path = r'files/inputs/config_meta_dpst.cfg'
    elif args.category == 'loan':
        config_path = r'files/inputs/config_meta_loan.cfg'
    
    config = read_config(
        MetaConfig,
        config_path=config_path,
        overrides={
            'input.table.select_rows': args.query,
            'input.name': args.name
        }
    )
    (out_folder := UPath(config.output.folder)).mkdir(exist_ok=True)
    shutil.copy(config_path, out_folder.joinpath(f'{config.input.step}_aws.cfg'))
    return config

## Cell 9: Main Execution - AWS Processing

In [None]:
def main():
    """Main execution function for AWS meta analysis"""
    config = parse_config()
    df_dict, df_next = {}, {}
    C_out, C_in = config.output, config.input
    start_row, end_row = C_in.range
    start_setup(start_row, C_out)
    logger.info('Configuration:\n' + config.tolog())

    start_run()

    #>>> Load PCDS metadata from Step 1 <<<#
    output_folder = UPath(C_out.folder)
    pcds_metadata_file = output_folder / 'pcds_metadata.json'

    if not pcds_metadata_file.exists():
        logger.error(f"PCDS metadata file not found: {pcds_metadata_file}")
        logger.error("Please run Step 1 (meta_analysis_1_pcds.ipynb) first!")
        return

    pcds_metadata = IO.read_json(pcds_metadata_file)
    logger.info(f"Loaded PCDS metadata for {len(pcds_metadata)} tables")

    #>>> Process each table <<<#
    total = len(pcds_metadata)
    for i, (name, pcds_meta) in enumerate(tqdm(
        pcds_metadata.items(), desc='Processing AWS ...', total=total
    ), start=1):
        if (i < start_row or i > end_row):
            continue

        aws_m, aws_d = {}, None
        record.next_d = UDict(pcds_meta)

        logger.info(f">>> Start {name}")

        pull_status = PullStatus.SUCCESS

        #>>> Try to pull AWS table metadata <<<#
        try:
            aws_m = process_aws_meta(pcds_meta)
        except NONEXIST_TABLE:
            pull_status = PullStatus.NONEXIST_AWS
            logger.error(f"AWS table {name} does not exist")
            continue

        #>>> Try to get date-wise counts from AWS <<<#
        try:
            aws_d = process_aws_date(pcds_meta)
            if len(aws_m) == 0:
                pull_status = PullStatus.EMPTY_AWS
        except NONEXIST_DATEVAR:
            if pull_status == PullStatus.SUCCESS:
                pull_status = PullStatus.NONDATE_AWS

        #>>> Store results <<<#
        df_dict[name] = {
            'aws_meta': aws_m,
            'aws_date': aws_d,
            'status': pull_status.value,
            'sql_engine': {
                'where': proc_aws._where,
                'date': proc_aws._date,
                'dateraw': proc_aws._dateraw,
                'type': proc_aws._type
            }
        }

        #>>> Save metadata for next step <<<#
        if aws_m and 'column' in aws_m:
            record.next_d.update(
                aws_cols=SEP.join(aws_m['column']['column_name'].tolist()),
                aws_types=SEP.join(aws_m['column']['data_type'].tolist()),
                aws_nrows=int(aws_m['row'].iloc[0, 0]) if 'row' in aws_m else 0,
                aws_where=proc_aws._where,
                aws_dt_type=proc_aws._type
            )

        df_next[name] = record.next_d.copy()

        #>>> Reset engine for next iteration <<<#
        proc_aws.reset()

    #>>> Save results using Parquet/JSON format <<<#
    # Save individual parquet files for each table
    for table_name, table_data in df_dict.items():
        if table_data['aws_meta'] and 'column' in table_data['aws_meta']:
            # Save column info as parquet
            col_file = output_folder / f'aws_column_info_{table_name}.parquet'
            IO.write_dataframe(col_file, table_data['aws_meta']['column'])

        if table_data['aws_date'] is not None:
            # Save date counts as parquet
            date_file = output_folder / f'aws_date_counts_{table_name}.parquet'
            IO.write_dataframe(date_file, table_data['aws_date'])

        # Save other metadata as JSON
        meta_file = output_folder / f'aws_metadata_{table_name}.json'
        IO.write_json(meta_file, {
            'status': table_data['status'],
            'sql_engine': table_data['sql_engine'],
            'nrows': table_data['aws_meta'].get('row', pd.DataFrame()).to_dict() if table_data['aws_meta'] else {}
        })

    # Save summary with file paths
    summary_data = {
        name: {
            'column_file': str(output_folder / f'aws_column_info_{name}.parquet'),
            'date_file': str(output_folder / f'aws_date_counts_{name}.parquet'),
            'meta_file': str(output_folder / f'aws_metadata_{name}.json'),
            'status': data['status']
        } for name, data in df_dict.items()
    }
    summary_file = output_folder / 'aws_summary.json'
    IO.write_json(summary_file, summary_data)

    # Save next step metadata (merged PCDS+AWS)
    IO.write_json(output_folder / 'aws_metadata.json', df_next)

    logger.info(f"Saved AWS results to {output_folder}")
    logger.info(f"  - Individual parquet/JSON files: {len(df_dict)} tables")
    logger.info(f"  - aws_summary.json: summary of all files")
    logger.info(f"  - aws_metadata.json: metadata for next steps")

    end_run()

if __name__ == '__main__':
    main()

## Run the Analysis

Uncomment the cell below to run the AWS processing:

In [None]:
# main()