In [None]:
# Import python packages
import streamlit as st

# We can also use Snowpark for our analyses!
from snowflake.snowpark.context import get_active_session
session = get_active_session()


In [None]:
-- Welcome to Snowflake Notebooks!
-- Try out a SQL cell to generate some data.
SELECT 'FRIDAY' as SNOWDAY, 0.2 as CHANCE_OF_SNOW
UNION ALL
SELECT 'SATURDAY',0.5
UNION ALL 
SELECT 'SUNDAY', 0.9;

In [None]:
# Then, we can use the python name to turn cell2 into a Pandas dataframe
my_df = cell2.to_pandas()

# Chart the data
st.subheader("Chance of SNOW ❄️")
st.line_chart(my_df, x='SNOWDAY', y='CHANCE_OF_SNOW')

# Give it a go!
st.subheader("Try it out yourself and show off your skills 🥇")

In [None]:
import snowflake
import snowflake.snowpark as snowpark
import snowflake.snowpark.functions as F
from snowflake.snowpark.functions import udf, pandas_udf
from snowflake.snowpark.dataframe import map
from snowflake.snowpark.window import Window
from snowflake.snowpark.types import IntegerType, FloatType, StringType, BooleanType, StructType, MapType, ArrayType, PandasSeriesType
from snowflake.snowpark.context import get_active_session
import pandas as pd
import numpy as np
import json
import re
from typing import List

In [None]:
session = get_active_session()
batch_raw = session.table('AATT_RAW.PUBLIC.BATCH_RAW')
field_raw = session.table('AATT_RAW.PUBLIC.FIELD_RAW')
batch_table = session.table('AATT_CLEANSED.PUBLIC.BATCH_CLEANSED')
field_table = session.table('AATT_CLEANSED.PUBLIC.FIELD_CLEANSED')
age_gates = [13,16,18]

subjects_df = session.table('AATT_RAW.PUBLIC.SUBJECTS')

In [None]:
@pandas_udf(return_type = PandasSeriesType(StringType()),
            input_types = [
                PandasSeriesType(StringType()),
                PandasSeriesType(StringType()),
                PandasSeriesType(StringType()),
                PandasSeriesType(StringType()),
            ],
            is_permanent = False,
            replace = True
)
def fill_subject_background(
    first_nations: pd.Series,
    subject_bg: pd.Series,
    father_bg: pd.Series,
    mother_bg: pd.Series,
) -> pd.Series:

    first_nations_rows = first_nations.isin(["aboriginal", "torresStrait", "both"])

    parents_same_bg_rows = (
        (pd.notna(father_bg)) &
        (pd.notna(mother_bg)) &
        (father_bg == mother_bg) &
        ~(father_bg.isin(("NONE","none"))) &
        ~(mother_bg.isin(("NONE","none")))
    )
    subject_valid_bg_rows = (
        (pd.notna(subject_bg)) &
        ~(subject_bg.isin(("NONE","none")))
    )

    bg = pd.Series(pd.NA, index=subject_bg.index)

    bg.loc[subject_valid_bg_rows] = subject_bg
    bg.loc[parents_same_bg_rows] = father_bg[parents_same_bg_rows]
    bg.loc[first_nations_rows] = "First Nations"
    
    return bg
    

In [None]:
@pandas_udf(
    return_type = PandasSeriesType(IntegerType()),
    input_types = [
        PandasSeriesType(StringType())
    ],
    is_permanent = False,
    replace = True
)
def parse_verification_time(verification_time: pd.Series) -> pd.Series:

    array_time_series = verification_time.str.split(":")
    time_series = array_time_series.apply(
        lambda x: pd.to_numeric(x[2]) + 60*pd.to_numeric(x[1]) + 3600*pd.to_numeric(x[0]) if x is not None and len(x) == 3 
                      else None
    )
    return time_series

In [None]:
country_region_num = json.load(open("country_region_num.json","r"))
region_name = json.load(open("region_name.json","r"))

country_region_num_df = pd.DataFrame(
    [(k, v) for k,v in country_region_num.items()], columns=["COUNTRY_CODE","REGION_NUM"]
)
region_name_df = pd.DataFrame(
    [(k, v) for k,v in region_name.items()], columns=["REGION_NUM","REGION_NAME"]
)

country_region_num_df = session.create_dataframe(country_region_num_df)
region_name_df = session.create_dataframe(region_name_df)

country_region_df = country_region_num_df.join(
    region_name_df,
    'REGION_NUM',
    how = 'inner'
).select(
    F.col("COUNTRY_CODE"),
    F.col("REGION_NAME")
)

In [None]:
whitelisted_users_school = [line.removesuffix("\n") for line in open(f"school_users.txt", "r").readlines()]
whitelisted_users_mystery = [line.removesuffix("\n") for line in open(f"mystery_users.txt", "r").readlines()]

whitelisted_users_school = session.create_dataframe(whitelisted_users_school)
whitelisted_users_mystery = session.create_dataframe(whitelisted_users_mystery)


In [None]:
field_raw = session.table('AATT_RAW.PUBLIC.FIELD_RAW')

def fix_raw_results(raw_results_df: snowpark.DataFrame) -> snowpark.DataFrame:
    abbrev_lookup = {"AE": "Estimation", "AV": "Verification", "AI": "Inference"}
    
    # FIX OLD RULES WITH TWO LETTER METHOD
    # FORCE AGE GATE TO 16 WHERE APPLICABLE
    
    old_rows = raw_results_df.select(
        F.col("SUBJECT_ID"),
        F.col("RESULT_TIME"),
        F.col("NAME")
    ).where(
        field_raw["METHOD"].isin(("AE","AV","AI"))
    )
    
    old_rows = old_rows.with_column(
        "AGE_GATE", 
        F.when(
            (old_rows["METHOD"] == "AE") & ~(old_rows["NAME"].isin(("Arissian","PrivateId","Privately","Yoti"))),
            F.lit(0)
        ).otherwise(
            F.lit(16)
        )
    )
    
    old_rows = old_rows.with_column(
        "METHOD", 
        F.when(
            F.col("METHOD") == "AE",
            "Estimation"
        ).when(
            F.col("METHOD") == "AV",
            "Verification"
        ).when(
            F.col("METHOD") == "AI",
            "Inference"
        ).otherwise(
            None
        )
    )
    
    merge_result = raw_results_df.merge(
        source = old_rows,
        join_expr = (
            (raw_results_df['SUBJECT_ID'] == old_rows['SUBJECT_ID']) &
            (raw_results_df['RESULT_TIME'] == old_rows['RESULT_TIME']) &
            (raw_results_df['NAME'] == old_rows['NAME'])
        ),
        clauses = [
            F.when_matched().update(
                {
                    "METHOD": old_rows["METHOD"],
                    "AGE_GATE": old_rows["AGE_GATE"]
                }
            )
        ]
    )

    return raw_results_df

def filter_raw_results(raw_results_df: snowpark.DataFrame) -> snowpark.DataFrame:
    
    raw_results_df = raw_results_df.with_column(
        "TEST_TYPE", 
        F.when(
            raw_results_df["CREATE_USER"].isin(whitelisted_users_school), "School"
        ).when(
            raw_results_df["CREATE_USER"].isin(whitelisted_users_mystery), "Mystery Shoppers"
        ).otherwise(None)
    )
    
    filtered_raw_results_df = raw_results_df.select(
        F.col("SUBJECT_ID"), 
        F.col("TEST_TYPE"),
        F.col("RESULT_TIME").alias("TIMESTAMP"), 
        F.col("NAME").alias("PROVIDER"), 
        F.col("METHOD"),
        F.col("AGE_GATE"),
        F.col("VERIFICATION_TIME"),
        F.col("VERIFICATION_DATA")
    ).where(
        F.col("TEST_TYPE").isNotNull()
    )

    return filtered_raw_results_df


In [None]:
@udf(
    replace = True,
    is_permanent = False,
    input_types = [
                StringType(),
                IntegerType(),
                StringType(),
                StringType()
    ],
    return_type = MapType(StringType(), FloatType()),
    session = session
)
def parse_raw_provider_result(method, age_gate, vendor, result_str):
    age_gate = None if age_gate is None else int(age_gate)
    match vendor:

        case 'IDmission':
            if method=="Estimation":
                try:
                    verification_data = json.loads(result_str)
                    estimated_age = round(float(verification_data[0]['resultData']['estimatedAge']), 2)
                    confidence = round(float(verification_data[0]['resultData']['realScore']), 2)
                    return {'estimate': estimated_age} if estimated_age > 0 else None
                except Exception as err:
                    return None
            
            if method=="Verification":
                try:
                    verification_data = json.loads(result_str)
                    ver_result_string = verification_data[0]['resultData']['verificationResult']
                    
                    if ver_result_string == "Approved":
                        ver_result = 1
                    elif ver_result_string == "Under the Age of 18":
                        ver_result = 0
                    else:
                        ver_result = 0
                    return {'18': ver_result}
                except Exception as err:
                    return None

        case 'PrivateId':
            try:
                verification_data = json.loads(result_str)
                estimated_age = round(float(verification_data['age']), 2)
                return {'estimate': estimated_age} if estimated_age > 0 else None
            except Exception as err:
                return None

        case 'VerifyChain':
            try:
                verification_data = json.loads(result_str)
                estimated_age = round(float(verification_data['estimatedAge']), 2)
                confidence = round(float(verification_data['confidence']), 2)
                return {'estimate': estimated_age} if estimated_age > 0 else None
            except Exception as err:
                return None

        case 'IDVerse':
            try:
                verification_data = json.loads(result_str)

                estimated_age = None
                for doc in verification_data['results']['documents']:
                    if estimated_age is None:
                        estimated_age = round(float(doc['calculatedData']['age']), 2)
                    else:
                        return None
                return {'estimate': estimated_age} if estimated_age > 0 else None
            except Exception as err:
                return None

        case 'Unissey':
            try:
                verification_data = json.loads(result_str)
                estimated_age = round(float(verification_data["data"]["details"]["age"]["age_estimation"]["estimated_age"]), 2)
                age_range = verification_data["data"]["details"]["age"]["age_estimation"]["age_range"]
                return {'estimate': estimated_age} if estimated_age > 0 else None
            except Exception as err:
                return None

        case 'Persona':
            if method=="Estimation":
                try:
                    verification_data = json.loads(result_str)

                    estimated_age = round(float(verification_data['data']['attributes']['fields']['selfie-estimated-age']['value']), 2)
                    return {'estimate': estimated_age} if estimated_age > 0 else None

                except Exception as err:
                    return None
            
            if method=="Verification":
                try:
                    verification_data = json.loads(result_str)
                    verified_docs = set()
                    for included_obj in verification_data['included']:
                        if included_obj['type'].startswith('verification/') and included_obj["attributes"]["status"] == "passed":
                            verified_docs.add(included_obj['type'])
                    if len(verified_docs) >= 2:
                        ver_result = 1
                    else:
                        ver_result = 0
                    return {'18': ver_result}

                except Exception as err:
                    return None

        case 'Yoti':
            try:
                verification_data = json.loads(result_str)
                estimated_age = round(float(verification_data["age"]), 2)
                return {'estimate': estimated_age} if estimated_age > 0 else None
            except Exception as err:
                return None

        case 'Rigr-AI':
            try:
                verification_data = json.loads(result_str)
                estimated_age = round(float(verification_data["age"]), 2)
                uncertainty = round(float(verification_data["uncertainty"]), 2)
                return {'estimate': estimated_age} if estimated_age > 0 else None
                
            except Exception as err:
                try:
                    verification_data = json.loads(result_str)
                    estimated_age = round(float(verification_data['response']['results'][0]['results'][0]['age']), 2)
                    uncertainty = round(float(verification_data['response']['results'][0]['results'][0]["uncertainty"]), 2)
                    return {'estimate': estimated_age} if estimated_age > 0 else None
                except Exception as err:
                    return None

        case 'Arissian':
            try:
                verification_data = json.loads(result_str)
                msgData = json.loads(verification_data['MsgData'])

                info_string = msgData['Info']
                age_gate_match = re.match(r".* (target age of (\d+)) .*", info_string)
                age_gate = int(age_gate_match.group(2))

                pass_match = re.match(r".* (\[AE_(PASS|FAIL)\])", info_string)
                bool_check = 1 if pass_match.group(2) == 'PASS' else 0

                #confidence_match = re.match(r"(.*) confidence .*", info_string)
                #confidence_level = confidence_match.group(1)

                return {str(age_gate): bool_check}

            except Exception:
                return None

        case 'Needemand':
            try:
                verification_data = json.loads(result_str)
                if age_gate is not None:
                    if verification_data['result'] == '1':
                        return {str(age_gate): 1}
                    elif verification_data['result'] == '0':
                        return {str(age_gate): 0}
                    else:
                        return None
                else:
                    return None
            except Exception:
                return None

        case 'ShareRing':
            try:
                verification_data = json.loads(result_str)
                ver_msg = verification_data["qrRes"]
                age_gate_match_pos = re.match(r"\nYes, I am (\d+) or over", ver_msg)
                if age_gate_match_pos:
                    age_gate = str(age_gate_match_pos.group(1))
                    return {age_gate: 1}
                age_gate_match_neg = re.match(r"\nNo, I'm not over (\d+)", ver_msg)
                if age_gate_match_neg:
                    age_gate = str(age_gate_match_neg.group(1))
                    return {age_gate: 0}
                return None
            except Exception:
                return None

        case 'RightCrowd':
            try:
                verification_data = json.loads(result_str)
                age_gate_pass = 1 if verification_data['result'] == "True" else 0
                
                return {str(verification_data['ageThreshold']): age_gate_pass}
            except Exception as err:
                return None

        case 'MyMahi':
            try:
                verification_data = json.loads(result_str)
                age_gate_results = verification_data['age_equal_or_over']
                return {str(age_gate): 1 if age_gate_results[age_gate] else 0 for age_gate in age_gate_results}
            except Exception as err:
                return None

        case 'VerifyMy':
            try:
                verification_data = json.loads(result_str)
                return {str(age_gate): 1 if verification_data['age_verified'] else 0}
            except Exception as err:
                return None

        case 'Privately':
            try:
                verification_data = json.loads(result_str)
                return {str(verification_data["age"]): 1 if verification_data["rlt"] else 0}
            except Exception as err:
                return None

        case _:
            return None

In [None]:
@pandas_udf(
    return_type = PandasSeriesType(StringType()),
    input_types = [
        PandasSeriesType(IntegerType()),
        PandasSeriesType(FloatType()),
        PandasSeriesType(MapType(StringType(), BooleanType())),
    ],
    is_permanent = False,
    replace = True
)
def format_verification_result(
    age_gate: pd.Series,
    estimated_age: pd.Series,
    dict_over_age_gate: pd.Series
) -> pd.Series:

    result_series = pd.Series(pd.NA, index=estimated_age.index)

    estimated_age_valid = (pd.notna(estimated_age)) & (~np.isnan(estimated_age))

    temp_df = pd.DataFrame(index=(~estimated_age_valid).index)
    temp_df["AGE_GATE"] = age_gate[~estimated_age_valid]
    temp_df["DICT_OVER_AGE_GATE"] = dict_over_age_gate[~estimated_age_valid]

    temp_df.replace(np.nan, pd.NA, inplace=True)
    
    over_under_age_gate = temp_df.apply(
        lambda x:  None if pd.isna(x["DICT_OVER_AGE_GATE"])
                        else (None if x['DICT_OVER_AGE_GATE'].get('18', None) is None else f"over 18" if x['DICT_OVER_AGE_GATE']['18'] else f"under 18") if x['AGE_GATE'] == 0
                        else None if x['DICT_OVER_AGE_GATE'].get(str(int(x['AGE_GATE'])), None) is None 
                        else f"over {str(int(x['AGE_GATE']))}" if x['DICT_OVER_AGE_GATE'][str(int(x['AGE_GATE']))] 
                        else f"under {str(int(x['AGE_GATE']))}",
        axis = 1
    )
    
    result_series.loc[estimated_age_valid] = estimated_age[estimated_age_valid].astype('str')
    result_series.loc[(~estimated_age_valid)] = over_under_age_gate

    return result_series

In [None]:
def parse_raw_results(filtered_raw_results_df: snowpark.DataFrame) -> snowpark.DataFrame:

    filtered_raw_results_df = filtered_raw_results_df.with_column(
        "VERIFICATION_TIME",
        parse_verification_time(filtered_raw_results_df["VERIFICATION_TIME"])
    )

    temp_df = filtered_raw_results_df.select(
        F.col("SUBJECT_ID"),
        F.col("PROVIDER"),
        F.col("METHOD"),
        F.col("AGE_GATE"),
        F.col("VERIFICATION_DATA")
    )

    temp_df = temp_df.with_column(
        "PARSED_RESULT",
        parse_raw_provider_result(
            temp_df['METHOD'],
            temp_df['AGE_GATE'],
            temp_df['PROVIDER'],
            temp_df['VERIFICATION_DATA']
        )
    )
    
    temp_df = temp_df.with_column(
        "ESTIMATED_AGE",
        F.get(
            temp_df["PARSED_RESULT"],
            F.lit("estimate")
        )
    )
    temp_df = temp_df.select(
        F.col("SUBJECT_ID"),
        F.col("PROVIDER"),
        F.col("METHOD"),
        F.col("AGE_GATE"),
        F.col("ESTIMATED_AGE"),
        F.when(
            F.col("ESTIMATED_AGE").isNull(),
            F.col("PARSED_RESULT")
        ).otherwise(None).alias("DICT_OVER_AGE_GATE")
    )

    return temp_df


def write_verification_results_to_table(
    filtered_raw_results_df: snowpark.DataFrame,
    verification_results_df: snowpark.DataFrame
) -> snowpark.DataFrame:
    
    temp_df = filtered_raw_results_df.select(
        F.col("SUBJECT_ID"), 
        F.col("TEST_TYPE"), 
        F.col("TIMESTAMP"), 
        F.col("PROVIDER"), 
        F.col("METHOD"), 
        F.col("AGE_GATE"), 
        F.col("VERIFICATION_TIME"),
    ).with_column(
        "ESTIMATED_AGE", 
        F.lit(None).cast(FloatType())
    ).with_column(
        "DICT_OVER_AGE_GATE",
        F.lit(None).cast(MapType(StringType(), BooleanType()))
    )
    
    temp_df.write.save_as_table("filtered_raw_results", mode="overwrite", table_type="temporary")
    temp_df = session.table('filtered_raw_results')
    
    merge_result = temp_df.merge(
        source = verification_results_df,
        join_expr = (
            (temp_df['SUBJECT_ID'] == verification_results_df['SUBJECT_ID']) &
            (temp_df['PROVIDER'] == verification_results_df['PROVIDER']) &
            (temp_df['METHOD'] == verification_results_df['METHOD']) &
            (temp_df['AGE_GATE'] == verification_results_df['AGE_GATE'])
        ),
        clauses = [
            F.when_matched().update(
                {
                    "ESTIMATED_AGE": verification_results_df["ESTIMATED_AGE"],
                    "DICT_OVER_AGE_GATE": verification_results_df["DICT_OVER_AGE_GATE"]
                }
            )
        ]
    )
    
    return temp_df
    


def split_results_by_method(filtered_parsed_results_df: snowpark.DataFrame) -> snowpark.DataFrame:

    filtered_parsed_results_df = filtered_parsed_results_df.with_column(
        "FORMATTED_RESULT",
        format_verification_result(
            F.col("AGE_GATE"),
            F.col("ESTIMATED_AGE"),
            F.col("DICT_OVER_AGE_GATE")
        )
    )
    
    filtered_parsed_results_df = filtered_parsed_results_df.with_column(
        ["AE_RESULT","AV_RESULT","AI_RESULT"], 
        [
            F.when(
                F.col("METHOD") == "Estimation",
                F.col("FORMATTED_RESULT")
            ).otherwise(
                None
            ),
            F.when(
                F.col("METHOD") == "Verification",
                F.col("FORMATTED_RESULT")
            ).otherwise(
                None
            ),
            F.when(
                F.col("METHOD") == "Inference",
                F.col("FORMATTED_RESULT")
            ).otherwise(
                None
            )
        ]
    ).select(
        F.col("SUBJECT_ID"),
        F.col("TEST_TYPE"), 
        F.col("TIMESTAMP"), 
        F.col("PROVIDER"),
        F.col("METHOD"),
        F.col("AGE_GATE"),
        F.col("AE_RESULT"),
        F.col("AV_RESULT"),
        F.col("AI_RESULT"),
        F.col("VERIFICATION_TIME")
    )

    return filtered_parsed_results_df

In [None]:
subjects_df = session.table('AATT_RAW.PUBLIC.SUBJECTS')

def clean_subjects_table(subjects_df: snowpark.DataFrame) -> snowpark.DataFrame:
    subjects_df = subjects_df.with_column(
        "BACKGROUND",
        fill_subject_background(
            subjects_df['ORIGIN'],
            subjects_df['COUNTRY_OF_BIRTH_SUBJECT'],
            subjects_df['COUNTRY_OF_BIRTH_FATHER'],
            subjects_df['COUNTRY_OF_BIRTH_MOTHER']
        )
    )
    
    subjects_df = subjects_df.join(
        country_region_df,
        subjects_df['BACKGROUND'] == country_region_df['COUNTRY_CODE'],
        how='left'
    )
    
    subjects_df = subjects_df.with_column(
        "BACKGROUND",
        F.when(
            F.col("BACKGROUND") == "First Nations",
            "First Nations"
        ).otherwise(
            F.col("REGION_NAME")
        )
    )
    
    subjects_df = subjects_df.with_column(
        "REAL_AGE_FLOAT",
        F.when(
            (F.col("AGE_IN_MONTHS") > 0) & (F.col("AGE_IN_MONTHS") < 25*12),
            F.bround(F.col("AGE_IN_MONTHS")/12, 2)
        ).otherwise(
            None
        )
    )
    
    subjects_df = subjects_df.with_column(
        "SUBJECT_AGE",
        F.when(
            F.col("AGE_IN_MONTHS") >= 25*12,
            ">=25"
        ).when(
            F.col("AGE_IN_MONTHS") < 10*12,
            "<10"
        ).when(
            (
                (F.col("AGE_IN_MONTHS") > 0) &
                (F.col("AGE_IN_MONTHS") < 25*12)
            ),
            F.cast(
                F.cast(
                    F.floor(F.col("REAL_AGE_FLOAT")),
                    IntegerType()
                ), 
                StringType()
            )
        ).otherwise(
            None
        )
    )
    
    subjects_df = subjects_df.select(
        F.col("ID").alias("SUBJECT_ID"),
        F.col("BACKGROUND"),
        F.col("REAL_AGE_FLOAT"),
        F.col("SUBJECT_AGE")
    )
    return subjects_df

In [None]:
filtered_field_result.write.mode("overwrite").save_as_table("filtered_field_result", table_type="temp")
subjects_df.write.mode("overwrite").save_as_table("processed_subjects", table_type="temp")

print(filtered_field_result.count())
print(subjects_df.count())

In [None]:
filtered_field_result = session.table("filtered_field_result")
processed_subjects = session.table("processed_subjects")




In [None]:
def create_cleansed_table(
    filtered_parsed_results_df: snowpark.DataFrame,
    processed_subjects_df: snowpark.DataFrame
) -> snowpark.DataFrame:
    
    cleansed_df = filtered_parsed_results_df.join(
        processed_subjects_df,
        "SUBJECT_ID",
        how='inner'
    )
    
    cleansed_df = cleansed_df.with_column(
        "ABS_ERROR",
        F.abs(F.col("REAL_AGE_FLOAT") - F.col("ESTIMATED_AGE"))
    ).where(
        (
            (F.col("REAL_AGE_FLOAT") >= 11) &
            (F.col("REAL_AGE_FLOAT") < 18) &
            (F.col("TEST_TYPE") == "School")
        ) |
        (
            (F.col("REAL_AGE_FLOAT") > 0 ) &
            (F.col("TEST_TYPE") == "Mystery Shoppers")
        )
    )
    
    cleansed_df = cleansed_df.select(
        F.col("SUBJECT_ID"), 
        F.col("SUBJECT_AGE"),
        F.col("BACKGROUND"),
        F.col("TEST_TYPE"), 
        F.col("TIMESTAMP"), 
        F.col("PROVIDER"), 
        F.col("METHOD"), 
        F.col("AGE_GATE"), 
        F.col("AE_RESULT"), 
        F.col("AV_RESULT"), 
        F.col("AI_RESULT"), 
        F.col("ABS_ERROR"), 
        F.col("VERIFICATION_TIME")
    )
    return cleansed_df

In [None]:
@udf(
    replace = True,
    is_permanent = False,
    input_types = [
                StringType(),
                StringType(),
                IntegerType(),
                IntegerType()
    ],
    return_type = StringType()
)
def gate_check(real_age,
               estimated_age,
               result_age_gate,
               query_age_gate
):

    def positive_check(query, threshold):
        return "TP" if query >= threshold else "FP"

    def negative_check(query, threshold):
        return "TN" if query < threshold else "FN"


    if result_age_gate != 0 and query_age_gate != result_age_gate:
        return None

    try:
        if type(real_age) is str:
            if real_age.startswith(">="):
                real_age = float(real_age[2:])
            elif real_age.startswith("<"):
                real_age = float(real_age[1:])
            else:
                real_age = pd.to_numeric(real_age, errors='raise', downcast='float')
        else:
            real_age = pd.to_numeric(real_age, errors='raise', downcast='float')
    except:
        return None

    try:
        if pd.isna(estimated_age):
            return None
        estimated_age_float = pd.to_numeric(estimated_age, errors='raise', downcast='float')

        if estimated_age_float >= query_age_gate:
            check = positive_check(real_age, query_age_gate)
        else:
            check = negative_check(real_age, query_age_gate)
        return check

    # generally for handling cases where estimated age is not a number, but only string that is a binary choice of over/under some integer e.g. "over 18"
    except:
        split_str = estimated_age.split(' ')
        if query_age_gate == int(split_str[1]):
            if split_str[0] == "over":
                if real_age >= query_age_gate:
                    return "TP"
                else:
                    return "FP"
            elif split_str[0] == "under":
                if real_age < query_age_gate:
                    return "TN"
                else:
                    return "FN"
        else:
            return None


In [None]:
def filter_by_method_type(
    cleansed_df: snowpark.DataFrame,
    method:str = None,
    test_type:str = None,
    discard_outliers:bool = True
) -> snowpark.DataFrame:
    
    abbrev_lookup = {"AE": "Estimation", "AV": "Verification", "AI": "Inference"}
    rev_abbrev_lookup = {v: k for k, v in abbrev_lookup.items()}

    if discard_outliers:
        abs_error_float = F.col('ABS_ERROR').astype('float')
        
        # Calculate overall statistics
        stats_df = cleansed_df.agg(
            F.avg(abs_error_float).alias("overall_mean"),
            F.stddev(abs_error_float).alias("overall_stdev")
        ).collect()[0]

        upper_bound = stats_df["OVERALL_MEAN"] + 2*stats_df["OVERALL_STDEV"] if stats_df["OVERALL_MEAN"] is not None and stats_df["OVERALL_STDEV"] is not None else None
        lower_bound = stats_df["OVERALL_MEAN"] - 2*stats_df["OVERALL_STDEV"] if stats_df["OVERALL_MEAN"] is not None and stats_df["OVERALL_STDEV"] is not None else None

        # Apply outlier filtering
        cleansed_df = cleansed_df.with_column(
            "ABS_ERROR",
            F.when(
                (F.is_null(abs_error_float)) |
                (F.equal_nan(abs_error_float)) |
                ((upper_bound is not None) &
                    (lower_bound is not None) &
                    ((abs_error_float < lower_bound) | (abs_error_float > upper_bound))
                ),
                None
            ).otherwise(
                abs_error_float
            )
        )

    if method is not None:
        if method not in rev_abbrev_lookup:
            raise ValueError(f"Invalid method: {method}. Must be one of {list(rev_abbrev_lookup.keys())}")
    
        # Filter by method
        cleansed_df = cleansed_df.where(
            F.col("METHOD") == method
        )

    if test_type is not None:
        cleansed_df = cleansed_df.where(
            F.col("TEST_TYPE") == test_type
        )
    
    return cleansed_df

In [None]:
def generate_check_age_gates(
    df: snowpark.DataFrame,
    age_gates: List[int],
    method:str = None
) -> snowpark.DataFrame:

    abbrev_lookup = {"AE": "Estimation", "AV": "Verification", "AI": "Inference"}
    rev_abbrev_lookup = {v: k for k, v in abbrev_lookup.items()}

    if method not in rev_abbrev_lookup:
        raise AssertionError("No method selected")

    for age_gate in age_gates:
        colname = f"check{age_gate}"

        df = df.with_column(
            colname,
            gate_check(
                F.col("SUBJECT_AGE"),
                F.col(f"{rev_abbrev_lookup[method]}_RESULT"),
                F.col("AGE_GATE"), F.lit(age_gate)
            )
        )

    return_df = df.select(
        F.col("SUBJECT_AGE"),
        *[F.col(f"CHECK{ag}") for ag in age_gates]
    )

    return return_df

In [None]:
def group_aggregate(
    groupby_columns: List[str],
    df: snowpark.DataFrame
) -> snowpark.DataFrame:

    # explode vertically
    df_long = df.flatten(
        F.create_map(
            [y for x in [[F.lit(f"{ag}"), f"CHECK{ag}"] for ag in age_gates] for y in x]
        )
    ).where(
        (F.col("AGE_GATE") == 0) | (F.col("KEY") == F.col("AGE_GATE"))
    ).select(
        F.col("SUBJECT_ID"), F.col("SUBJECT_AGE"), F.col("PROVIDER"), F.col("KEY").alias("AGE_GATE"), F.col("VALUE").alias("CHECK"), F.col("ABS_ERROR")
    )

    grouped_df = df_long.group_by(*(["AGE_GATE"] + groupby_columns))
    
    aggregated_errors_df = grouped_df.agg(
        F.avg("ABS_ERROR").alias("MAE"),
        F.stddev("ABS_ERROR").alias("ABSOLUTE_ERROR_STDEV")
    )

    aggregated_accuracies_df = grouped_df.agg(
        F.sum(F.when(F.col("CHECK") == "TP", 1).otherwise(0)).alias("TP"),
        F.sum(F.when(F.col("CHECK") == "TN", 1).otherwise(0)).alias("TN"),
        F.sum(F.when(F.col("CHECK") == "FP", 1).otherwise(0)).alias("FP"),
        F.sum(F.when(F.col("CHECK") == "FN", 1).otherwise(0)).alias("FN"),
    )

    combined_aggregate_df = aggregated_accuracies_df.with_columns(
        ["samples","FPR","FNR","TPR","TNR","accuracy"],
        [
            F.col("TP") + F.col("TN") + F.col("FP") + F.col("FN"),
            F.when(
                (F.col("TN") + F.col("FP")) == 0,
                None
            ).otherwise(
                100 * F.col("FP") / (F.col("TN") + F.col("FP"))
            ),
            F.when(
                (F.col("TP") + F.col("FN")) == 0,
                None
            ).otherwise(
                100 * F.col("FN") / (F.col("TP") + F.col("FN"))
            ),
            F.when(
                (F.col("TP") + F.col("FN")) == 0,
                None
            ).otherwise(
                100 * F.col("TP") / (F.col("TP") + F.col("FN"))
            ),
            F.when(
                (F.col("TN") + F.col("FP")) == 0,
                None
            ).otherwise(
                100 * F.col("TN") / (F.col("TN") + F.col("FP"))
            ),
            F.when(
                (F.col("TP") + F.col("TN") + F.col("FP") + F.col("FN")) == 0,
                None
            ).otherwise(
                100 * (F.col("TP") + F.col("TN")) / (F.col("TP") + F.col("TN") + F.col("FP") + F.col("FN"))
            )
        ]
    )

    combined_aggregate_df = combined_aggregate_df.join(
        aggregated_errors_df, 
        ['AGE_GATE','PROVIDER','SUBJECT_AGE'],
        how = 'left'
    ).select(
        F.col("AGE_GATE"),
        F.col("PROVIDER"),
        F.col("SUBJECT_AGE"),
        F.col("SAMPLES"),
        F.col("FPR"),
        F.col("FNR"),
        F.col("TPR"),
        F.col("TNR"),
        F.col("ACCURACY"),
        F.col("MAE"),
        F.col("ABSOLUTE_ERROR_STDEV")
    ).sort("AGE_GATE","PROVIDER","SUBJECT_AGE")
    
    return combined_aggregate_df


In [None]:
def create_aggregate_table(cleansed_table, method, test_type, age_gates):
    filtered_field_df = filter_by_method_type(cleansed_table, method=method, test_type=test_type)
    filtered_field_df = generate_check_age_gates(filtered_field_df, age_gates, method=method)
    aggregated_table = group_aggregate(["PROVIDER","SUBJECT_AGE"], filtered_field_df)
    return aggregated_table

In [None]:
create_aggregate_table(cleansed_field, "Estimation", "School", [13,16,18])

In [None]:
import matplotlib.pyplot as plt
from reportlab.platypus import SimpleDocTemplate, Paragraph, Image, Spacer
from reportlab.lib.styles import getSampleStyleSheet
from reportlab.lib.units import inch