In [1]:
### Dependencies ###

# Local Processing
import fireducks.pandas as pd
import numpy as np
import plotly.graph_objects as go


# Snowpark for Python
from snowflake.snowpark import Session
from snowflake.snowpark.version import VERSION
from snowflake.snowpark.functions import col, count, when, mean, lit, corr
from snowflake.snowpark.types import StringType, LongType, DecimalType

# Workflow
import json

# Python
from collections import defaultdict as dd

In [2]:
### Secured connection to Snowflake ###
connection_parameters = json.load(open('connection.json'))
session = Session.builder.configs(connection_parameters).create()
session.sql_simplifier_enabled = True

database = 'FRAUD_DETECT_DB'
schema = 'FRAUD_DETECT_SM'
table = 'FRAUD_DATA'
input_tbl = f"{database}.{schema}.{table}"
fraud_data = session.table(input_tbl)
#fraud_data.show()

Data Prep Tasks to perform :
    - Missing Values
        - Drop Column "prev_address_months_count"
        - Replace all "-1" occurences in 'current_address_months_count' with mean value
    - Balance dataset
        - Oversample lines where fraud_bool = 1
    - Eliminates highly correlated features

In [3]:
### Missing values

def missing_values_snowpark(session, table_name):
    """
    Check for missing values in a Snowflake (Snowpark Version)
    
    Parameters:
    - session (Session) : current session
    - connection_parameters (dict): Snowflake connection parameters
    - table_name (str): Name of the table to check for missing values
    
    Output:
    - None
    """
    df = session.table(table_name)
    column_metadata = df.schema.fields

    missing_values = False
    for field in column_metadata:
        column_name = field.name
        data_type = field.datatype

        if data_type == StringType:
            # String Columns
            n_missing_values = df.filter((col(column_name).is_null()) | (col(column_name) == "")).count()
        elif column_name in ('PREV_ADRESS_MONTH_COUNT', 'CURRENT_ADRESS_MONTH_COUNT'):
            # Special cases
            n_missing_values = df.filter((col(column_name).is_null()) | (col(column_name) == -1)).count()
        else:
            # Numerical columns
            n_missing_values = df.filter(col(column_name).is_null()).count()
        print(f"Feature {column_name} lacks {n_missing_values} values")
    if missing_values:
        print("Values are missing, please check details!")
    else:
        print("No missing values, great!")
    

In [4]:
missing_values_snowpark(session,input_tbl)

Feature FRAUD_BOOL lacks 0 values
Feature INCOME lacks 0 values
Feature NAME_EMAIL_SIMILARITY lacks 0 values
Feature CUSTOMER_AGE lacks 0 values
Feature DAYS_SINCE_REQUEST lacks 0 values
Feature INTENDED_BALCON_AMOUNT lacks 0 values
Feature PAYMENT_TYPE lacks 0 values
Feature ZIP_COUNT_4W lacks 0 values
Feature VELOCITY_6H lacks 0 values
Feature VELOCITY_24H lacks 0 values
Feature VELOCITY_4W lacks 0 values
Feature BANK_BRANCH_COUNT_8W lacks 0 values
Feature DATE_OF_BIRTH_DISTINCT_EMAILS_4W lacks 0 values
Feature EMPLOYMENT_STATUS lacks 0 values
Feature CREDIT_RISK_SCORE lacks 0 values
Feature EMAIL_IS_FREE lacks 0 values
Feature HOUSING_STATUS lacks 0 values
Feature PHONE_HOME_VALID lacks 0 values
Feature PHONE_MOBILE_VALID lacks 0 values
Feature BANK_MONTHS_COUNT lacks 0 values
Feature HAS_OTHER_CARDS lacks 0 values
Feature PROPOSED_CREDIT_LIMIT lacks 0 values
Feature FOREIGN_REQUEST lacks 0 values
Feature SOURCE lacks 0 values
Feature SESSION_LENGTH_IN_MINUTES lacks 0 values
Feature

In [5]:
# Dropping near empty column
df_missing = fraud_data.drop('PREV_ADRESS_MONTH_COUNT')

In [6]:
# Imputing mean value in the near filled column
mean_value = df_missing.filter(col('CURRENT_ADRESS_MONTH_COUNT') != -1).select(mean(col('CURRENT_ADRESS_MONTH_COUNT'))).collect()[0][0]
df_missing = df_missing.with_column('CURRENT_ADRESS_MONTH_COUNT', 
                    when(col('CURRENT_ADRESS_MONTH_COUNT') == -1, mean_value)
                    .otherwise(col('CURRENT_ADRESS_MONTH_COUNT')))

In [7]:
### Handling Imbalances ###
def balance_df(df):
    class_counts = df_missing.group_by(col('fraud_bool')).agg(count(lit(1)).alias('count')).collect()
    majority_class = min(class_counts) # associated with the minimal value (0)
    minority_class = max(class_counts) # associated with the maximal value (1)
    rows_needed = majority_class['COUNT'] - minority_class['COUNT']
    minority_class_df = df_missing.filter(col('FRAUD_BOOL') == minority_class['FRAUD_BOOL'])
    sampled_rows_df = session.create_dataframe([], schema = df_missing.schema)
    while rows_needed > 0:
        sample_size = min(rows_needed, minority_class_df.count())
        sampled_rows_df = sampled_rows_df.union_all(minority_class_df.sample(n=sample_size))
        rows_needed -= sample_size
    balanced_df = df_missing.union_all(sampled_rows_df)
    return balanced_df

In [8]:
balanced_df = balance_df(df_missing)
balanced_df.group_by(col('fraud_bool')).agg(count(lit(1)).alias('count')).collect()

[Row(FRAUD_BOOL=1, COUNT=988971), Row(FRAUD_BOOL=0, COUNT=988971)]

The conversion of the correlation function to Snowpark queries was too costly (long execution time, more than 3 mins), therefore I will refer to the previous local analysis to treat cloud data

In [9]:
### Manual Remove of Highly Correlated Features ###

# TODO : Find a efficient way to automate highly correlated features correlation using Snowpark

# def get_correlation_matrix(df):
#     num_cols = [col.name for col in df.schema.fields if col.datatype in (LongType(),DecimalType())]
#     correlation_matrix = dd()
#     n = len(num_cols)
#     for i in range(n):
#         for j in range(i,n):
#             col1, col2 = num_cols[i], num_cols[j]
#             correlation_value = balanced_df.select(corr(col(col1), col(col2)).alias('result')).collect()[0]['RESULT']
#             correlation_matrix[(i,j)] = correlation_value
#     return correlation_matrix # Too LONG !!!!!

In [10]:
#corr_matrix = get_correlation_matrix(balanced_df)
#print(corr_matrix)

In [11]:
pd_df = balanced_df.to_pandas()

In [12]:
def display_corr_matrix(dataset):
    corr = dataset.select_dtypes(exclude=['object']).corr()
    mask = np.tril(np.ones(corr.shape)).astype(bool)
    corr = corr.where(~mask)
    
    trace = go.Heatmap(z=corr.values,
                       x=corr.index.values,
                       y=corr.columns.values)

    fig = go.Figure()
    fig.add_trace(trace)
    fig.update_layout(
        plot_bgcolor='rgba(0,0,0,0)',
        xaxis=dict(showgrid=False),
        yaxis=dict(showgrid=False)
    )
    fig.show()

display_corr_matrix(pd_df)

In [13]:
def display_corr_distribution(dataset, title="Correlation Value Distribution"):

    corr = dataset.select_dtypes(exclude=['object']).corr()
    mask = np.tril(np.ones(corr.shape)).astype(bool)

    corr_values = corr.where(~mask).stack() # flatten trilled matrix
    corr_value_counts = corr_values.value_counts().sort_index()

    trace = go.Bar(
        x=corr_value_counts.index,
        y=corr_value_counts.values
    )

    fig = go.Figure()
    fig.add_trace(trace)
    fig.update_layout(
        title=title,
        xaxis_title="Correlation Value",
        yaxis_title="Frequency",
        paper_bgcolor='rgba(0,0,0,0)',
        plot_bgcolor='rgba(0,0,0,0)',
        xaxis=dict(
            showgrid=False,
            tickvals=np.round(corr_value_counts.index, 2),  # Set tick values with improved precision
            tickformat=".2f"  # Format ticks to 2 decimal places
        ),
        yaxis=dict(showgrid=False)
    )
    fig.show()

display_corr_distribution(pd_df)

In [14]:
cleaned_df_pd = balanced_df.drop('MONTH').to_pandas()

In [15]:
display_corr_distribution(cleaned_df_pd)

In [18]:
cleaned_df = balanced_df.drop('MONTH')
cleaned_df.write.mode('overwrite').save_as_table('FRAUD_DATA_CLEANED')

In [20]:
### Verifying correlation
cleaned_df_pd = session.table('FRAUD_DATA_CLEANED').to_pandas()

In [21]:
display_corr_distribution(cleaned_df_pd) # same correlation great !

In [22]:
session.close()