In [28]:
import pyspark

from pyspark.sql.functions import col, udf
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType,StructField,FloatType
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import DecisionTreeRegressor

spark = SparkSession.builder \
    .appName("Ml test") \
    .master("local[*]") \
    .config("spark.driver.memory", "10g") \
    .getOrCreate()

path = "E:/Data/данные по загруженности отделений/part-00000-231f14d4-e9c4-4ac3-b3ea-4e249e71b4a2-c000.csv"
delimiter = ","
encoding = "utf-8"
my_schema = StructType(fields=[
    StructField("day_of_the_week", FloatType()), 
    StructField("label", FloatType()),
    StructField("day_", FloatType()), 
    StructField("month_", FloatType()), 
    StructField("time_interval_start", FloatType()), 
    StructField("time_interval_end", FloatType())
])

df = spark.read \
    .option("header", True) \
    .option("delimiter", delimiter) \
    .option("encoding", encoding) \
    .csv('file:///' + path, schema=my_schema)

#df.printSchema()
#df.show(5)

#Создаем вектор по колонкам которые могут повлиять на результат
assembler = VectorAssembler(
    inputCols=["day_of_the_week", "day_", "month_", "time_interval_start", "time_interval_end"],
    outputCol="features")

df1 = assembler.transform(df)
#df1.show()

#создаем два набора тестовый и тот на котором будем обучать
df_test, df_train = df1.randomSplit([0.2, 0.8]) 

#df_train.show(3)

#модель которую будем использовать для предсказания Дерево регрессии
dt = DecisionTreeRegressor()
#обучаем
model = dt.fit(df_train)
#проверяем на тестовом наборе
prediction = model.transform(df_test);
prediction.show()


+---------------+-----+----+------+-------------------+-----------------+--------------------+-----------------+
|day_of_the_week|label|day_|month_|time_interval_start|time_interval_end|            features|       prediction|
+---------------+-----+----+------+-------------------+-----------------+--------------------+-----------------+
|            1.0|  0.0| 8.0|  12.0|               18.0|             19.0|[1.0,8.0,12.0,18....|              0.0|
|            1.0|  0.0| 8.0|  12.0|               19.0|             20.0|[1.0,8.0,12.0,19....|              0.0|
|            1.0|  0.0| 8.0|  12.0|               21.0|             22.0|[1.0,8.0,12.0,21....|              0.0|
|            1.0|  0.0|15.0|  12.0|                9.0|             10.0|[1.0,15.0,12.0,9....|8.785714285714286|
|            1.0|  0.0|15.0|  12.0|               22.0|             23.0|[1.0,15.0,12.0,22...|              0.0|
|            1.0|  0.0|22.0|  12.0|               14.0|             15.0|[1.0,22.0,12.0,14...|  