In [1]:
from pyspark.sql import SparkSession

In [3]:
spark = SparkSession.builder.appName('lrec').getOrCreate()

In [4]:
from pyspark.ml.regression import LinearRegression

In [5]:
data = spark.read.csv('Ecommerce_Customers.csv', inferSchema=True, header=True)

In [6]:
data.printSchema()

root
 |-- Email: string (nullable = true)
 |-- Address: string (nullable = true)
 |-- Avatar: string (nullable = true)
 |-- Avg Session Length: double (nullable = true)
 |-- Time on App: double (nullable = true)
 |-- Time on Website: double (nullable = true)
 |-- Length of Membership: double (nullable = true)
 |-- Yearly Amount Spent: double (nullable = true)



In [7]:
from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import VectorAssembler

In [8]:
data.columns

['Email',
 'Address',
 'Avatar',
 'Avg Session Length',
 'Time on App',
 'Time on Website',
 'Length of Membership',
 'Yearly Amount Spent']

In [9]:
assembler = VectorAssembler(inputCols=['Avg Session Length','Time on App','Time on Website','Length of Membership'], 
                           outputCol='features')

In [10]:
output = assembler.transform(data)

In [11]:
output.printSchema()

root
 |-- Email: string (nullable = true)
 |-- Address: string (nullable = true)
 |-- Avatar: string (nullable = true)
 |-- Avg Session Length: double (nullable = true)
 |-- Time on App: double (nullable = true)
 |-- Time on Website: double (nullable = true)
 |-- Length of Membership: double (nullable = true)
 |-- Yearly Amount Spent: double (nullable = true)
 |-- features: vector (nullable = true)



In [12]:
final_data = output.select('features', 'Yearly Amount Spent')

In [13]:
train_data, test_data = final_data.randomSplit([0.7, 0.3])

In [14]:
train_data.describe().show()

+-------+-------------------+
|summary|Yearly Amount Spent|
+-------+-------------------+
|  count|                356|
|   mean| 502.36026393520916|
| stddev|  79.23799917339379|
|    min|   266.086340948469|
|    max|  765.5184619388373|
+-------+-------------------+



In [15]:
train_data.printSchema()

root
 |-- features: vector (nullable = true)
 |-- Yearly Amount Spent: double (nullable = true)



In [16]:
lr = LinearRegression(featuresCol='features', labelCol='Yearly Amount Spent')

In [17]:
lr_model = lr.fit(train_data)

In [18]:
test_results = lr_model.evaluate(test_data)

In [20]:
test_results.residuals.show()

+--------------------+
|           residuals|
+--------------------+
| -11.445939945854093|
|   -5.15821378510924|
|  0.6057396500099799|
|  -6.460782018019472|
|    3.80739357305481|
| -5.2381496038212845|
|   6.888761910002529|
| -3.6926471851822953|
|  0.6888809624192618|
| -2.3404741876775006|
| -16.982432600782943|
|  -4.704206057601709|
|0.061015682913364344|
|  -8.717066064615892|
|  16.617893601908463|
|   5.849469802428644|
|    5.46704501493457|
|    17.4277567084456|
|  -5.800217352999198|
|-0.30503602357214277|
+--------------------+
only showing top 20 rows



In [21]:
test_results.rootMeanSquaredError

8.981273594919186

In [22]:
test_results.r2

0.9870756596214079

In [23]:
unlabeled_data = test_data.select('features')

In [24]:
predictions = lr_model.transform(unlabeled_data)

In [25]:
predictions.show()

+--------------------+------------------+
|            features|        prediction|
+--------------------+------------------+
|[30.3931845423455...| 331.3748097490477|
|[30.4925366965402...|287.62945950502376|
|[30.5743636841713...| 441.4586741080557|
|[31.1280900496166...| 563.7134687650741|
|[31.3662121671876...| 426.7814889834301|
|[31.5147378578019...| 495.0506376002827|
|[31.6548096756927...|  468.374661817546|
|[31.8124825597242...| 396.5029921689795|
|[31.8293464559211...| 384.4634570255557|
|[31.8530748017465...| 461.6255976500295|
|[31.9048571310136...| 490.9322900235991|
|[31.9453957483445...| 661.7241299952536|
|[32.0047530203648...| 463.6849654377161|
|[32.0085045178551...| 451.9142870933713|
|[32.0180740106320...|341.16521714340684|
|[32.0305497162129...| 588.4250136161832|
|[32.0478009788678...| 507.9835261711619|
|[32.0609143984100...| 610.1755620045694|
|[32.0883806304482...| 517.9660837411573|
|[32.0961089938451...| 375.7034914338153|
+--------------------+------------