Load and Explore the Data

In [0]:
# Load the dataset
file_path = "/FileStore/tables/creditcard.csv"  # Update with your file path
df = spark.read.csv(file_path, header=True, inferSchema=True)

# Show the first 5 rows
display(df.limit(5))

Time,V1,V2,V3,V4,V5,V6,V7,V8,V9,V10,V11,V12,V13,V14,V15,V16,V17,V18,V19,V20,V21,V22,V23,V24,V25,V26,V27,V28,Amount,Class
0.0,-1.3598071336738,-0.0727811733098497,2.53634673796914,1.37815522427443,-0.338320769942518,0.462387777762292,0.239598554061257,0.0986979012610507,0.363786969611213,0.0907941719789316,-0.551599533260813,-0.617800855762348,-0.991389847235408,-0.311169353699879,1.46817697209427,-0.470400525259478,0.207971241929242,0.0257905801985591,0.403992960255733,0.251412098239705,-0.018306777944153,0.277837575558899,-0.110473910188767,0.0669280749146731,0.128539358273528,-0.189114843888824,0.133558376740387,-0.0210530534538215,149.62,0
0.0,1.19185711131486,0.26615071205963,0.16648011335321,0.448154078460911,0.0600176492822243,-0.0823608088155687,-0.0788029833323113,0.0851016549148104,-0.255425128109186,-0.166974414004614,1.61272666105479,1.06523531137287,0.48909501589608,-0.143772296441519,0.635558093258208,0.463917041022171,-0.114804663102346,-0.183361270123994,-0.145783041325259,-0.0690831352230203,-0.225775248033138,-0.638671952771851,0.101288021253234,-0.339846475529127,0.167170404418143,0.125894532368176,-0.0089830991432281,0.0147241691924927,2.69,0
1.0,-1.35835406159823,-1.34016307473609,1.77320934263119,0.379779593034328,-0.503198133318193,1.80049938079263,0.791460956450422,0.247675786588991,-1.51465432260583,0.207642865216696,0.624501459424895,0.066083685268831,0.717292731410831,-0.165945922763554,2.34586494901581,-2.89008319444231,1.10996937869599,-0.121359313195888,-2.26185709530414,0.524979725224404,0.247998153469754,0.771679401917229,0.909412262347719,-0.689280956490685,-0.327641833735251,-0.139096571514147,-0.0553527940384261,-0.0597518405929204,378.66,0
1.0,-0.966271711572087,-0.185226008082898,1.79299333957872,-0.863291275036453,-0.0103088796030823,1.24720316752486,0.23760893977178,0.377435874652262,-1.38702406270197,-0.0549519224713749,-0.226487263835401,0.178228225877303,0.507756869957169,-0.28792374549456,-0.631418117709045,-1.0596472454325,-0.684092786345479,1.96577500349538,-1.2326219700892,-0.208037781160366,-0.108300452035545,0.0052735967825345,-0.190320518742841,-1.17557533186321,0.647376034602038,-0.221928844458407,0.0627228487293033,0.0614576285006353,123.5,0
2.0,-1.15823309349523,0.877736754848451,1.548717846511,0.403033933955121,-0.407193377311653,0.0959214624684256,0.592940745385545,-0.270532677192282,0.817739308235294,0.753074431976354,-0.822842877946363,0.53819555014995,1.3458515932154,-1.11966983471731,0.175121130008994,-0.451449182813529,-0.237033239362776,-0.0381947870352842,0.803486924960175,0.408542360392758,-0.0094306971323291,0.79827849458971,-0.137458079619063,0.141266983824769,-0.206009587619756,0.502292224181569,0.219422229513348,0.215153147499206,69.99,0


In [0]:
# Print the schema
df.printSchema()

# Count the total number of rows
print("Total rows:", df.count())

# Check the class distribution (fraud vs. non-fraud)
display(df.groupBy("Class").count())

root
 |-- Time: double (nullable = true)
 |-- V1: double (nullable = true)
 |-- V2: double (nullable = true)
 |-- V3: double (nullable = true)
 |-- V4: double (nullable = true)
 |-- V5: double (nullable = true)
 |-- V6: double (nullable = true)
 |-- V7: double (nullable = true)
 |-- V8: double (nullable = true)
 |-- V9: double (nullable = true)
 |-- V10: double (nullable = true)
 |-- V11: double (nullable = true)
 |-- V12: double (nullable = true)
 |-- V13: double (nullable = true)
 |-- V14: double (nullable = true)
 |-- V15: double (nullable = true)
 |-- V16: double (nullable = true)
 |-- V17: double (nullable = true)
 |-- V18: double (nullable = true)
 |-- V19: double (nullable = true)
 |-- V20: double (nullable = true)
 |-- V21: double (nullable = true)
 |-- V22: double (nullable = true)
 |-- V23: double (nullable = true)
 |-- V24: double (nullable = true)
 |-- V25: double (nullable = true)
 |-- V26: double (nullable = true)
 |-- V27: double (nullable = true)
 |-- V28: double (nulla

Class,count
1,492
0,284315


Data Preprocessing

In [0]:
df = df.drop("Time")

Handle Imbalanced Data

In [0]:
from pyspark.sql.functions import col

# Separate fraud and non-fraud data
fraud_df = df.filter(col("Class") == 1)
non_fraud_df = df.filter(col("Class") == 0)

# Undersample non-fraud data
sampled_non_fraud_df = non_fraud_df.sample(fraction=0.01, seed=42)

# Combine the datasets
balanced_df = fraud_df.union(sampled_non_fraud_df)

# Check the new class distribution
display(balanced_df.groupBy("Class").count())

Class,count
1,492
0,2866


scale the features

In [0]:
from pyspark.ml.feature import VectorAssembler, StandardScaler

# Assemble features into a single vector
assembler = VectorAssembler(inputCols=["Amount"] + [f"V{i}" for i in range(1, 29)], outputCol="features")
df = assembler.transform(df)

# Scale the features
scaler = StandardScaler(inputCol="features", outputCol="scaledFeatures")
scaler_model = scaler.fit(df)
df = scaler_model.transform(df)

# Show the transformed DataFrame
display(df.select("scaledFeatures", "Class").limit(5))

scaledFeatures,Class
"Map(vectorType -> dense, length -> 29, values -> List(0.5981926061623504, -0.6942411021638852, -0.0440748471935518, 1.6727705625425264, 0.9733638055579661, -0.24511615322163974, 0.34706733585762145, 0.1936785982994624, 0.08263713433481461, 0.3311272018898171, 0.08338539885285962, -0.5404060870051797, -0.6182946323291837, -0.9960971732491677, -0.3246096164490301, 1.604011022937432, -0.5368319260681518, 0.2448630241444076, 0.030769878583235596, 0.49628115539005957, 0.3261174434913668, -0.02492332120662109, 0.3828537662094795, -0.17691102375613274, 0.11050672655740985, 0.24658501006984543, -0.39216974306340907, 0.33089104174320316, -0.0637810387248247))",0
"Map(vectorType -> dense, length -> 29, values -> List(0.010754832980729332, 0.6084952594310952, 0.16117563692663997, 0.10979690934882776, 0.3165223710679578, 0.04348327570891918, -0.06181985742147667, -0.06370009791342846, 0.07125335796867134, -0.23249378077926594, -0.15334935939751806, 1.580000075702499, 1.0660866993727665, 0.49141734116292446, -0.14998221852172033, 0.6943591996604377, 0.5294328243583352, -0.13516973179726777, -0.21876219825788853, -0.1790857349082969, -0.08961070531227833, -0.30737626492223197, -0.8800752094067651, 0.16220089886855985, -0.5611295646957747, 0.32069333790440524, 0.26106901709386826, -0.022255639115493928, 0.0446074393682205))",0
"Map(vectorType -> dense, length -> 29, values -> List(1.5139126604025908, -0.6934992452238191, -0.8115764015230559, 1.1694664397320775, 0.26823082294214673, -0.3645711458040887, 1.3514512133714873, 0.6397745147279934, 0.20737236543359341, -1.378672931058512, 0.19069927901388312, 0.6118286359340674, 0.06613650257222503, 0.7206985870825093, -0.17311358493070836, 2.562901685603002, -3.298229581937316, 1.3068655851353743, -0.14478973730846748, -2.7785559725451017, 0.6809737760183282, 0.3376311034323613, 1.063356404316374, 1.456317189039576, -1.1380901404568373, -0.6285356170752573, -0.2884462456250787, -0.13713661493785764, -0.18102050930977134))",0
"Map(vectorType -> dense, length -> 29, values -> List(0.4937627781115511, -0.49332403207741354, -0.11216922771730702, 1.1825143748617035, -0.6097255708019895, -0.0074688672314202975, 0.9361481886931922, 0.1920703010029605, 0.31601704471831427, -1.2625009557851772, -0.05046786454608884, -0.221891224748239, 0.17837067456552683, 0.5101678056040084, -0.30035996630214945, -0.6898361983181155, -1.209293870164272, -0.8054432281407429, 2.3453004047924333, -1.514202265029094, -0.2698547516872058, -0.1474430378266673, 0.00726689464375752, -0.3047760123187692, -1.9410237319908816, 1.2419015323822085, -0.46021653361476894, 0.1553959344495727, 0.18618825967134908))",0
"Map(vectorType -> dense, length -> 29, values -> List(0.2798255614577122, -0.5913287255806574, 0.5315401165822371, 1.0214098823894229, 0.28465490447934966, -0.2950149181822865, 0.07199845677494943, 0.479301441909162, -0.22650983355730225, 0.744324980360488, 0.6916238180002389, -0.8061451706863964, 0.5386257022751221, 1.3522419776167227, -1.1680314634958737, 0.19132313625856917, -0.5152042170828695, -0.2790802962693358, -0.04556892286021143, 0.9870355642029609, 0.5299378632216263, -0.012839195108414563, 1.1000093400718385, -0.22012300956314917, 0.23324967844983227, -0.39520094794314836, 1.0416094709933832, 0.5436188420170638, 0.6518147717866488))",0


Train a Machine Learning Model

Split the Data

In [0]:
train_df, test_df = df.randomSplit([0.8, 0.2], seed=42)

Train a Logistic Regression Model

In [0]:
from pyspark.ml.classification import LogisticRegression

# Initialize the model
lr = LogisticRegression(featuresCol="scaledFeatures", labelCol="Class")

# Train the model
lr_model = lr.fit(train_df)

Evaluate the Model

In [0]:
from pyspark.ml.evaluation import BinaryClassificationEvaluator

# Make predictions
predictions = lr_model.transform(test_df)

# Evaluate using AUC-ROC
evaluator = BinaryClassificationEvaluator(labelCol="Class", rawPredictionCol="rawPrediction", metricName="areaUnderROC")
auc = evaluator.evaluate(predictions)
print(f"AUC-ROC: {auc}")

AUC-ROC: 0.968724625140629


getting an AUC-ROC score of 0.9687 (96.87%) accuracy.