# Column Statistics Analysis - Step 2: AWS Data Collection

This notebook collects column-level statistics from AWS (Athena) platform.
It reads metadata from step 1 and will save results for use in step 3 (comparison).

**Note:** This notebook runs on Linux/remote server, so aws_creds_renew is not needed.

## Cell 1: Import Required Libraries

In [None]:
import re
import os
import json
import pickle
import warnings
import numpy as np
import pandas as pd
import pyathena as pa
import functools as ft
import multiprocessing as mp
import threading as td
import datetime as dt
import time

from upath import UPath
from loguru import logger
from tqdm import tqdm
from typing import Literal
from dataclasses import dataclass, field, fields
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed

warnings.filterwarnings('ignore', message=r'pandas only supports SQLAlchemy connectable .*', category=UserWarning)

# Note: This notebook uses Parquet format for cross-platform compatibility
# Install pyarrow if needed: pip install pyarrow or conda install -c conda-forge pyarrow
import pyarrow

## Cell 2: Constants and Configuration

In [None]:
# --- Global Constants ---
SEP = '; '
AWS_DT_FORMAT = '%Y-%m-%d'
TODAY = dt.datetime.now()
WIDTH = 80

TPartition = Literal['whole', 'year', 'year_month', 'empty', 'year_week', 'week', 'snapshot']

# --- SQL Templates for AWS Column Statistics ---
AWS_Cont_SQL = """
SELECT
    '{data_type}' AS col_type,
    COUNT({column_name}) AS col_count,
    COUNT(DISTINCT {column_name}) AS col_distinct,
    MAX({column_name}) AS col_max,
    MIN({column_name}) AS col_min,
    AVG(CAST({column_name} AS DOUBLE)) AS col_avg,
    STDDEV_SAMP(CAST({column_name} AS DOUBLE)) AS col_std,
    SUM(CAST({column_name} AS DOUBLE)) AS col_sum,
    SUM(CAST({column_name} AS DOUBLE) * CAST({column_name} AS DOUBLE)) AS col_sum_sq,
    '' AS col_freq,
    COUNT(*) - COUNT({column_name}) AS col_missing
FROM {db}.{table}
WHERE {limit};
"""

AWS_Catg_SQL = """
WITH FreqTable_RAW AS (
    SELECT
        {column_name} AS p_col,
        COUNT(*) AS value_freq
    FROM  {db}.{table}
    WHERE {limit}
    GROUP BY {column_name}
),FreqTable AS (
    SELECT
        p_col, value_freq, 
        ROW_NUMBER() OVER (ORDER BY value_freq DESC, p_col ASC) AS rn
    FROM FreqTable_RAW
)
SELECT
    '{data_type}' AS col_type,
    SUM(value_freq) AS col_count,
    COUNT(value_freq) AS col_distinct,
    MAX(value_freq) AS col_max,
    MIN(value_freq) AS col_min,
    AVG(CAST(value_freq AS DOUBLE)) AS col_avg,
    STDDEV_SAMP(CAST(value_freq AS DOUBLE)) AS col_std,
    SUM(value_freq) AS col_sum,
    SUM(value_freq * value_freq) AS col_sum_sq,
    (SELECT ARRAY_JOIN(ARRAY_AGG(COALESCE(CAST(p_col AS VARCHAR), '') || '(' || CAST(value_freq AS VARCHAR) || ')' ORDER BY value_freq DESC), '; ') FROM FreqTable WHERE rn <= 10) AS col_freq, 
    (SELECT COALESCE(value_freq, 0) FROM FreqTable Where p_col is NULL) AS col_missing
FROM FreqTable
"""

## Cell 3: Core Data Types

In [None]:
class Timer:
    """Context manager for timing code execution"""
    
    def __enter__(self):
        self.start = time.perf_counter()
        return self
    
    def __exit__(self, exc_type, exc_value, exc_tb):
        pass

    @property
    def time(self):
        return time.perf_counter() - self.start
    
    def pause(self):
        """Return elapsed time and reset timer"""
        elapsed = self.time
        self.start = time.perf_counter()
        return elapsed

    @staticmethod
    def to_str(value):
        """Convert seconds to human-readable format"""
        minutes, seconds = divmod(value, 60)
        hours, minutes = divmod(minutes, 60)
        return f'{hours} hours {minutes} minutes {seconds:.0f} seconds'

@dataclass
class MetaOut:
    """Metadata output structure"""
    col2COL: dict
    col2type: dict
    infostr: str
    rowvar: str
    rowexclude: list
    rowtype: str
    nrows: int
    where: str

@dataclass(init=False)
class MetaJSON:
    """Container for metadata from previous meta analysis step"""
    pcds: MetaOut
    aws: MetaOut
    last_modified: str
    partition: TPartition = 'whole'
    tokenised_cols: list = field(default_factory=list)

    def __init__(self, **kwargs):
        field_names = [f.name for f in fields(self)]
        for k, v in kwargs.items():
            if k in field_names:
                setattr(self, k, v)
        
        def col2col(a_str, b_str, sep=SEP):
            return {k: v for k, v in zip(a_str.split(sep), b_str.split(sep))}
        
        for key, other in [('pcds', 'aws'), ('aws', 'pcds')]:
            out = MetaOut(
                rowvar=kwargs['%s_dt' % key],
                infostr=kwargs['%s_tbl' % key],
                where=kwargs['%s_where' % key],
                nrows=kwargs['%s_nrows' % key],
                col2COL=col2col(kwargs['%s_cols' % key], kwargs['%s_cols' % other]),
                col2type=col2col(kwargs['%s_cols' % key], kwargs['%s_types' % key]),
                rowtype=kwargs['%s_dt_type' % key],
                rowexclude=kwargs['%s_exclude' % key]
            )
            setattr(self, key, out)

@dataclass
class CSMeta:
    """Metadata for column statistics comparison"""
    pcds_table: str
    aws_table: str
    partition: TPartition
    vintage: str
    pcds_time: int = 0
    aws_time: int = 0

    def todict(self):
        return {f.name: getattr(self, f.name) for f in fields(self)}

## Cell 4: Utility Functions

In [None]:
def start_run():
    logger.info('\n\n' + '=' * WIDTH)

def end_run():
    logger.info('\n\n' + '=' * WIDTH)

class IO:
    """File I/O utilities - uses Parquet/JSON for cross-platform compatibility"""

    @staticmethod
    def write_dataframe(file, df):
        """Save DataFrame in portable Parquet format"""
        file = UPath(file)
        df.to_parquet(file, index=True, engine='pyarrow', compression='snappy')

    @staticmethod
    def read_dataframe(file):
        """Load DataFrame from Parquet format"""
        file = UPath(file)
        return pd.read_parquet(file, engine='pyarrow')

    @staticmethod
    def write_json(file, data, cls=None):
        """Save to JSON with proper serialization"""
        import numpy as np
        import pandas as pd
        import datetime as dt

        def convert(obj):
            if isinstance(obj, (np.integer, np.floating)):
                return obj.item()
            elif isinstance(obj, np.ndarray):
                return obj.tolist()
            elif pd.isna(obj):
                return None
            elif isinstance(obj, (dt.datetime, dt.date)):
                return obj.isoformat()
            elif isinstance(obj, set):
                return list(obj)
            raise TypeError(f"Object of type {type(obj)} is not JSON serializable")

        with open(file, 'w') as f:
            json.dump(data, f, indent=2, default=convert, cls=cls)

    @staticmethod
    def read_json(file):
        """Load from JSON"""
        with open(file, 'r') as f:
            return json.load(f)

    @staticmethod
    def write_pickle(file, data):
        """Deprecated: Use write_dataframe or write_json instead"""
        with open(file, 'wb') as f:
            pickle.dump(data, f)

    @staticmethod
    def read_pickle(file):
        """Deprecated: Use read_dataframe or read_json instead"""
        with open(file, 'rb') as f:
            return pickle.load(f)

    @staticmethod
    def read_meta_json(json_file):
        """Read metadata JSON and convert to MetaJSON objects"""
        data = IO.read_json(json_file)
        return {k: MetaJSON(**v) for k, v in data.items()}

## Cell 5: Date Handling Functions for AWS

In [None]:
def get_iso_week_dates(year, week):
    jan01, dec31 = dt.datetime(year, 1, 1), dt.datetime(year, 12, 31)
    first_day = jan01 - dt.timedelta(days=jan01.weekday())
    start = first_day + dt.timedelta(weeks=week - 1)
    end = start + dt.timedelta(days=6)
    start, end = max(start, jan01), min(end, dec31)
    return start.strftime('%Y-%m-%d'), end.strftime('%Y-%m-%d')

def parse_format_date(str_w_format):
    pattern = r'^(.+?)(?:\s*\(([^)]+)\))?$'
    return re.match(pattern, str_w_format)

def parse_exclude_date(exclude_clause):
    """Convert date exclusions to AWS format"""
    p2 = r"DATE_FORMAT\(DATE_PARSE\((?P<col>\w+),\s*'(?P<fmt>%Y%m%d)'\),\s*'%Y-%m-%d'\)\s+(?P<op>not in|in)\s+\((?P<dates>.*?)\)"
    if m := re.match(p2, exclude_clause, flags=re.I):
        col, fmt, op, dates = m.groups()
        new_dates = ', '.join(
            "'%s'" % dt.datetime.strptime(date.strip("'"), '%Y-%m-%d').strftime(fmt)
            for date in dates.split(',')
        )
        return '%s %s (%s)' % (col, op, new_dates)
    return exclude_clause

def get_aws_where(date_var, date_type, date_partition, date_range, date_format, snapshot=None, exclude_clauses=[]):
    if '=' in date_range:
        _date_var, date_range = date_range.split('=', 1)
        assert date_var.split()[0] == _date_var, f"Date Variable Should Match: {date_var} vs {_date_var}"
    
    if (m := parse_format_date(date_var)):
        date_var, date_format = m.groups()
    
    if date_type and re.match(r'^(string|varchar)', date_type, re.IGNORECASE):
        if date_format:
            date_var = f"DATE_PARSE({date_var}, '{date_format}')"
        else:
            date_var = f"DATE_PARSE({date_var}, '%Y%m%d')"
    
    if snapshot:
        return ' AND '.join('(%s)' % parse_exclude_date(x) for x in exclude_clauses if x)
    elif date_partition == 'whole':
        base_clause = "1=1"
    elif date_partition == 'year':
        base_clause = f"DATE_FORMAT({date_var}, '%Y') = '{date_range}'"
    elif date_partition == 'year_month':
        base_clause = f"DATE_FORMAT({date_var}, '%Y-%m') = '{date_range}'"
    elif date_partition in ('year_week', 'week'):
        if '-W' in date_range:
            year, week = date_range.split('-W')
        else:
            year, week = map(int, date_range.split('-'))
            week = f"W{week:02d}"
        base_clause = f"DATE_FORMAT({date_var}, '%Y-%v') = '{year}-{week}'"
    elif date_partition == 'daily':
        base_clause = f"DATE({date_var}) = DATE('{date_range}')"
    else:
        raise ValueError(f"Unsupported partition type: {date_partition}")
    
    if (exclude_clauses := [x for x in exclude_clauses if x]):
        exclude_clause = ' AND '.join('(%s)' % parse_exclude_date(x) for x in exclude_clauses if x)
        return f"({base_clause}) AND ({exclude_clause})"
    else:
        return base_clause

## Cell 6: AWS SQL Engine

In [None]:
import pandas.io.sql as psql

class SQLengine:
    """SQL query engine for AWS"""
    
    def __init__(self):
        self.reset()

    def reset(self):
        self._where = None
        self._type = None
        self._date = None
        self._dateraw = None
        self._table = None

    def query(self, query, connection, **query_kwargs):
        """Execute SQL query and return DataFrame"""
        df = psql.read_sql_query(query, connection, **query_kwargs)
        df.columns = [x.lower() for x in df.columns]
        return df

    def query_AWS(self, query_stmt: str, **query_kwargs):
        """Execute query on AWS Athena"""
        # No aws_creds_renew needed - running on Linux/remote server with IAM role
        CONN = pa.connect(
            s3_staging_dir="s3://355538383407-us-east-1-athena-output/uscb-analytics/",
            region_name="us-east-1",
        )
        return self.query(query_stmt, CONN, **query_kwargs)

proc_aws = SQLengine()

## Cell 7: AWS Column Analyzer

In [None]:
class PsuedoLock:
    """Dummy lock for single-threaded execution"""
    def __enter__(self):
        return self
    def __exit__(self, exc_type, exc_val, exc_tb):
        pass

class ColumnAnalyzer:
    """AWS column analysis engine"""
    
    parallel = 'process'

    @staticmethod
    def run_aws_single(data_type, column_name, db, table, limit='', lock=None):
        """Run statistics on single AWS column"""
        continuous_types = ('tinyint', 'smallint', 'integer', 'bigint', 'float', 'double', 'decimal')
        
        if re.match('^time', data_type, flags=re.I):
            column_name = 'CAST(%s AS DATE)' % column_name
        
        is_continuous = any(ct in data_type.lower() for ct in continuous_types)
        if is_continuous:
            sql_template = AWS_Cont_SQL
        else:
            sql_template = AWS_Catg_SQL
        
        sql_stmt = sql_template.format(
            db=db, table=table, column_name=column_name, data_type=data_type, limit=limit
        )
        
        # No lock needed for credential renewal on Linux/remote server
        proc_aws_local = SQLengine()
        return proc_aws_local.query_AWS(sql_stmt)

    def run_aws_column_analysis(self, info_str: str, columns_info: dict, where_clause: str = "1=1", n_jobs=None) -> pd.DataFrame:
        """Run column analysis on AWS table with optional parallelization"""
        results = {}
        db_name, table_name = info_str.split('.')
        logger.info(f"Executing AWS analysis for {table_name}")
        worker = ft.partial(self.run_aws_single, db=db_name, table=table_name, limit=where_clause)
        
        try:
            if n_jobs is None:
                locker = PsuedoLock()
                for col_name, data_type in tqdm(columns_info.items()):
                    results[col_name] = worker(data_type, col_name, lock=locker)
            else:
                if self.parallel == "thread":
                    executor_class = ThreadPoolExecutor
                    locker = td.Lock()
                elif self.parallel == "process":
                    executor_class = ft.partial(ProcessPoolExecutor, mp_context=mp.get_context('spawn'))
                    locker = mp.Manager().Lock()
                
                with executor_class(max_workers=n_jobs) as executor:
                    futures = {}
                    for col_name, data_type in tqdm(columns_info.items()):
                        futures[executor.submit(worker, data_type, col_name, lock=locker)] = col_name
                    
                    for future in tqdm(as_completed(futures), total=len(futures), desc='Processing ... '):
                        try:
                            col_name = futures[future]
                            results[col_name] = future.result()
                        except Exception as e:
                            logger.error(f"Task failed: {e}")
                    executor.shutdown()
            
            df = pd.concat(results.values(), keys=results).droplevel(1)
            return df
        except Exception as e:
            logger.error(f"Error in AWS analysis: {e}")
            raise

## Cell 8: Main Execution - AWS Data Collection

In [None]:
def main_aws():
    """Main execution function for AWS column statistics collection"""

    # Configuration - adjust these paths as needed
    meta_json_path = 'path/to/meta_analysis_output.json'  # From meta_analysis step
    meta_csv_path = 'path/to/meta_analysis.csv'
    pcds_summary_path = 'output/column_stats_pcds/pcds_summary.json'  # From step 1
    output_folder = UPath('output/column_stats_aws')
    output_folder.mkdir(exist_ok=True, parents=True)

    n_process = 4  # Number of parallel processes

    start_run()

    # Load metadata from previous steps
    meta_json = IO.read_meta_json(meta_json_path)
    meta_csv = pd.read_csv(meta_csv_path)

    # Load PCDS summary to get vintages
    if UPath(pcds_summary_path).exists():
        pcds_summary = IO.read_json(pcds_summary_path)
        logger.info(f"Loaded PCDS summary from {pcds_summary_path}")
    else:
        logger.warning(f"PCDS summary not found at {pcds_summary_path}, using default vintages")
        pcds_summary = {}

    CA = ColumnAnalyzer()
    summary_data = {}

    for i, row in tqdm(meta_csv.iterrows(), desc='Processing AWS tables...', total=len(meta_csv)):
        name = row.get('PCDS Table Details with DB Name')
        logger.info(f"Processing dataset: {name}")

        # Load metadata for this table
        meta_info = meta_json.get(name)
        if not meta_info:
            continue

        meta_aws = meta_info.aws
        partition = meta_info.partition

        if partition == 'empty':
            continue

        # Get vintages from PCDS summary
        if name in pcds_summary:
            vintages = list(pcds_summary[name].keys())
        else:
            vintages = ['entire_dataset']

        # Clean table name for file naming
        table_name = name.split('.')[-1].lower()

        # Remove PII and tokenized columns
        avai_cols = [x for x in meta_aws.col2type if x not in meta_info.tokenised_cols]
        col2type_filtered = {k: v for k, v in meta_aws.col2type.items() if k in avai_cols}

        for vintage in vintages:
            logger.info(f"Processing vintage: {vintage}")

            # Build WHERE clause for this vintage
            rowvar = meta_aws.rowvar
            aws_where = get_aws_where(
                date_var=rowvar,
                date_type=meta_aws.rowtype,
                date_partition=partition,
                date_range='%s=%s' % (re.sub(r"\s*\(.*?\)$", "", rowvar), vintage) if vintage != 'entire_dataset' else vintage,
                date_format='%Y%m%d',
                snapshot=partition == 'snapshot',
                exclude_clauses=[meta_aws.where, meta_aws.rowexclude]
            )

            # Compute column statistics
            with Timer() as timer:
                aws_stats = CA.run_aws_column_analysis(
                    meta_aws.infostr, col2type_filtered, aws_where, n_process
                )
                aws_time = timer.pause()

            # Save individual parquet file for this table/vintage
            stats_file = output_folder / f'aws_stats_{table_name}_{vintage}.parquet'
            IO.write_dataframe(stats_file, aws_stats)
            logger.info(f"Saved stats to {stats_file}")

            # Save metadata as JSON
            meta_data = CSMeta(
                pcds_table=meta_info.pcds.infostr,
                aws_table=meta_aws.infostr,
                partition=partition,
                vintage=vintage,
                aws_time=aws_time,
            ).todict()

            meta_file = output_folder / f'aws_meta_{table_name}_{vintage}.json'
            IO.write_json(meta_file, {
                'meta_data': meta_data,
                'aws_where': aws_where
            })
            logger.info(f"Saved metadata to {meta_file}")

            # Add to summary
            if name not in summary_data:
                summary_data[name] = {}
            summary_data[name][vintage] = {
                'stats_file': str(stats_file),
                'meta_file': str(meta_file),
                'table_name': table_name,
                'meta_data': meta_data
            }

            logger.info(f"Completed AWS analysis for vintage {vintage}")

        proc_aws.reset()

    # Save summary file
    summary_file = output_folder / 'aws_summary.json'
    IO.write_json(summary_file, summary_data)
    logger.info(f"Summary saved to {summary_file}")
    logger.info(f"Processed {len(summary_data)} tables with {sum(len(v) for v in summary_data.values())} total vintages")

    end_run()
    return summary_data

if __name__ == '__main__':
    results = main_aws()

## Run the Analysis

Uncomment the cell below to run the AWS data collection:

In [None]:
# results = main_aws()