In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DecimalType, DoubleType, FloatType, ArrayType
# from decimal import Decimal

import pyspark.sql.functions as F

#for user defined funstions
from pyspark.sql.functions import udf, col

#for generating random numbers
from random import randint, random
#necessary for calculating root of discriminant
from math import sqrt

# Create Spark session
spark = SparkSession.builder \
    .appName("create df from dict") \
    .master("local") \
    .getOrCreate()

# instantiate Spark
spark = SparkSession.builder.getOrCreate()

#define column types
schema = StructType([
    StructField('a', DoubleType(), False),
    StructField('b', DoubleType(), False),
    StructField('c', DoubleType(), False),
    StructField('d', DoubleType(), False),
    StructField('x1', DoubleType(), True),
    StructField('x2', DoubleType(), True)
])

In [2]:
n = 2000000

#generate coefficents a,b,c for quadratic equation a*x^2 + b*x + c
vals = list()
for i in range(n):
    a = 20 * random() - 8
    b = 10 * random() + 4
    c = 6 * random()-7
    d = b ** 2 - 4 * a * c
    if d >= 0:
        x1 = (-b - sqrt(d)) / 2 / a
        x2 = (-b + sqrt(d)) / 2 / a
    elif d == 0:
        x1 = x2 = -b / 2 / a
    else:
        x1 = x2 = None
    vals.append((a,b,c,d,x1,x2))

In [3]:
# create DataFrame
df = spark.createDataFrame(vals, schema)
# df.show()

In [4]:
#calculate value of quadratic function
def f(x,a,b,c):
    return a * x ** 2 + b * x + c

#find root of quadratice equation in case it exists
#by Binary search
def solve_q_eq(left, right, a, b, c):
    eps = 1.e-8
    if f(right,a,b,c) >= 0:
        while abs(right - left) > eps:
            mid = (left + right) / 2
            tmp = f(mid,a,b,c)
            if tmp > 0:
                right = mid
            elif tmp < 0:
                left = mid
            else:
                break
    else:
        while abs(right - left) > eps:
            mid = (left + right) / 2
            f_left = f(left,a,b,c)
            tmp = f(mid,a,b,c)
            if tmp > 0:
                left = mid
            elif tmp < 0:
                right = mid
            else:
                break
    return (mid)

#use previously defined 'solve_q_eq' function for finding both roots
#of quadratic equation
def solve_q_eq_totally(a,b,c):
    left = -20
    right = 20
    d = b ** 2 - 4 * a * c
    if d >= 0:
        x1 = solve_q_eq(left, right, a, b, c)
        right = x1 - 0.01
        x2 = solve_q_eq(left, right, a, b, c)
    else:
        x1 = None
        x2 = None
    return (x1, x2)

In [5]:
# this schema is for outputs of function 'solve_udf' which is PySpark version
# of 'solve_q_eq_totally'
array_schema = StructType([
    StructField('x1_calc', DoubleType(), nullable=True),
    StructField('x2_calc', DoubleType(), nullable=True)
    ])

solve_udf = udf(lambda a,b,c: solve_q_eq_totally(a,b,c), array_schema)

#save df and found roots column to new df2 DataFrame
df2 = df.select('a', 'b', 'c','d', 'x1', 'x2', solve_udf('a', 'b', 'c').alias('roots'))

In [6]:
# df2.select('roots.x2_calc').show()

In [7]:
df2.write.parquet("quadr_eq")