In [1]:
from pyspark import SparkContext
import numpy as np

In [2]:
# Initialize Spark Context
sc = SparkContext(master="local[*]", appName="FraudDetection")


In [None]:
# Load data
file_path = "../creditcard.csv/creditcard.csv"
# raw_data = sc.textFile(file_path)

In [None]:
# 1. Improved Data Loading and Parsing
def load_and_parse_data(sc, filepath):
    """Load and parse CSV data, handling headers and malformed records"""
    try:
        # More robust header handling
        lines = sc.textFile(filepath)
        # header = lines.first()

        indexed_rdd = lines.zipWithIndex()

        header = indexed_rdd.filter(lambda x: x[1] == 0).map(lambda x: x[0]).collect()
        if header:
            header = header[0]
            data = indexed_rdd.filter(lambda x: x[1] > 0).map(lambda x: x[0])
        
        # Skip header and parse data
        data = lines.filter(lambda line: line != header).map(
            lambda line: [float(x.strip('"')) if x.strip('"').isdigit() else 0.0 
            for x in line.split(",")]
        )
        
        # Create feature-label pairs, handle empty lines
        rdd_data = data.filter(lambda cols: len(cols) > 1).map(
            lambda cols: (cols[:-1], cols[-1])
        )
        
        # Cache as we'll reuse this RDD
        rdd_data.cache()
        
        # Count features for verification
        num_features = len(rdd_data.first()[0]) if not rdd_data.isEmpty() else 0
        print(f"Loaded dataset with {rdd_data.count()} records and {num_features} features")
        
        return rdd_data
        
    except Exception as e:
        print(f"Error loading data: {str(e)}")
        return sc.emptyRDD()


In [None]:
rdd_data = load_and_parse_data(sc, file_path)

In [None]:
# Parse CSV and cache the RDD
parsed_data = raw_data.filter(lambda line: line != header) \
                     .map(lambda line: line.split(",")) \
                     .map(lambda cols: (
                         [float(x.strip('"')) for x in cols[:-1]],  # Features (V1-V28, Time, Amount)
                         int(cols[-1])                  # Class (0 or 1)
                     )).cache()

# Count classes for imbalance analysis
class_counts = parsed_data.map(lambda x: (x[1], 1)) \
                        .reduceByKey(lambda a, b: a + b) \
                        .collect()

print("Class distribution:", dict(class_counts))