In [9]:
from pyspark.ml.feature import OneHotEncoder,StringIndexer,VectorAssembler, StandardScaler
from pyspark.ml import Pipeline
from pyspark.ml.classification import LogisticRegression 
from pyspark.sql.functions import col,stddev_samp

# 데이터 수집

In [1]:
df = spark.read.format('csv')\
    .options(header = 'true', inferSchema = 'true')\
    .load('/root/git_code/data/train.csv').cache()

                                                                                

In [77]:
test_df = spark.read.format('csv')\
    .options(header = 'true', inferSchema = 'true')\
    .load('/root/git_code/data/test.csv').cache()

24/03/26 18:19:38 WARN CacheManager: Asked to cache already cached data.


In [79]:
df.select("Survived").distinct().show()
df.select("Pclass").distinct().show()

+--------+
|Survived|
+--------+
|       1|
|       0|
+--------+

+------+
|Pclass|
+------+
|     1|
|     3|
|     2|
+------+



# 결측지 제거
결측치 처리
1. 평균: 위험, 튀는 값이 있어 평균이 틀어질 수 있음
2. 중위수
3. 삭제: 표본이 아주 많고 결측치 비율이 적은 경우 괜찮 -> 지금은 날릴거

In [2]:
from pyspark.sql.functions import isnan, count
from pyspark.sql.functions import mean, col, split, regexp_extract, when, lit

# when(조건A, 조건A가 True일 시 value).otherwise(조건A가 False일 시 value)
df.select([count(when(isnan(c) | col(c).isNull(), c)).alias(c) for c in df.columns])\
  .show()

df.groupby('Survived').count()

                                                                                

+-----------+--------+------+----+---+---+-----+-----+------+----+-----+--------+
|PassengerId|Survived|Pclass|Name|Sex|Age|SibSp|Parch|Ticket|Fare|Cabin|Embarked|
+-----------+--------+------+----+---+---+-----+-----+------+----+-----+--------+
|          0|       0|     0|   0|  0|177|    0|    0|     0|   0|  687|       2|
+-----------+--------+------+----+---+---+-----+-----+------+----+-----+--------+



DataFrame[Survived: int, count: bigint]

In [4]:
from pyspark.sql.functions import avg


mean_age_by_sex_pclass = df.groupBy("Sex", "Pclass").agg(avg("Age").alias("MeanAge"))
# 결측치를 해당 그룹의 평균 나이로 채우기
filled_df = df.join(mean_age_by_sex_pclass, ["Sex", "Pclass"], "left") \
              .withColumn("AgeFilled", when(col("Age").isNull(), col("MeanAge")).otherwise(col("Age"))) \
              .drop("MeanAge")
# 결과 확인
filled_df.show()

+------+------+-----------+--------+--------------------+----+-----+-----+----------------+-------+-----+--------+------------------+
|   Sex|Pclass|PassengerId|Survived|                Name| Age|SibSp|Parch|          Ticket|   Fare|Cabin|Embarked|         AgeFilled|
+------+------+-----------+--------+--------------------+----+-----+-----+----------------+-------+-----+--------+------------------+
|  male|     3|          1|       0|Braund, Mr. Owen ...|22.0|    1|    0|       A/5 21171|   7.25| NULL|       S|              22.0|
|female|     1|          2|       1|Cumings, Mrs. Joh...|38.0|    1|    0|        PC 17599|71.2833|  C85|       C|              38.0|
|female|     3|          3|       1|Heikkinen, Miss. ...|26.0|    0|    0|STON/O2. 3101282|  7.925| NULL|       S|              26.0|
|female|     1|          4|       1|Futrelle, Mrs. Ja...|35.0|    1|    0|          113803|   53.1| C123|       S|              35.0|
|  male|     3|          5|       0|Allen, Mr. Willia...|35.0|

In [6]:
filled_df.select([count(when(isnan(c) | col(c).isNull(), c)).alias(c) for c in filled_df.columns])\
  .show()

+---+------+-----------+--------+----+---+-----+-----+------+----+-----+--------+---------+
|Sex|Pclass|PassengerId|Survived|Name|Age|SibSp|Parch|Ticket|Fare|Cabin|Embarked|AgeFilled|
+---+------+-----------+--------+----+---+-----+-----+------+----+-----+--------+---------+
|  0|     0|          0|       0|   0|177|    0|    0|     0|   0|  687|       2|        0|
+---+------+-----------+--------+----+---+-----+-----+------+----+-----+--------+---------+



# 분석내용
Pclass, 나이, 성별에 따라 생존여부

# 범주형 데이터에 대한 전처리 준비

In [10]:
params = {
    'inputCol' : 'Sex',
    'outputCol' : 'SexIdx'
}
strIdx = StringIndexer(**params)

In [11]:
params = {
    'inputCol' : 'SexIdx',
    'outputCol' : 'SexClassVec'
}
encode = OneHotEncoder(**params)

In [12]:
#정답
params = {
    'inputCol' : 'Survived',
    'outputCol' : 'label'
}
label_strIdx = StringIndexer(**params)

In [13]:
stage = [strIdx, encode, label_strIdx]
stage

[StringIndexer_ccc5fefdbfd4,
 OneHotEncoder_9ece67685625,
 StringIndexer_f98d9c01e018]

# 연속형 데이터에 대한 전처리 준비

In [14]:
numCols = ['AgeFilled','Fare']
#scaling
for c in numCols:
    filled_df = filled_df.withColumn(c+'Scaled', col(c) / filled_df.agg(stddev_samp(c)).first()[0])

In [15]:
filled_df.show(2)

+------+------+-----------+--------+--------------------+----+-----+-----+---------+-------+-----+--------+---------+------------------+-------------------+
|   Sex|Pclass|PassengerId|Survived|                Name| Age|SibSp|Parch|   Ticket|   Fare|Cabin|Embarked|AgeFilled|   AgeFilledScaled|         FareScaled|
+------+------+-----------+--------+--------------------+----+-----+-----+---------+-------+-----+--------+---------+------------------+-------------------+
|  male|     3|          1|       0|Braund, Mr. Owen ...|22.0|    1|    0|A/5 21171|   7.25| NULL|       S|     22.0|1.6564889525368385|0.14589454188740145|
|female|     1|          2|       1|Cumings, Mrs. Joh...|38.0|    1|    0| PC 17599|71.2833|  C85|       C|     38.0|2.8612081907454483| 1.4344612962375451|
+------+------+-----------+--------+--------------------+----+-----+-----+---------+-------+-----+--------+---------+------------------+-------------------+
only showing top 2 rows



In [16]:
inputs = ['SexClassVec', 'AgeFilledScaled', 'FareScaled']

In [17]:
assembler = VectorAssembler(inputCols=inputs, outputCol='features', handleInvalid = 'keep')
stage +=[assembler]
stage


#첫번째 코드 아래 코드와 동일
# params = {
#     'inputCols' : ['SexClassVec', 'AgeScaled', 'FareScaled'],
#     'outputCol' : 'features'
# }
# assembler = VectorAssembler(**params)

[StringIndexer_ccc5fefdbfd4,
 OneHotEncoder_9ece67685625,
 StringIndexer_f98d9c01e018,
 VectorAssembler_388c698015c2]

# 학습용, 평가용 데이터 준비 끝

In [18]:
#pipe line 
pipeline = Pipeline(stages=stage)
piplineModel = pipeline.fit(filled_df)
dataset = piplineModel.transform(filled_df)

                                                                                

In [19]:
dataset.show(5)

+------+------+-----------+--------+--------------------+----+-----+-----+----------------+-------+-----+--------+---------+------------------+-------------------+------+-------------+-----+--------------------+
|   Sex|Pclass|PassengerId|Survived|                Name| Age|SibSp|Parch|          Ticket|   Fare|Cabin|Embarked|AgeFilled|   AgeFilledScaled|         FareScaled|SexIdx|  SexClassVec|label|            features|
+------+------+-----------+--------+--------------------+----+-----+-----+----------------+-------+-----+--------+---------+------------------+-------------------+------+-------------+-----+--------------------+
|  male|     3|          1|       0|Braund, Mr. Owen ...|22.0|    1|    0|       A/5 21171|   7.25| NULL|       S|     22.0|1.6564889525368385|0.14589454188740145|   0.0|(1,[0],[1.0])|  0.0|[1.0,1.6564889525...|
|female|     1|          2|       1|Cumings, Mrs. Joh...|38.0|    1|    0|        PC 17599|71.2833|  C85|       C|     38.0|2.8612081907454483| 1.434461

In [20]:
(train, test) = dataset.randomSplit([0.7,0.3], seed=14)

# 적절한 모델준비

In [21]:
lr = LogisticRegression(labelCol='label', featuresCol='features', maxIter=10)

In [22]:
lrModel = lr.fit(train) #훈련
predictions = lrModel.transform(test) #예측
predictions.show(5)

24/03/26 18:47:56 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS


+------+------+-----------+--------+--------------------+----+-----+-----+--------+--------+-----------+--------+-----------------+------------------+------------------+------+-----------+-----+--------------------+--------------------+--------------------+----------+
|   Sex|Pclass|PassengerId|Survived|                Name| Age|SibSp|Parch|  Ticket|    Fare|      Cabin|Embarked|        AgeFilled|   AgeFilledScaled|        FareScaled|SexIdx|SexClassVec|label|            features|       rawPrediction|         probability|prediction|
+------+------+-----------+--------+--------------------+----+-----+-----+--------+--------+-----------+--------+-----------------+------------------+------------------+------+-----------+-----+--------------------+--------------------+--------------------+----------+
|female|     1|         12|       1|Bonnell, Miss. El...|58.0|    0|    0|  113783|   26.55|       C103|       S|             58.0|  4.36710723850621|0.5342758740842081|   1.0|  (1,[],[])|  1.0

In [23]:
predictions.select('Survived','label','prediction').show(10)

+--------+-----+----------+
|Survived|label|prediction|
+--------+-----+----------+
|       1|  1.0|       1.0|
|       1|  1.0|       1.0|
|       1|  1.0|       1.0|
|       1|  1.0|       1.0|
|       1|  1.0|       1.0|
|       1|  1.0|       1.0|
|       1|  1.0|       1.0|
|       1|  1.0|       1.0|
|       1|  1.0|       1.0|
|       1|  1.0|       1.0|
+--------+-----+----------+
only showing top 10 rows



# 평가

In [24]:
from pyspark.ml.evaluation import BinaryClassificationEvaluator
evaluator = BinaryClassificationEvaluator(rawPredictionCol='rawPrediction')
evaluator.evaluate(predictions)

                                                                                

0.8173761513567332

In [25]:
predictions.createOrReplaceTempView('predic')
print('테스트 표본:',predictions.count(), "\n"
'틀린 갯수: ', spark.sql("select * from predic where label != prediction").count())

테스트 표본: 298 
틀린 갯수:  59
