In [1]:
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler

# Create a Spark session
spark = SparkSession.builder.appName("VectorAssemblerExample").getOrCreate()

23/12/18 12:37:15 WARN Utils: Your hostname, user-HP-EliteBook-840-G7-Notebook-PC resolves to a loopback address: 127.0.1.1; using 192.168.1.141 instead (on interface wlp0s20f3)


In [2]:
from pyspark.ml.regression import LinearRegression

# Sample data
data = [(-3.0965012, 5.2371198, -0.7370271),
        (-0.2100299, -0.7810844, -1.3284768),
        (8.3525083, 5.3337562, 21.8897181),
        (-3.0380369, 6.5357180, 0.3469820),
        (5.9354651, 6.0223208, 17.9566144),
        (-6.8357707, 5.6629804, -8.1598308),
        (8.8919844, -2.5149762, 15.3622538),
        (6.3404984, 4.1778706, 16.7931822)]

columns = ["x1", "x2", "y"]

df = spark.createDataFrame(data, columns)
# Create a VectorAssembler
assembler = VectorAssembler(inputCols=["x1", "x2"], outputCol="features")

# Transform the DataFrame
df = assembler.transform(df)

# Check the df
df.show()

# Create a LinearRegression model
lr = LinearRegression(featuresCol="features", labelCol="y", maxIter=10, elasticNetParam=0.8)

# Fit the model to the data
lrModel = lr.fit(df)

# Print model coefficients and intercept
print(f"Coefficients: {lrModel.coefficients}")
print(f"Intercept: {lrModel.intercept}")

# Get training summary
trainingSummary = lrModel.summary

# Print summary statistics
print(f"numIterations: {trainingSummary.totalIterations}")
trainingSummary.residuals.show()
print(f"RMSE: {trainingSummary.rootMeanSquaredError}")
print(f"r2: {trainingSummary.r2}")

# Stop the Spark session
spark.stop()

+----------+----------+----------+--------------------+
|        x1|        x2|         y|            features|
+----------+----------+----------+--------------------+
|-3.0965012| 5.2371198|-0.7370271|[-3.0965012,5.237...|
|-0.2100299|-0.7810844|-1.3284768|[-0.2100299,-0.78...|
| 8.3525083| 5.3337562|21.8897181|[8.3525083,5.3337...|
|-3.0380369|  6.535718|  0.346982|[-3.0380369,6.535...|
| 5.9354651| 6.0223208|17.9566144|[5.9354651,6.0223...|
|-6.8357707| 5.6629804|-8.1598308|[-6.8357707,5.662...|
| 8.8919844|-2.5149762|15.3622538|[8.8919844,-2.514...|
| 6.3404984| 4.1778706|16.7931822|[6.3404984,4.1778...|
+----------+----------+----------+--------------------+

Coefficients: [2.0021398601708835,0.9962345581488143]
Intercept: -0.01921941939824674
numIterations: 0
objectiveHistory: [0.0]
+--------------------+
|           residuals|
+--------------------+
|  0.2644210690598129|
| -0.1106048738731169|
|-0.12762453162039478|
|-0.06233194035702172|
| 0.09255845281667163|
| -0.09609916374