In [2]:
from __future__ import print_function

from pyspark.ml.regression import LinearRegression

from pyspark.sql import SparkSession
from pyspark.ml.linalg import Vectors

if __name__ == "__main__":

    # Create a SparkSession (Note, the config section is only for Windows!)
    spark = SparkSession.builder.config("spark.sql.warehouse.dir", "file:///C:/temp").appName("LinearRegression").getOrCreate()

    # Load up our data and convert it to the format MLLib expects.
    inputLines = spark.sparkContext.textFile("./assets/regression.txt")
    data = inputLines.map(lambda x: x.split(",")).map(lambda x: (float(x[0]), Vectors.dense(float(x[1]))))

    # Convert this RDD to a DataFrame
    colNames = ["label", "features"]
    df = data.toDF(colNames)

    # Note, there are lots of cases where you can avoid going from an RDD to a DataFrame.
    # Perhaps you're importing data from a real database. Or you are using structured streaming
    # to get your data.

    # Let's split our data into training data and testing data
    trainTest = df.randomSplit([0.5, 0.5])
    trainingDF = trainTest[0]
    testDF = trainTest[1]

    # Now create our linear regression model
    lir = LinearRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8)

    # Train the model using our training data
    model = lir.fit(trainingDF)

    # Now see if we can predict values in our test data.
    # Generate predictions using our linear regression model for all features in our
    # test dataframe:
    fullPredictions = model.transform(testDF).cache()

    # Extract the predictions and the "known" correct labels.
    predictions = fullPredictions.select("prediction").rdd.map(lambda x: x[0])
    labels = fullPredictions.select("label").rdd.map(lambda x: x[0])

    # Zip them together
    predictionAndLabel = predictions.zip(labels).collect()

    # Print out the predicted and actual values for each point
    for prediction in predictionAndLabel:
      print(prediction)


    # Stop the session
    spark.stop()


(-1.731535346275294, -2.54)
(-1.7029488960013295, -2.29)
(-1.5886030949054708, -2.27)
(-1.4814039063781035, -2.07)
(-1.4385242309671564, -1.96)
(-1.4385242309671564, -1.94)
(-1.3527648801452625, -1.91)
(-1.374204717850736, -1.88)
(-1.3384716550082802, -1.8)
(-1.3384716550082802, -1.64)
(-1.1883927910699656, -1.6)
(-1.1383665030905277, -1.57)
(-1.2241258539124216, -1.53)
(-1.0097274768576867, -1.48)
(-1.0669003774056158, -1.47)
(-1.0311673145631601, -1.36)
(-1.0811936025425983, -1.33)
(-0.8382087752138987, -1.29)
(-0.8739418380563545, -1.27)
(-0.8882350631933369, -1.26)
(-0.8739418380563545, -1.23)
(-0.8167689375084254, -1.2)
(-0.9168215134673016, -1.16)
(-0.7881824872344607, -1.03)
(-0.802475712371443, -0.99)
(-0.673836686138602, -0.97)
(-0.7095697489810578, -0.97)
(-0.7810358746659696, -0.96)
(-0.623810398159164, -0.95)
(-0.7524494243920049, -0.94)
(-0.6523968484331286, -0.89)
(-0.6452502358646374, -0.87)
(-0.6166637855906727, -0.84)
(-0.6595434610016198, -0.81)
(-0.5952239478851993, 