In [1]:
from __future__ import print_function

from pyspark.ml.regression import LinearRegression

from pyspark.sql import SparkSession
from pyspark.ml.linalg import Vectors

In [3]:
if __name__ == "__main__":
    
    # Create a SparkSession 
    spark = SparkSession.builder.appName("LinearRegression").getOrCreate()
    
    # Load up our data and convert it to the format MLlib expects
    inputLines = spark.sparkContext.textFile("regression.txt")
    data = inputLines.map(lambda x: x.split(",")).map(lambda x: (float(x[0]), Vectors.dense(float(x[1]))))
    
    # Convert this RDD to a DataFrame
    colNames = ["label", "features"]
    df = data.toDF(colNames)

    # Note, there are lots of cases where you can avoid going from an RDD to a DataFrame.
    # Perhaps you're importing data from a real database. Or you are using structured streaming
    # to get your data.

    # Let's split our data into training data and testing data
    trainTest = df.randomSplit([0.5, 0.5])
    trainingDF = trainTest[0]
    testDF = trainTest[1]

    # Now create our linear regression model
    lir = LinearRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8)

    # Train the model using our training data
    model = lir.fit(trainingDF)

    # Now see if we can predict values in our test data.
    # Generate predictions using our linear regression model for all features in our
    # test dataframe:
    fullPredictions = model.transform(testDF).cache()

    # Extract the predictions and the "known" correct labels.
    predictions = fullPredictions.select("prediction").rdd.map(lambda x: x[0])
    labels = fullPredictions.select("label").rdd.map(lambda x: x[0])

    # Zip them together
    predictionAndLabel = predictions.zip(labels).collect()

    # Print out the predicted and actual values for each point
    for prediction in predictionAndLabel:
      print(prediction)


    # Stop the session
    spark.stop()
    
    

(-1.8939202205922057, -2.36)
(-1.6925392918130109, -2.29)
(-1.6206175315347269, -2.26)
(-1.5774644753677565, -2.17)
(-1.368891370560733, -2.12)
(-1.4192366027555319, -2.09)
(-1.3976600746720467, -1.94)
(-1.3185461383659345, -1.91)
(-1.2466243780876505, -1.79)
(-1.1962791458928517, -1.77)
(-1.2178556739763369, -1.75)
(-1.1962791458928517, -1.74)
(-1.0524356253362839, -1.67)
(-1.3257383143937629, -1.64)
(-1.1747026178093665, -1.6)
(-1.1675104417815383, -1.59)
(-1.2106634979485085, -1.53)
(-1.0524356253362839, -1.47)
(-1.0164747451971419, -1.36)
(-0.8510546965570889, -1.3)
(-1.0596278013641123, -1.29)
(-0.8438625205292605, -1.26)
(-0.8654390486127457, -1.25)
(-0.8582468725849173, -1.23)
(-0.8654390486127457, -1.22)
(-0.9013999287518877, -1.17)
(-0.7215955280561779, -1.11)
(-0.8079016403901187, -1.08)
(-0.8438625205292605, -1.04)
(-0.613712887638752, -1.03)
(-0.7719407602509767, -1.03)
(-0.6640581198335508, -1.01)
(-0.7863251123066335, -0.99)
(-0.7647485842231483, -0.96)
(-0.58494418352743