# Logistic regression sample for Spark

In [1]:
spark

Waiting for a Spark session to start...

org.apache.spark.sql.SparkSession@51890b9c

### 1. Read DataFrame

In [56]:
val df = spark.read.option("header", true).csv("/user/jay-n/SampleData/titanic_train.csv")
df.printSchema()

root
 |-- PassengerId: string (nullable = true)
 |-- Survived: string (nullable = true)
 |-- Pclass: string (nullable = true)
 |-- Name: string (nullable = true)
 |-- Sex: string (nullable = true)
 |-- Age: string (nullable = true)
 |-- SibSp: string (nullable = true)
 |-- Parch: string (nullable = true)
 |-- Ticket: string (nullable = true)
 |-- Fare: string (nullable = true)
 |-- Cabin: string (nullable = true)
 |-- Embarked: string (nullable = true)



df = [PassengerId: string, Survived: string ... 10 more fields]


lastException: Throwable = null


[PassengerId: string, Survived: string ... 10 more fields]

In [62]:
val adf = df.filter($"Age".isNotNull).filter($"Sex".isNotNull)
println(df.count())
println(adf.count())

891
714


adf = [PassengerId: string, Survived: string ... 10 more fields]


[PassengerId: string, Survived: string ... 10 more fields]

In [63]:
val bdf = adf.selectExpr("cast(Age as int) Age", "Sex", "cast(Survived as int) Survived")
bdf.printSchema()

root
 |-- Age: integer (nullable = true)
 |-- Sex: string (nullable = true)
 |-- Survived: integer (nullable = true)



bdf = [Age: int, Sex: string ... 1 more field]


[Age: int, Sex: string ... 1 more field]

### 2. Create Pipelines

#### 2-1. OneHotEncoder

In [68]:
import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer}

val indexer = new StringIndexer()
  .setInputCol("Sex")
  .setOutputCol("SexIndex")
  .fit(tdf)
val indexed = indexer.transform(bdf)

val encoder = new OneHotEncoder()
  .setInputCol("SexIndex")
  .setOutputCol("SexVec")

val cdf = encoder.transform(indexed)
cdf.show(3, false)

+---+------+--------+--------+-------------+
|Age|Sex   |Survived|SexIndex|SexVec       |
+---+------+--------+--------+-------------+
|22 |male  |0       |0.0     |(1,[0],[1.0])|
|38 |female|1       |1.0     |(1,[],[])    |
|26 |female|1       |1.0     |(1,[],[])    |
+---+------+--------+--------+-------------+
only showing top 3 rows



indexer = strIdx_a2a1e715a5ba
indexed = [Age: int, Sex: string ... 2 more fields]
encoder = oneHot_4be47d4899da
cdf = [Age: int, Sex: string ... 3 more fields]


lastException: Throwable = null


[Age: int, Sex: string ... 3 more fields]

In [69]:
cdf.printSchema()

root
 |-- Age: integer (nullable = true)
 |-- Sex: string (nullable = true)
 |-- Survived: integer (nullable = true)
 |-- SexIndex: double (nullable = true)
 |-- SexVec: vector (nullable = true)



#### 2-2 VectorAssembler (TEST)

In [71]:
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.linalg.Vectors

val assembler = new VectorAssembler()
  .setInputCols(Array("Age", "SexVec"))
  .setOutputCol("features")

val output = assembler.transform(cdf)
output.show(3, false)

+---+------+--------+--------+-------------+----------+
|Age|Sex   |Survived|SexIndex|SexVec       |features  |
+---+------+--------+--------+-------------+----------+
|22 |male  |0       |0.0     |(1,[0],[1.0])|[22.0,1.0]|
|38 |female|1       |1.0     |(1,[],[])    |[38.0,0.0]|
|26 |female|1       |1.0     |(1,[],[])    |[26.0,0.0]|
+---+------+--------+--------+-------------+----------+
only showing top 3 rows



assembler = vecAssembler_19346505d429
output = [Age: int, Sex: string ... 4 more fields]


[Age: int, Sex: string ... 4 more fields]

### 3. Fit & Check

In [72]:
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.classification.LogisticRegression

val assembler = new VectorAssembler()
  .setInputCols(Array("Age", "SexVec"))
  .setOutputCol("features")

val lr = new LogisticRegression().setLabelCol("Survived")
val pipeline = new Pipeline().setStages(Array(assembler, lr))

val model = pipeline.fit(cdf)

assembler = vecAssembler_939c185f30a0
lr = logreg_423e25e14ba5
pipeline = pipeline_1004ed49807a
model = pipeline_1004ed49807a


pipeline_1004ed49807a

In [90]:
println(model.stages(0))
println(model.stages(1))

vecAssembler_939c185f30a0
logreg_423e25e14ba5


In [91]:
import org.apache.spark.ml.classification.LogisticRegressionModel
val logRegModel = model.stages(1).asInstanceOf[LogisticRegressionModel]

logRegModel = logreg_423e25e14ba5


logreg_423e25e14ba5

#### 3-1. DEBUG

In [93]:
// coefficient(계수), intercept(절편)
println(s"coefficients: ${logRegModel.coefficientMatrix}")
println(s"intercepts: ${logRegModel.interceptVector}")

coefficients: -0.0054616405614571475  -2.4659403174938883  
intercepts: [1.2782309194083987]


In [96]:
// Objective Function 체크 (학습이 잘 되고 있는지)
import org.apache.spark.ml.classification.{BinaryLogisticRegressionSummary, LogisticRegression}
val trainingSummary = logRegModel.summary
val objectiveHistory = trainingSummary.objectiveHistory
println("objectiveHistory:")
objectiveHistory.foreach(loss => println(loss))

objectiveHistory:
0.6754313479380432
0.6331096467875821
0.5933719022991755
0.5462606025000956
0.542209931657562
0.532297514278512
0.5255861640910967
0.5251767386069403
0.525172397326206
0.5251722630081502
0.5251722562852661
0.5251722562159999


trainingSummary = org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary@44b439cd
objectiveHistory = Array(0.6754313479380432, 0.6331096467875821, 0.5933719022991755, 0.5462606025000956, 0.542209931657562, 0.532297514278512, 0.5255861640910967, 0.5251767386069403, 0.525172397326206, 0.5251722630081502, 0.5251722562852661, 0.5251722562159999)


Array(0.6754313479380432, 0.6331096467875821, 0.5933719022991755, 0.5462606025000956, 0.542209931657562, 0.532297514278512, 0.5255861640910967, 0.5251767386069403, 0.525172397326206, 0.5251722630081502, 0.5251722562852661, 0.5251722562159999)

In [97]:
// ROC Curve 정보
val binarySummary = trainingSummary.asInstanceOf[BinaryLogisticRegressionSummary]
val roc = binarySummary.roc
roc.show()
println(s"areaUnderROC: ${binarySummary.areaUnderROC}")

+--------------------+--------------------+
|                 FPR|                 TPR|
+--------------------+--------------------+
|                 0.0|                 0.0|
|                 0.0|0.006896551724137931|
|                 0.0|0.013793103448275862|
|0.009433962264150943|0.020689655172413793|
| 0.01179245283018868| 0.02413793103448276|
| 0.01179245283018868|0.041379310344827586|
| 0.01179245283018868| 0.05517241379310345|
|0.014150943396226415| 0.05862068965517241|
|0.014150943396226415| 0.06206896551724138|
| 0.01650943396226415| 0.06551724137931035|
|0.025943396226415096| 0.06551724137931035|
| 0.02830188679245283| 0.06551724137931035|
|0.030660377358490566| 0.06551724137931035|
|0.030660377358490566| 0.07241379310344828|
| 0.03537735849056604| 0.08275862068965517|
| 0.03537735849056604| 0.09655172413793103|
| 0.03773584905660377| 0.11379310344827587|
| 0.04009433962264151|  0.1310344827586207|
| 0.05188679245283019| 0.15862068965517243|
| 0.05188679245283019| 0.1827586

binarySummary = org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary@44b439cd
roc = [FPR: double, TPR: double]


[FPR: double, TPR: double]

In [84]:
import org.apache.spark.sql.functions._

val fMeasure = binarySummary.fMeasureByThreshold
val maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0)
val bestThreshold = fMeasure.where($"F-Measure" === maxFMeasure).select("threshold").head().getDouble(0)
logRegModel.setThreshold(bestThreshold)

fMeasure = [threshold: double, F-Measure: double]
maxFMeasure = 0.7414965986394558
bestThreshold = 0.2221382134617676


logreg_423e25e14ba5

### 4. Prediction

In [102]:
val result = model.transform(cdf)
result.show(10, false)

+---+------+--------+--------+-------------+----------+----------------------------------------+----------------------------------------+----------+
|Age|Sex   |Survived|SexIndex|SexVec       |features  |rawPrediction                           |probability                             |prediction|
+---+------+--------+--------+-------------+----------+----------------------------------------+----------------------------------------+----------+
|22 |male  |0       |0.0     |(1,[0],[1.0])|[22.0,1.0]|[1.3078654904375469,-1.3078654904375469]|[0.7871557559895479,0.212844244010452]  |0.0       |
|38 |female|1       |1.0     |(1,[],[])    |[38.0,0.0]|[-1.070688578073027,1.070688578073027]  |[0.2552721577194062,0.7447278422805937] |1.0       |
|26 |female|1       |1.0     |(1,[],[])    |[26.0,0.0]|[-1.1362282648105129,1.1362282648105129]|[0.24301352898468334,0.7569864710153167]|1.0       |
|35 |female|1       |1.0     |(1,[],[])    |[35.0,0.0]|[-1.0870734997573985,1.0870734997573985]|[0.2521697

result = [Age: int, Sex: string ... 7 more fields]


[Age: int, Sex: string ... 7 more fields]