In [4]:
from pyspark.sql import SparkSession 

spark = SparkSession.builder.appName("StrokePreProcessing").getOrCreate()

df = spark.read.csv("../../data/raw/patients_data.csv", header=True, inferSchema=True)

df.show(5)
df.printSchema()


+---------+-------+------+-------------+------------+----------------+-----------------+----------------+--------------+---------+---------+---------+-------------+-------+---------------------+----------------+------------+-----------+-------------------+-----------------------+-----------------------+-----------------+-------------------------+-----------------+-------------+--------------------+---------+---------------------+---------------+----------+------------+-------------+--------------------+----------------+--------+
|PatientID|  State|   Sex|GeneralHealth| AgeCategory|  HeightInMeters|WeightInKilograms|             BMI|HadHeartAttack|HadAngina|HadStroke|HadAsthma|HadSkinCancer|HadCOPD|HadDepressiveDisorder|HadKidneyDisease|HadArthritis|HadDiabetes|DeafOrHardOfHearing|BlindOrVisionDifficulty|DifficultyConcentrating|DifficultyWalking|DifficultyDressingBathing|DifficultyErrands| SmokerStatus|     ECigaretteUsage|ChestScan|RaceEthnicityCategory|AlcoholDrinkers|HIVTesting|FluVa

In [39]:
print(f"Số lượng bản ghi: {df.count()}")

Số lượng bản ghi: 237630


In [40]:
from pyspark.sql.functions import col, sum
df.select([sum(col(c).isNull().cast("int")).alias(c) for c in df.columns]).show()

+---------+-----+---+-------------+-----------+--------------+-----------------+---+--------------+---------+---------+---------+-------------+-------+---------------------+----------------+------------+-----------+-------------------+-----------------------+-----------------------+-----------------+-------------------------+-----------------+------------+---------------+---------+---------------------+---------------+----------+------------+-------------+-----------------+----------------+--------+
|PatientID|State|Sex|GeneralHealth|AgeCategory|HeightInMeters|WeightInKilograms|BMI|HadHeartAttack|HadAngina|HadStroke|HadAsthma|HadSkinCancer|HadCOPD|HadDepressiveDisorder|HadKidneyDisease|HadArthritis|HadDiabetes|DeafOrHardOfHearing|BlindOrVisionDifficulty|DifficultyConcentrating|DifficultyWalking|DifficultyDressingBathing|DifficultyErrands|SmokerStatus|ECigaretteUsage|ChestScan|RaceEthnicityCategory|AlcoholDrinkers|HIVTesting|FluVaxLast12|PneumoVaxEver|TetanusLast10Tdap|HighRiskLastYear|

In [41]:
from pyspark.sql.functions import when, col
df = df.withColumn("Sex", when(col("Sex") == "Male", 1).otherwise(0))

In [42]:
from pyspark.sql.functions import regexp_extract

df = df.withColumn("AgeMin", regexp_extract(col("AgeCategory"), r"(\d+)", 1).cast("int"))
df = df.withColumn("AgeMax", when(col("AgeCategory") == "Age 80 or older", 100)
                               .otherwise(regexp_extract(col("AgeCategory"), r"to (\d+)", 1).cast("int")))

df.select("AgeCategory", "AgeMin", "AgeMax").show(10)


+---------------+------+------+
|    AgeCategory|AgeMin|AgeMax|
+---------------+------+------+
|   Age 75 to 79|    75|    79|
|   Age 65 to 69|    65|    69|
|   Age 60 to 64|    60|    64|
|   Age 70 to 74|    70|    74|
|   Age 50 to 54|    50|    54|
|   Age 75 to 79|    75|    79|
|   Age 70 to 74|    70|    74|
|   Age 60 to 64|    60|    64|
|Age 80 or older|    80|   100|
|Age 80 or older|    80|   100|
+---------------+------+------+
only showing top 10 rows



In [43]:
from pyspark.sql.functions import when, col

df = df.withColumn("GeneralHealth",
                   when(col("GeneralHealth") == "Excellent", 4)
                   .when(col("GeneralHealth") == "Very good", 3)
                   .when(col("GeneralHealth") == "Good", 2)
                   .when(col("GeneralHealth") == "Fair", 1)
                   .when(col("GeneralHealth") == "Poor", 0))

df = df.withColumn("SmokerStatus",
                   when(col("SmokerStatus") == "Never smoked", 0)
                   .when(col("SmokerStatus") == "Former smoker", 1)
                   .when(col("SmokerStatus") == "Current smoker - now smokes some days", 2)
                   .when(col("SmokerStatus") == "Current smoker - now smokes every day", 3))

df = df.withColumn("ECigaretteUsage",
                   when(col("ECigaretteUsage") == "Never used e-cigarettes in my entire life", 0)
                   .when(col("ECigaretteUsage") == "Not at all (right now)", 1)
                   .when(col("ECigaretteUsage") == "Use them some days", 2)
                   .when(col("ECigaretteUsage") == "Use them every day", 3))

# Kiểm tra kết quả
df.select("GeneralHealth", 
          "SmokerStatus", 
          "ECigaretteUsage").show()


+-------------+------------+---------------+
|GeneralHealth|SmokerStatus|ECigaretteUsage|
+-------------+------------+---------------+
|            1|           1|              0|
|            3|           1|              0|
|            4|           0|              0|
|            3|           1|              0|
|            2|           0|              0|
|            3|           1|              0|
|            2|           0|              0|
|            1|           0|              0|
|            1|           0|              0|
|            2|           0|              0|
|            2|           0|              0|
|            2|           0|              1|
|            2|           0|              0|
|            4|           3|              1|
|            2|           0|              0|
|            2|           0|              0|
|            1|           0|              0|
|            0|           1|              0|
|            3|           0|              0|
|         

In [None]:
# Xoá cột nếu đã tồn tại
for col in ["TetanusLast10TdapIndex", "HadDiabetesIndex"]:
    if col in df.columns:
        df = df.drop(col)

In [None]:
from pyspark.ml.feature import StringIndexer

from pyspark.ml import Pipeline

indexers = [
    StringIndexer(inputCol=col, outputCol=col + "Index") 
    for col in ["TetanusLast10Tdap", "HadDiabetes"]
]

pipeline = Pipeline(stages=indexers)
df = pipeline.fit(df).transform(df)

In [5]:
from pyspark.sql.functions import col, mean, stddev

# Hàm xử lý ngoại lai
def remover_outliers(df, col_name):
    q1, q3 = df.approxQuantile(col_name, [0.25, 0.75], 0.01)  # Lấy Q1, Q3
    iqr = q3 - q1
    lower_bound = q1 - 1.5 * iqr
    upper_bound = q3 + 1.5 * iqr
    return df.filter((col(col_name) >= lower_bound) & (col(col_name) <= upper_bound))

# Xử lý ngoại lai và chuẩn hóa
for column in ["BMI", "HeightInMeters", "WeightInKilograms"]:
    # Xử lý ngoại lai
    df = remover_outliers(df, column)

    # Tính toán mean và stddev
    mean_val = df.agg(mean(col(column))).collect()[0][0]
    std_val = df.agg(stddev(col(column))).collect()[0][0]

    # Hiển thị giá trị trung bình (mean) và độ lệch chuẩn (stddev)
    print(f"Mean của {column}: {mean_val:.6f}")
    print(f"Standard Deviation của {column}: {std_val:.6f}")


    # Chuẩn hóa dữ liệu (z-score)
    df = df.withColumn(column, (col(column) - mean_val) / std_val)

# Kiểm tra kết quả
df.select("BMI", "HeightInMeters", "WeightInKilograms").show(10)


Mean của BMI: 28.016977
Standard Deviation của BMI: 5.363275
Mean của HeightInMeters: 1.705695
Standard Deviation của HeightInMeters: 0.104464
Mean của WeightInKilograms: 81.030545
Standard Deviation của WeightInKilograms: 17.610067
+--------------------+--------------------+--------------------+
|                 BMI|      HeightInMeters|   WeightInKilograms|
+--------------------+--------------------+--------------------+
|  0.7612926770531111| -0.7246055458739628| 0.21518684903257934|
|-0.00502993793843...| -1.0117850547624707| -0.5315452128888883|
| -1.0230645113331922|  0.7112931397185082| -0.5576665836092048|
|  0.3939799029059159|  0.7112931397185082|  0.8074617334421387|
|-0.04791410004891896| -0.2459730310598162|-0.17095609474719256|
|  0.6792535031803855|  1.3813798016083225|  1.5803151660839512|
|-0.40590435395690655|-0.05451888398414...|-0.35153464297942305|
|   0.412625299015657|-0.05451888398414...|  0.3696440265444868|
|  0.6158593697860572|-0.05451888398414...|  0.55022

In [47]:
cols_to_drop = [
    "PatientID", "State", "AgeCategory",
    "RaceEthnicityCategory", "TetanusLast10Tdap", "HadDiabetes"
]

df = df.drop(*cols_to_drop)


In [48]:
from pyspark.sql.functions import col

#Tách dữ liệu theo từng lớp
df_majority = df.filter(col("HadStroke") == 0)  # Lớp đa số
df_minority = df.filter(col("HadStroke") == 1)  # Lớp thiểu số

#Tính tỷ lệ nhân bản
majority_count = df_majority.count()
minority_count = df_minority.count()
oversampling_ratio = majority_count / minority_count  # Số lần nhân bản

#Lấy mẫu có thay thế để tăng số lượng lớp thiểu số
df_minority_oversampled = df_minority.sample(withReplacement=True, fraction=oversampling_ratio, seed=42)

#Kết hợp dữ liệu đã cân bằng
df_balanced = df_majority.union(df_minority_oversampled)

#Kiểm tra số lượng mẫu sau oversampling
df_balanced.groupBy("HadStroke").count().show()


+---------+------+
|HadStroke| count|
+---------+------+
|        0|216814|
|        1|216638|
+---------+------+



In [None]:
df_balanced.toPandas().to_csv("../data/processed_data/preprocess.csv", index=False)