In [1]:
import findspark
findspark.init('/home/venkat/Downloads/spark-3.2.0-bin-hadoop3.2')
from pyspark.sql import SparkSession

spark = SparkSession.builder.appName('lr_example').getOrCreate()

21/11/28 20:48:00 WARN Utils: Your hostname, venkat-VirtualBox resolves to a loopback address: 127.0.1.1; using 10.0.2.15 instead (on interface enp0s3)
21/11/28 20:48:00 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
21/11/28 20:48:02 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [2]:
from pyspark.ml.regression import LinearRegression

In [4]:
data = spark.read.csv('Ecommerce_Customers.csv', inferSchema=True, header = True)

In [6]:
data.printSchema()

root
 |-- Email: string (nullable = true)
 |-- Address: string (nullable = true)
 |-- Avatar: string (nullable = true)
 |-- Avg Session Length: double (nullable = true)
 |-- Time on App: double (nullable = true)
 |-- Time on Website: double (nullable = true)
 |-- Length of Membership: double (nullable = true)
 |-- Yearly Amount Spent: double (nullable = true)



In [9]:
data.head(1)[0]

Row(Email='mstephenson@fernandez.com', Address='835 Frank TunnelWrightmouth, MI 82180-9605', Avatar='Violet', Avg Session Length=34.49726772511229, Time on App=12.65565114916675, Time on Website=39.57766801952616, Length of Membership=4.0826206329529615, Yearly Amount Spent=587.9510539684005)

In [10]:
from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import VectorAssembler

In [12]:
data.columns

['Email',
 'Address',
 'Avatar',
 'Avg Session Length',
 'Time on App',
 'Time on Website',
 'Length of Membership',
 'Yearly Amount Spent']

In [13]:
assembler = VectorAssembler(inputCols=['Avg Session Length','Time on App','Time on Website', 'Length of Membership'],
                           outputCol='features')


In [14]:
output = assembler.transform(data)

In [16]:
output.head(1) # note the features column

[Row(Email='mstephenson@fernandez.com', Address='835 Frank TunnelWrightmouth, MI 82180-9605', Avatar='Violet', Avg Session Length=34.49726772511229, Time on App=12.65565114916675, Time on Website=39.57766801952616, Length of Membership=4.0826206329529615, Yearly Amount Spent=587.9510539684005, features=DenseVector([34.4973, 12.6557, 39.5777, 4.0826]))]

In [18]:
final_data = output.select('features', 'Yearly Amount Spent')

In [19]:
train_data, test_data = final_data.randomSplit([0.7,0.3])

In [21]:
lr = LinearRegression(labelCol='Yearly Amount Spent', featuresCol='features')

In [23]:
lr_model = lr.fit(train_data)

21/11/28 20:55:23 WARN Instrumentation: [5cf611f0] regParam is zero, which might cause numerical instability and overfitting.


In [24]:
test_results = lr_model.evaluate(test_data)

In [26]:
test_results.residuals.show()

+-------------------+
|          residuals|
+-------------------+
|  8.107964284077525|
| -4.869613053370983|
| -6.970904013451047|
|-14.038236792610121|
| -6.130288672039853|
| -15.15954138930465|
|  -8.87022448520554|
| -6.055634363489446|
|-18.108582677412244|
|  6.371180350933116|
|  7.999929943677728|
|  -6.78314992276654|
|-10.041095715589506|
|    -17.98529623572|
| 10.775472016664594|
| 10.940977235129196|
|-6.6132503805926035|
|-18.614665111724207|
|-0.8702273495001123|
|  -9.16850600211933|
+-------------------+
only showing top 20 rows



In [27]:
test_results.rootMeanSquaredError

9.3008394412747

In [28]:
lr_model.coefficients

DenseVector([25.2451, 38.5868, 0.9014, 61.6917])

## Mimic unlabeled data

In [29]:
unlabeled_data = test_data.select('features')

In [30]:
unlabeled_data.show()

+--------------------+
|            features|
+--------------------+
|[29.5324289670579...|
|[30.8364326747734...|
|[31.0613251567161...|
|[31.0662181616375...|
|[31.5171218025062...|
|[31.5741380228732...|
|[31.7207699002873...|
|[31.7656188210424...|
|[31.8164283341993...|
|[31.8209982016720...|
|[31.8512531286083...|
|[31.8745516945853...|
|[31.8854062999117...|
|[31.9048571310136...|
|[31.9096268275227...|
|[31.9262720263601...|
|[31.9453957483445...|
|[31.9563005605233...|
|[32.0047530203648...|
|[32.0085045178551...|
+--------------------+
only showing top 20 rows



In [31]:
predictions = lr_model.transform(unlabeled_data)

In [32]:
predictions.show()

+--------------------+------------------+
|            features|        prediction|
+--------------------+------------------+
|[29.5324289670579...|400.53238678854996|
|[30.8364326747734...| 472.3715134803606|
|[31.0613251567161...|494.52636207135265|
|[31.0662181616375...| 462.9715300002845|
|[31.5171218025062...|282.04870932242557|
|[31.5741380228732...| 559.5688135498915|
|[31.7207699002873...| 547.6451579632285|
|[31.7656188210424...| 502.6097159990966|
|[31.8164283341993...| 519.2310741810686|
|[31.8209982016720...|418.30410066228023|
|[31.8512531286083...|464.99231672312067|
|[31.8745516945853...|399.06839416903404|
|[31.8854062999117...|  400.144368688065|
|[31.9048571310136...|491.93515365853614|
|[31.9096268275227...| 552.6705636565746|
|[31.9262720263601...| 381.2639562091972|
|[31.9453957483445...| 663.6331743182445|
|[31.9563005605233...|  565.740596858923|
|[32.0047530203648...|464.61620847012955|
|[32.0085045178551...| 452.3657270308747|
+--------------------+------------