In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
from pyspark.sql import SparkSession

spark = SparkSession.builder.appName('lin_reg').getOrCreate()
%matplotlib inline

In [2]:
data = spark.read.csv('../data/Ecommerce_Customers.csv',inferSchema = True,header=True)
print(data.count(),len(data.columns))

500 8


In [3]:
data.printSchema()

root
 |-- Email: string (nullable = true)
 |-- Address: string (nullable = true)
 |-- Avatar: string (nullable = true)
 |-- Avg Session Length: double (nullable = true)
 |-- Time on App: double (nullable = true)
 |-- Time on Website: double (nullable = true)
 |-- Length of Membership: double (nullable = true)
 |-- Yearly Amount Spent: double (nullable = true)



In [4]:
data.head(1)[0].asDict()

{'Email': 'mstephenson@fernandez.com',
 'Address': '835 Frank TunnelWrightmouth, MI 82180-9605',
 'Avatar': 'Violet',
 'Avg Session Length': 34.49726772511229,
 'Time on App': 12.65565114916675,
 'Time on Website': 39.57766801952616,
 'Length of Membership': 4.0826206329529615,
 'Yearly Amount Spent': 587.9510539684005}

In [5]:
from pyspark.ml.regression import LinearRegression
from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import VectorAssembler

In [6]:
print(data.columns)

['Email', 'Address', 'Avatar', 'Avg Session Length', 'Time on App', 'Time on Website', 'Length of Membership', 'Yearly Amount Spent']


In [7]:
assembler = VectorAssembler(inputCols=['Avg Session Length', 'Time on App', 'Time on Website', 'Length of Membership'],
                           outputCol='features')
output = assembler.transform(data)
print(output.limit(1).select('features').show())

+--------------------+
|            features|
+--------------------+
|[34.4972677251122...|
+--------------------+

None


In [8]:
dataset = output.select(['features','Yearly Amount Spent'])
dataset.show()

+--------------------+-------------------+
|            features|Yearly Amount Spent|
+--------------------+-------------------+
|[34.4972677251122...|  587.9510539684005|
|[31.9262720263601...|  392.2049334443264|
|[33.0009147556426...| 487.54750486747207|
|[34.3055566297555...|  581.8523440352177|
|[33.3306725236463...|  599.4060920457634|
|[33.8710378793419...|   637.102447915074|
|[32.0215955013870...|  521.5721747578274|
|[32.7391429383803...|  549.9041461052942|
|[33.9877728956856...|  570.2004089636196|
|[31.9365486184489...|  427.1993848953282|
|[33.9925727749537...|  492.6060127179966|
|[33.8793608248049...|  522.3374046069357|
|[29.5324289670579...|  408.6403510726275|
|[33.1903340437226...|  573.4158673313865|
|[32.3879758531538...|  470.4527333009554|
|[30.7377203726281...|  461.7807421962299|
|[32.1253868972878...| 457.84769594494855|
|[32.3388993230671...| 407.70454754954415|
|[32.1878120459321...|  452.3156754800354|
|[32.6178560628234...|   605.061038804892|
+----------

In [9]:
train_data,test_data = dataset.randomSplit([0.8,0.2])
train_data.count(),test_data.count()

(394, 106)

In [10]:
# ?LinearRegression

In [11]:
lr = LinearRegression(featuresCol = 'features',labelCol='Yearly Amount Spent',predictionCol='preds')
print(lr)

LinearRegression_5923c4439070


In [12]:
model = lr.fit(dataset)

In [13]:
print('Coefficients : ',model.coefficients)
print('Intercept : ',model.intercept)

Coefficients :  [25.734271084670716,38.709153810828816,0.43673883558514964,61.57732375487594]
Intercept :  -1051.5942552990748


In [14]:
summary = model.summary
print(summary.r2)
print(summary.rootMeanSquaredError)
summary.residuals.show()

0.9843155370226727
9.923256785022229
+-------------------+
|          residuals|
+-------------------+
| -6.788234090018818|
| 11.841128565326073|
| -17.65262700858966|
| 11.454889631178617|
| 7.7833824373080915|
|-1.8347332184773677|
|  4.620232401352382|
| -8.526545950978175|
| 11.012210896516763|
|-13.828032682158891|
| -16.04456458615175|
|  8.786634365463442|
| 10.425717191807507|
| 12.161293785003522|
|  9.989313714461446|
| 10.626662732649379|
|  20.15641408428496|
|-3.7708446586326545|
| -4.129505481591934|
|  9.206694655890487|
+-------------------+
only showing top 20 rows



In [16]:
data.describe().show()

+-------+-----------------+--------------------+-----------+------------------+------------------+------------------+--------------------+-------------------+
|summary|            Email|             Address|     Avatar|Avg Session Length|       Time on App|   Time on Website|Length of Membership|Yearly Amount Spent|
+-------+-----------------+--------------------+-----------+------------------+------------------+------------------+--------------------+-------------------+
|  count|              500|                 500|        500|               500|               500|               500|                 500|                500|
|   mean|             null|                null|       null| 33.05319351819619|12.052487937166134| 37.06044542094859|   3.533461555915055|  499.3140382585909|
| stddev|             null|                null|       null|0.9925631110845354|0.9942156084725424|1.0104889067564033|  0.9992775024112585|   79.3147815497068|
|    min|aaron04@yahoo.com|0001 Mack MillNor..