In [4]:
!unzip '/root/git_code/data/titanic.zip' -d '/root/git_code/data'

Archive:  /root/git_code/data/titanic.zip
  inflating: /root/git_code/data/gender_submission.csv  
  inflating: /root/git_code/data/test.csv  
  inflating: /root/git_code/data/train.csv  


In [26]:
from pyspark.ml.feature import OneHotEncoder,StringIndexer,VectorAssembler, StandardScaler
from pyspark.ml import Pipeline
from pyspark.ml.classification import LogisticRegression 
from pyspark.sql.functions import col,stddev_samp

# 데이터 수집

In [76]:
df = spark.read.format('csv')\
    .options(header = 'true', inferSchema = 'true')\
    .load('/root/git_code/data/train.csv').cache()

24/03/26 18:19:37 WARN CacheManager: Asked to cache already cached data.


In [77]:
test_df = spark.read.format('csv')\
    .options(header = 'true', inferSchema = 'true')\
    .load('/root/git_code/data/test.csv').cache()

24/03/26 18:19:38 WARN CacheManager: Asked to cache already cached data.


In [78]:
df.show(2)

+-----------+--------+------+--------------------+------+----+-----+-----+---------+-------+-----+--------+
|PassengerId|Survived|Pclass|                Name|   Sex| Age|SibSp|Parch|   Ticket|   Fare|Cabin|Embarked|
+-----------+--------+------+--------------------+------+----+-----+-----+---------+-------+-----+--------+
|          1|       0|     3|Braund, Mr. Owen ...|  male|22.0|    1|    0|A/5 21171|   7.25| NULL|       S|
|          2|       1|     1|Cumings, Mrs. Joh...|female|38.0|    1|    0| PC 17599|71.2833|  C85|       C|
+-----------+--------+------+--------------------+------+----+-----+-----+---------+-------+-----+--------+
only showing top 2 rows



In [79]:
df.select("Survived").distinct().show()
df.select("Pclass").distinct().show()

+--------+
|Survived|
+--------+
|       1|
|       0|
+--------+

+------+
|Pclass|
+------+
|     1|
|     3|
|     2|
+------+



# 결측지 제거
결측치 처리
1. 평균: 위험, 튀는 값이 있어 평균이 틀어질 수 있음
2. 중위수
3. 삭제: 표본이 아주 많고 결측치 비율이 적은 경우 괜찮 -> 지금은 날릴거

In [80]:
from pyspark.sql.functions import isnan, count
from pyspark.sql.functions import mean, col, split, regexp_extract, when, lit

# when(조건A, 조건A가 True일 시 value).otherwise(조건A가 False일 시 value)
df.select([count(when(isnan(c) | col(c).isNull(), c)).alias(c) for c in df.columns])\
  .show()

df.groupby('Survived').count()

+-----------+--------+------+----+---+---+-----+-----+------+----+-----+--------+
|PassengerId|Survived|Pclass|Name|Sex|Age|SibSp|Parch|Ticket|Fare|Cabin|Embarked|
+-----------+--------+------+----+---+---+-----+-----+------+----+-----+--------+
|          0|       0|     0|   0|  0|177|    0|    0|     0|   0|  687|       2|
+-----------+--------+------+----+---+---+-----+-----+------+----+-----+--------+



DataFrame[Survived: int, count: bigint]

In [81]:
# dropna(subset=['Age']) -> 결측치가 너무 많음 우선 그냥 날리기
df = df.dropna(subset=['Age'])

In [82]:
df.count()

714

In [83]:
df.select([count(when(isnan(c) | col(c).isNull(), c)).alias(c) for c in df.columns])\
  .show()

+-----------+--------+------+----+---+---+-----+-----+------+----+-----+--------+
|PassengerId|Survived|Pclass|Name|Sex|Age|SibSp|Parch|Ticket|Fare|Cabin|Embarked|
+-----------+--------+------+----+---+---+-----+-----+------+----+-----+--------+
|          0|       0|     0|   0|  0|  0|    0|    0|     0|   0|  529|       2|
+-----------+--------+------+----+---+---+-----+-----+------+----+-----+--------+



# 분석내용
Pclass, 나이, 성별에 따라 생존여부

# 범주형 데이터에 대한 전처리 준비

In [84]:
params = {
    'inputCol' : 'Sex',
    'outputCol' : 'SexIdx'
}
strIdx = StringIndexer(**params)

In [85]:
params = {
    'inputCol' : 'SexIdx',
    'outputCol' : 'SexClassVec'
}
encode = OneHotEncoder(**params)

In [86]:
#정답
params = {
    'inputCol' : 'Survived',
    'outputCol' : 'label'
}
label_strIdx = StringIndexer(**params)

In [87]:
stage = [strIdx, encode, label_strIdx]
stage

[StringIndexer_11c34870700d,
 OneHotEncoder_f06267523b0c,
 StringIndexer_817ef1ed4175]

# 연속형 데이터에 대한 전처리 준비

In [88]:
numCols = ['Age','Fare']
#scaling
for c in numCols:
    df = df.withColumn(c+'Scaled', col(c) / df.agg(stddev_samp(c)).first()[0])

In [89]:
df.show(2)

+-----------+--------+------+--------------------+------+----+-----+-----+---------+-------+-----+--------+------------------+------------------+
|PassengerId|Survived|Pclass|                Name|   Sex| Age|SibSp|Parch|   Ticket|   Fare|Cabin|Embarked|         AgeScaled|        FareScaled|
+-----------+--------+------+--------------------+------+----+-----+-----+---------+-------+-----+--------+------------------+------------------+
|          1|       0|     3|Braund, Mr. Owen ...|  male|22.0|    1|    0|A/5 21171|   7.25| NULL|       S|1.5144738264626911|0.1370020155009282|
|          2|       1|     1|Cumings, Mrs. Joh...|female|38.0|    1|    0| PC 17599|71.2833|  C85|       C|2.6159093366173756|1.3470283822837676|
+-----------+--------+------+--------------------+------+----+-----+-----+---------+-------+-----+--------+------------------+------------------+
only showing top 2 rows



In [90]:
inputs = ['SexClassVec', 'AgeScaled', 'FareScaled']

In [91]:
assembler = VectorAssembler(inputCols=inputs, outputCol='features', handleInvalid = 'keep')
stage +=[assembler]
stage


#첫번째 코드 아래 코드와 동일
# params = {
#     'inputCols' : ['SexClassVec', 'AgeScaled', 'FareScaled'],
#     'outputCol' : 'features'
# }
# assembler = VectorAssembler(**params)

[StringIndexer_11c34870700d,
 OneHotEncoder_f06267523b0c,
 StringIndexer_817ef1ed4175,
 VectorAssembler_521912917018]

# 학습용, 평가용 데이터 준비 끝

In [92]:
#pipe line 
pipeline = Pipeline(stages=stage)
piplineModel = pipeline.fit(df)
dataset = piplineModel.transform(df)

In [93]:
dataset.show(5)

+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+------------------+-------------------+------+-------------+-----+--------------------+
|PassengerId|Survived|Pclass|                Name|   Sex| Age|SibSp|Parch|          Ticket|   Fare|Cabin|Embarked|         AgeScaled|         FareScaled|SexIdx|  SexClassVec|label|            features|
+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+------------------+-------------------+------+-------------+-----+--------------------+
|          1|       0|     3|Braund, Mr. Owen ...|  male|22.0|    1|    0|       A/5 21171|   7.25| NULL|       S|1.5144738264626911| 0.1370020155009282|   0.0|(1,[0],[1.0])|  0.0|[1.0,1.5144738264...|
|          2|       1|     1|Cumings, Mrs. Joh...|female|38.0|    1|    0|        PC 17599|71.2833|  C85|       C|2.6159093366173756| 1.3470283822837676|   1.0|    (1,[],[])|  1.0|[0.0,2.61590

In [94]:
(train, test) = dataset.randomSplit([0.7,0.3], seed=14)

# 적절한 모델준비

In [95]:
lr = LogisticRegression(labelCol='label', featuresCol='features', maxIter=10)

In [97]:
lrModel = lr.fit(train) #훈련
predictions = lrModel.transform(test) #예측
predictions.show(5)

+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+-------------------+-------------------+------+-------------+-----+--------------------+--------------------+--------------------+----------+
|PassengerId|Survived|Pclass|                Name|   Sex| Age|SibSp|Parch|          Ticket|   Fare|Cabin|Embarked|          AgeScaled|         FareScaled|SexIdx|  SexClassVec|label|            features|       rawPrediction|         probability|prediction|
+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+-------------------+-------------------+------+-------------+-----+--------------------+--------------------+--------------------+----------+
|          3|       1|     3|Heikkinen, Miss. ...|female|26.0|    0|    0|STON/O2. 3101282|  7.925| NULL|       S| 1.7898327040013622|0.14975737556480773|   1.0|    (1,[],[])|  1.0|[0.0,1.7898327040...|[-0.7640567598771...|[0.317766

In [99]:
predictions.select('Survived','label','prediction').show(10)

+--------+-----+----------+
|Survived|label|prediction|
+--------+-----+----------+
|       1|  1.0|       1.0|
|       1|  1.0|       1.0|
|       0|  0.0|       0.0|
|       0|  0.0|       0.0|
|       1|  1.0|       1.0|
|       0|  0.0|       0.0|
|       1|  1.0|       1.0|
|       0|  0.0|       0.0|
|       0|  0.0|       1.0|
|       1|  1.0|       1.0|
+--------+-----+----------+
only showing top 10 rows



# 평가

In [101]:
from pyspark.ml.evaluation import BinaryClassificationEvaluator
evaluator = BinaryClassificationEvaluator(rawPredictionCol='rawPrediction')
evaluator.evaluate(predictions)

0.8455151964418081

In [111]:
predictions.createOrReplaceTempView('predic')
print('테스트 표본:',predictions.count(), "\n"
'틀린 갯수: ', spark.sql("select * from predic where label != prediction").count())

테스트 표본: 237 
틀린 갯수:  54
