In [89]:
from pyspark.ml.regression import LinearRegression

# Load training data
training = spark.read.format("csv")\
    .option("header", "true")\
    .load("Churn_Modelling.csv")

training = training.drop("Surname", "Geography", "Gender")
training = training.withColumnRenamed("RowNumber", "label")
training.show(10)



+-----+----------+-----------+---+------+---------+-------------+---------+--------------+---------------+------+
|label|CustomerId|CreditScore|Age|Tenure|  Balance|NumOfProducts|HasCrCard|IsActiveMember|EstimatedSalary|Exited|
+-----+----------+-----------+---+------+---------+-------------+---------+--------------+---------------+------+
|    1|  15634602|        619| 42|     2|        0|            1|        1|             1|      101348.88|     1|
|    2|  15647311|        608| 41|     1| 83807.86|            1|        0|             1|      112542.58|     0|
|    3|  15619304|        502| 42|     8| 159660.8|            3|        1|             0|      113931.57|     1|
|    4|  15701354|        699| 39|     1|        0|            2|        0|             0|       93826.63|     0|
|    5|  15737888|        850| 43|     2|125510.82|            1|        1|             1|        79084.1|     0|
|    6|  15574012|        645| 44|     8|113755.78|            2|        1|             

In [90]:
from pyspark.sql.types import IntegerType

cols = ["label", "CustomerId", "CreditScore", "Age", "Tenure", "Balance", "NumOfProducts", "HasCrCard", "IsActiveMember", "EstimatedSalary", "Exited"]

for i in cols:
    training = training.withColumn(i, training[i].cast(IntegerType()))

training.take(5)

[Row(label=1, CustomerId=15634602, CreditScore=619, Age=42, Tenure=2, Balance=0, NumOfProducts=1, HasCrCard=1, IsActiveMember=1, EstimatedSalary=101348, Exited=1),
 Row(label=2, CustomerId=15647311, CreditScore=608, Age=41, Tenure=1, Balance=83807, NumOfProducts=1, HasCrCard=0, IsActiveMember=1, EstimatedSalary=112542, Exited=0),
 Row(label=3, CustomerId=15619304, CreditScore=502, Age=42, Tenure=8, Balance=159660, NumOfProducts=3, HasCrCard=1, IsActiveMember=0, EstimatedSalary=113931, Exited=1),
 Row(label=4, CustomerId=15701354, CreditScore=699, Age=39, Tenure=1, Balance=0, NumOfProducts=2, HasCrCard=0, IsActiveMember=0, EstimatedSalary=93826, Exited=0),
 Row(label=5, CustomerId=15737888, CreditScore=850, Age=43, Tenure=2, Balance=125510, NumOfProducts=1, HasCrCard=1, IsActiveMember=1, EstimatedSalary=79084, Exited=0)]

In [91]:
from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import VectorAssembler

assembler = VectorAssembler(inputCols=["CreditScore", "Age", "Tenure", "Balance", "NumOfProducts", "HasCrCard", "IsActiveMember", "EstimatedSalary", "Exited"], outputCol="features")

output = assembler.transform(training)

In [92]:
lr = LinearRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8)

# Fit the model
lrModel = lr.fit(output)

# Print the coefficients and intercept for linear regression
print("Coefficients: %s" % str(lrModel.coefficients))
print("Intercept: %s" % str(lrModel.intercept))

# Summarize the model over the training set and print out some metrics
trainingSummary = lrModel.summary
print("numIterations: %d" % trainingSummary.totalIterations)
print("objectiveHistory: %s" % str(trainingSummary.objectiveHistory))
trainingSummary.residuals.show()
print("RMSE: %f" % trainingSummary.rootMeanSquaredError)
print("r2: %f" % trainingSummary.r2)

Coefficients: [0.15271031089535145,1.2625922971452483,-6.433142738839414,-0.00026656688692134554,23.979986883192446,3.6738767121073534,50.18854530055087,-0.0002770144260690374,-110.43125572988305]
Intercept: 4889.734693812233
numIterations: 9
objectiveHistory: [0.5000000000000002, 0.49988814163382367, 0.4997258643548936, 0.4997247099441776, 0.4997244131016744, 0.4997244102140144, 0.49972440868075785, 0.4997244086759044, 0.4997244086753518]
+-------------------+
|          residuals|
+-------------------+
|-4962.7612623715995|
| -5046.568307833587|
| -4856.020986340553|
|  -5057.25597620394|
| -5078.441831419957|
| -4855.716469639887|
| -5125.395261677535|
| -4875.470927119345|
| -5026.587014813095|
| -5027.505099515671|
|  -4958.46183527702|
|-4995.1072237132685|
| -4972.380964162884|
| -4954.061579678341|
| -5054.417489824353|
| -5047.497030601815|
| -4919.715100318681|
|-5025.8083109328045|
| -4958.615672943157|
| -5058.968796547549|
+-------------------+
only showing top 20 rows

RM