In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.types import *
from pyspark.sql import functions as F
from pyspark.sql.types import *
from pyspark.sql.functions import udf, hour, mean, month, year, to_date
from pyspark.sql.window import Window
from tqdm.notebook import tqdm

In [2]:
!pip install tqdm



In [3]:
spark = SparkSession.builder.getOrCreate()

In [4]:
spark.conf.set('spark.sql.repl.eagerEval.enabled', True)

In [5]:
spark

# Schema de données

In [6]:
schema = StructType() \
    .add("STATION", IntegerType(), False) \
    .add("DATE", TimestampType(), False) \
    .add("SOURCE", IntegerType(), True) \
    .add("LATITUDE", FloatType(), True) \
    .add("LONGITUDE", FloatType(), True) \
    .add("ELEVATION", StringType(), True) \
    .add("NAME", StringType(), True) \
    .add("REPORT_TYPE", StringType(), True) \
    .add("CALL_SIGN", StringType(), True) \
    .add("QUALITY_CONTROL", StringType(), True) \
    .add("WND", StringType(), True) \
    .add("CIG", StringType(), True) \
    .add("VIS", StringType(), True) \
    .add("TMP", StringType(), True) \
    .add("DEW", StringType(), True) \
    .add("SLP", StringType(), True) \
    .add("GA1", StringType(), True) \
    .add("GA2", StringType(), True) \
    .add("GA3", StringType(), True) \
    .add("GA4", StringType(), True) \
    .add("GF1", StringType(), True) \
    .add("MA1", StringType(), True) \
    .add("MW1", StringType(), True) \
    .add("MW2", StringType(), True) \
    .add("MW3", StringType(), True) \
    .add("OC1", StringType(), True) \
    .add("REM", StringType(), True) \
    .add("EQD", StringType(), True)

# Chargement des données

In [7]:
# station_2018 = spark.read.load("./data/2018", format="csv", header=True, schema=schema, inferSchema=False)
# station_2018.show()

In [8]:
cols_of_interest = ("STATION","DATE","SOURCE","LATITUDE","LONGITUDE","ELEVATION","NAME","REPORT_TYPE","CALL_SIGN","QUALITY_CONTROL","WND","CIG","VIS","TMP","DEW","SLP")

all_stations = spark.read.load("./data/*", format="csv", header=True, schema=schema, inferSchema=False).select(*cols_of_interest)

In [9]:
all_stations

STATION,DATE,SOURCE,LATITUDE,LONGITUDE,ELEVATION,NAME,REPORT_TYPE,CALL_SIGN,QUALITY_CONTROL,WND,CIG,VIS,TMP,DEW,SLP
826099999,2008-01-01 00:00:00,4,0.0,0.0,0.0,WXPOD8270,FM-15,KFMH,V020,"280,1,N,0062,1","22000,1,9,N","024140,1,N,1",201,-401,101701
826099999,2008-01-01 00:53:00,4,0.0,0.0,0.0,WXPOD8270,FM-15,KCCR,V020,"100,1,N,0021,1","22000,1,9,N","016093,1,N,1",99999,-301,102311
826099999,2008-01-01 00:53:00,4,0.0,0.0,0.0,WXPOD8270,FM-16,KDAL,V020,"360,1,N,0057,1","22000,1,9,N","016093,1,N,1",99999,99999,999999
826099999,2008-01-01 01:00:00,4,0.0,0.0,0.0,WXPOD8270,FM-15,KFMH,V020,"270,1,N,0046,1","22000,1,9,N","024140,1,N,1",101,-401,101811
826099999,2008-01-01 01:53:00,4,0.0,0.0,0.0,WXPOD8270,AUTO,KCGI,V020,"300,1,N,0067,1","22000,1,9,N","016093,1,N,1",301,-401,999999
826099999,2008-01-01 01:53:00,4,0.0,0.0,0.0,WXPOD8270,FM-15,KCCR,V020,"100,1,N,0031,1","22000,1,9,N","016093,1,N,1",99999,-301,102291
826099999,2008-01-01 02:00:00,4,0.0,0.0,0.0,WXPOD8270,FM-15,KFMH,V020,"270,1,N,0051,1","22000,1,9,N","024140,1,N,1",101,-401,101921
826099999,2008-01-01 02:53:00,4,0.0,0.0,0.0,WXPOD8270,AUTO,KCGI,V020,"280,1,N,0041,1","22000,1,9,N","016093,1,N,1",201,-401,101901
826099999,2008-01-01 02:53:00,4,0.0,0.0,0.0,WXPOD8270,FM-15,KCCR,V020,"080,1,N,0041,1","22000,1,9,N","016093,1,N,1",99999,-301,102311
826099999,2008-01-01 03:00:00,4,0.0,0.0,0.0,WXPOD8270,SY-MT,KFMH,V020,"300,1,N,0041,1","22000,1,9,N","016000,1,N,1",171,-441,102031


In [10]:
# Nombre de lignes
all_stations.count()

9573561

## Supprimer les lignes du champ TMP avec des valeurs vides OU des +9999

In [11]:
all_stations = all_stations.na.drop(how="any", subset=["TMP"]).filter(~all_stations.TMP.contains("+9999"))

## Supprimer les lignes du champ ELEVATION avec des valeurs vides OU des +9999

In [12]:
all_stations = all_stations.na.drop(how="any", subset=["ELEVATION"]).filter(~all_stations.ELEVATION.contains("+9999"))

## Supprimer les lignes du champ DEW avec des valeurs vides OU des +9999

In [13]:
all_stations = all_stations.na.drop(how="any", subset=["DEW"]).filter(~all_stations.DEW.contains("+9999"))

In [14]:
# Nombre de lignes après le drop
all_stations.count()

9097319

In [15]:
# Le schema de données
all_stations.printSchema()

root
 |-- STATION: integer (nullable = true)
 |-- DATE: timestamp (nullable = true)
 |-- SOURCE: integer (nullable = true)
 |-- LATITUDE: float (nullable = true)
 |-- LONGITUDE: float (nullable = true)
 |-- ELEVATION: string (nullable = true)
 |-- NAME: string (nullable = true)
 |-- REPORT_TYPE: string (nullable = true)
 |-- CALL_SIGN: string (nullable = true)
 |-- QUALITY_CONTROL: string (nullable = true)
 |-- WND: string (nullable = true)
 |-- CIG: string (nullable = true)
 |-- VIS: string (nullable = true)
 |-- TMP: string (nullable = true)
 |-- DEW: string (nullable = true)
 |-- SLP: string (nullable = true)



## Split de la colonne température

In [16]:
@udf(returnType=FloatType())
def extract_tmp(tmp_col: str):
    return int(tmp_col.split(',')[0].lstrip('+')) / 10

all_stations = all_stations.withColumn('temperature', extract_tmp(all_stations['TMP']))

## Split de la colonne DEW

In [17]:
@udf(returnType=FloatType())
def extract_dew(dew_col: str):
    return int(dew_col.split(',')[0].lstrip('+')) / 10

all_stations = all_stations.withColumn('precipitation', extract_dew(all_stations['DEW']))

In [18]:
all_stations.printSchema()

root
 |-- STATION: integer (nullable = true)
 |-- DATE: timestamp (nullable = true)
 |-- SOURCE: integer (nullable = true)
 |-- LATITUDE: float (nullable = true)
 |-- LONGITUDE: float (nullable = true)
 |-- ELEVATION: string (nullable = true)
 |-- NAME: string (nullable = true)
 |-- REPORT_TYPE: string (nullable = true)
 |-- CALL_SIGN: string (nullable = true)
 |-- QUALITY_CONTROL: string (nullable = true)
 |-- WND: string (nullable = true)
 |-- CIG: string (nullable = true)
 |-- VIS: string (nullable = true)
 |-- TMP: string (nullable = true)
 |-- DEW: string (nullable = true)
 |-- SLP: string (nullable = true)
 |-- temperature: float (nullable = true)
 |-- precipitation: float (nullable = true)



## Création du champ season

In [19]:
@udf(returnType=StringType())
def create_season(month: int):
    if month in [7, 8, 9]:
        season = 'Summer'
    elif month in [10, 11, 12]:
        season = 'Autumn'
    elif month in [1, 2, 3]:
        season = 'Winter'
    else:
        season = 'Spring'
    return season

all_stations = all_stations.withColumn('season', create_season(month("DATE")))

In [20]:
all_stations.printSchema()

root
 |-- STATION: integer (nullable = true)
 |-- DATE: timestamp (nullable = true)
 |-- SOURCE: integer (nullable = true)
 |-- LATITUDE: float (nullable = true)
 |-- LONGITUDE: float (nullable = true)
 |-- ELEVATION: string (nullable = true)
 |-- NAME: string (nullable = true)
 |-- REPORT_TYPE: string (nullable = true)
 |-- CALL_SIGN: string (nullable = true)
 |-- QUALITY_CONTROL: string (nullable = true)
 |-- WND: string (nullable = true)
 |-- CIG: string (nullable = true)
 |-- VIS: string (nullable = true)
 |-- TMP: string (nullable = true)
 |-- DEW: string (nullable = true)
 |-- SLP: string (nullable = true)
 |-- temperature: float (nullable = true)
 |-- precipitation: float (nullable = true)
 |-- season: string (nullable = true)



In [21]:
all_stations

STATION,DATE,SOURCE,LATITUDE,LONGITUDE,ELEVATION,NAME,REPORT_TYPE,CALL_SIGN,QUALITY_CONTROL,WND,CIG,VIS,TMP,DEW,SLP,temperature,precipitation,season
826099999,2008-01-01 00:00:00,4,0.0,0.0,0.0,WXPOD8270,FM-15,KFMH,V020,"280,1,N,0062,1","22000,1,9,N","024140,1,N,1",201,-401,101701,2.0,-4.0,Winter
826099999,2008-01-01 01:00:00,4,0.0,0.0,0.0,WXPOD8270,FM-15,KFMH,V020,"270,1,N,0046,1","22000,1,9,N","024140,1,N,1",101,-401,101811,1.0,-4.0,Winter
826099999,2008-01-01 01:53:00,4,0.0,0.0,0.0,WXPOD8270,AUTO,KCGI,V020,"300,1,N,0067,1","22000,1,9,N","016093,1,N,1",301,-401,999999,3.0,-4.0,Winter
826099999,2008-01-01 02:00:00,4,0.0,0.0,0.0,WXPOD8270,FM-15,KFMH,V020,"270,1,N,0051,1","22000,1,9,N","024140,1,N,1",101,-401,101921,1.0,-4.0,Winter
826099999,2008-01-01 02:53:00,4,0.0,0.0,0.0,WXPOD8270,AUTO,KCGI,V020,"280,1,N,0041,1","22000,1,9,N","016093,1,N,1",201,-401,101901,2.0,-4.0,Winter
826099999,2008-01-01 03:00:00,4,0.0,0.0,0.0,WXPOD8270,SY-MT,KFMH,V020,"300,1,N,0041,1","22000,1,9,N","016000,1,N,1",171,-441,102031,1.7,-4.4,Winter
826099999,2008-01-01 03:53:00,4,0.0,0.0,0.0,WXPOD8270,AUTO,KCGI,V020,"280,1,N,0041,1","22000,1,9,N","016093,1,N,1",101,-301,102001,1.0,-3.0,Winter
826099999,2008-01-01 04:00:00,4,0.0,0.0,0.0,WXPOD8270,FM-15,KFMH,V020,"999,9,V,0010,1","22000,1,9,N","024140,1,N,1",-301,-501,101961,-3.0,-5.0,Winter
826099999,2008-01-01 04:53:00,4,0.0,0.0,0.0,WXPOD8270,AUTO,KCGI,V020,"260,1,N,0036,1","22000,1,9,N","016093,1,N,1",1,-301,102071,0.0,-3.0,Winter
826099999,2008-01-01 05:00:00,4,0.0,0.0,0.0,WXPOD8270,FM-15,KFMH,V020,"210,1,N,0015,1","22000,1,9,N","024140,1,N,1",-401,-501,101991,-4.0,-5.0,Winter


## Moyenne des températures et des précipitations par année/mois/journée/saison

In [22]:
# Par année
mean_tmp_dew_by_year = all_stations.groupBy([year("DATE").alias("year")]).agg(F.round(mean("temperature"), 2).alias("mean_tmp"), F.round(mean("precipitation"), 2).alias("mean_dew"))
mean_tmp_dew_by_year = mean_tmp_dew_by_year.sort("year")
mean_tmp_dew_by_year

year,mean_tmp,mean_dew
2000,6.68,3.4
2002,4.97,1.17
2004,5.13,1.37
2006,4.2,0.06
2008,3.51,-0.58
2010,2.22,-1.53
2012,2.64,-1.06
2014,2.62,-1.14
2016,2.64,-0.99
2018,2.03,-1.54


In [23]:
# Par mois
mean_tmp_dew_by_month = all_stations.groupBy([year("DATE").alias("year"), month("DATE").alias("month")]).agg(F.round(mean("temperature"), 2).alias("mean_tmp"), F.round(mean("precipitation"), 2).alias("mean_dew"))
mean_tmp_dew_by_month = mean_tmp_dew_by_month.sort("year", "month")
mean_tmp_dew_by_month

year,month,mean_tmp,mean_dew
2000,1,0.31,-2.23
2000,2,0.36,-2.22
2000,3,1.6,-2.19
2000,4,4.48,0.74
2000,5,9.46,4.29
2000,6,11.58,7.05
2000,7,13.81,10.12
2000,8,13.54,10.29
2000,9,10.82,7.47
2000,10,8.54,5.97


## Création du champ 'year'

In [24]:
all_stations = all_stations.withColumn("year", year(all_stations["DATE"]))
all_stations

STATION,DATE,SOURCE,LATITUDE,LONGITUDE,ELEVATION,NAME,REPORT_TYPE,CALL_SIGN,QUALITY_CONTROL,WND,CIG,VIS,TMP,DEW,SLP,temperature,precipitation,season,year
826099999,2008-01-01 00:00:00,4,0.0,0.0,0.0,WXPOD8270,FM-15,KFMH,V020,"280,1,N,0062,1","22000,1,9,N","024140,1,N,1",201,-401,101701,2.0,-4.0,Winter,2008
826099999,2008-01-01 01:00:00,4,0.0,0.0,0.0,WXPOD8270,FM-15,KFMH,V020,"270,1,N,0046,1","22000,1,9,N","024140,1,N,1",101,-401,101811,1.0,-4.0,Winter,2008
826099999,2008-01-01 01:53:00,4,0.0,0.0,0.0,WXPOD8270,AUTO,KCGI,V020,"300,1,N,0067,1","22000,1,9,N","016093,1,N,1",301,-401,999999,3.0,-4.0,Winter,2008
826099999,2008-01-01 02:00:00,4,0.0,0.0,0.0,WXPOD8270,FM-15,KFMH,V020,"270,1,N,0051,1","22000,1,9,N","024140,1,N,1",101,-401,101921,1.0,-4.0,Winter,2008
826099999,2008-01-01 02:53:00,4,0.0,0.0,0.0,WXPOD8270,AUTO,KCGI,V020,"280,1,N,0041,1","22000,1,9,N","016093,1,N,1",201,-401,101901,2.0,-4.0,Winter,2008
826099999,2008-01-01 03:00:00,4,0.0,0.0,0.0,WXPOD8270,SY-MT,KFMH,V020,"300,1,N,0041,1","22000,1,9,N","016000,1,N,1",171,-441,102031,1.7,-4.4,Winter,2008
826099999,2008-01-01 03:53:00,4,0.0,0.0,0.0,WXPOD8270,AUTO,KCGI,V020,"280,1,N,0041,1","22000,1,9,N","016093,1,N,1",101,-301,102001,1.0,-3.0,Winter,2008
826099999,2008-01-01 04:00:00,4,0.0,0.0,0.0,WXPOD8270,FM-15,KFMH,V020,"999,9,V,0010,1","22000,1,9,N","024140,1,N,1",-301,-501,101961,-3.0,-5.0,Winter,2008
826099999,2008-01-01 04:53:00,4,0.0,0.0,0.0,WXPOD8270,AUTO,KCGI,V020,"260,1,N,0036,1","22000,1,9,N","016093,1,N,1",1,-301,102071,0.0,-3.0,Winter,2008
826099999,2008-01-01 05:00:00,4,0.0,0.0,0.0,WXPOD8270,FM-15,KFMH,V020,"210,1,N,0015,1","22000,1,9,N","024140,1,N,1",-401,-501,101991,-4.0,-5.0,Winter,2008


## Longitude et latitude par station

In [44]:
annee = int(input("Quelle année? : "))
longitude_latitude_par_station = all_stations.filter(all_stations.year == annee).groupBy(["STATION"]).agg(mean("LATITUDE").alias("lat"), mean("LONGITUDE").alias("long"))
longitude_latitude_par_station.show()

Quelle année? : 2016
+----------+-----------------+------------------+
|   STATION|              lat|              long|
+----------+-----------------+------------------+
|1027099999| 69.6500015258789|18.899999618530273|
| 841899999|              3.0| 131.8000030517578|
|1037099999|69.38333129882812| 20.28333282470703|
|1006099999|            78.25|22.816665649414062|
|1001499999|59.79192352294922| 5.340849876403809|
|1083099999|70.86666870117188| 29.03333282470703|
|1068099999|71.01667022705078|25.983333587646484|
|1046099999| 69.7868423461914|20.959444046020508|
|1002099999|80.05000305175781|             16.25|
|1045099999|69.83333587646484|21.883333206176758|
|1059099999|70.06881713867188| 24.97348976135254|
|1074099999|71.03333282470703| 27.83333396911621|
|1082099999| 69.1500015258789|             29.25|
|1052099999|70.68333435058594|23.683332443237305|
|1023099999|69.05575561523438|18.540355682373047|
|1011099999|80.06666564941406|              31.5|
|1016099999|78.93333435058594

In [45]:
longitude_latitude_par_station.count()

47

In [46]:
filtre_norvege = longitude_latitude_par_station.filter(longitude_latitude_par_station.long >= 5).filter(longitude_latitude_par_station.long <= 12).filter(longitude_latitude_par_station.lat >= 57).filter(longitude_latitude_par_station.lat <= 65)
filtre_norvege

STATION,lat,long
1001499999,59.79192352294922,5.340849876403809
1023199999,64.3499984741211,7.800000190734863


In [47]:
# Par jour, pour l'année choisie et pour une station choisie
num_station = int(input("Quelle station ? : "))
mean_tmp_dew_by_day = all_stations.filter(all_stations.year == annee).filter(all_stations.STATION == num_station).groupBy([to_date("DATE").cast("date").alias("date")]).agg(F.round(mean("temperature"), 2).alias("mean_tmp"), F.round(mean("precipitation"), 2).alias("mean_dew"))
mean_tmp_dew_by_day = mean_tmp_dew_by_day.sort("date")
mean_tmp_dew_by_day.show(365)

Quelle station ? : 1023199999
+----------+--------+--------+
|      date|mean_tmp|mean_dew|
+----------+--------+--------+
|2016-01-01|    7.35|   -0.94|
|2016-01-02|    4.67|   -3.06|
|2016-01-03|    1.25|    -8.5|
|2016-01-04|   -2.52|  -11.04|
|2016-01-05|   -0.21|   -6.91|
|2016-01-06|    0.58|   -3.58|
|2016-01-07|   -0.83|   -8.59|
|2016-01-08|   -5.31|  -11.65|
|2016-01-09|   -2.94|  -12.48|
|2016-01-10|   -2.83|  -11.02|
|2016-01-11|    1.81|   -6.35|
|2016-01-12|   -1.34|   -8.94|
|2016-01-13|   -1.46|   -9.04|
|2016-01-14|   -1.96|  -11.64|
|2016-01-15|   -0.46|  -11.81|
|2016-01-16|    2.15|   -5.15|
|2016-01-17|     2.5|   -3.69|
|2016-01-18|    2.75|   -3.98|
|2016-01-19|    1.04|   -5.91|
|2016-01-20|     2.0|   -5.19|
|2016-01-21|    2.42|   -2.85|
|2016-01-22|    3.75|   -3.23|
|2016-01-23|    6.23|    3.58|
|2016-01-24|    6.76|     4.5|
|2016-01-25|    7.42|    6.36|
|2016-01-26|    5.78|   -4.64|
|2016-01-27|    5.34|   -7.55|
|2016-01-28|    4.37|   -8.76|
|2016-01-

In [48]:
mean_tmp_dew_by_day.count()

363

## Apprentissage

In [50]:
from pyspark.ml.regression import LinearRegression
from pyspark.ml.feature import VectorAssembler

# Load training data
training = mean_tmp_dew_by_day

vectorAssembler = VectorAssembler(inputCols = ["mean_tmp"], outputCol = 'features')
v_training = vectorAssembler.transform(training)
v_training = v_training.select(['features', 'mean_dew'])
v_training.show(3)




+--------+--------+
|features|mean_dew|
+--------+--------+
|  [7.35]|   -0.94|
|  [4.67]|   -3.06|
|  [1.25]|    -8.5|
+--------+--------+
only showing top 3 rows



In [51]:
splits = v_training.randomSplit([0.7, 0.3])
train_df = splits[0]
test_df = splits[1]

In [52]:
lr = LinearRegression(featuresCol = 'features', labelCol='mean_dew', maxIter=10, regParam=0.3, elasticNetParam=0.8)

# Fit the model
lrModel = lr.fit(train_df)

# Print the coefficients and intercept for linear regression
print("Coefficients: %s" % str(lrModel.coefficients))
print("Intercept: %s" % str(lrModel.intercept))

Coefficients: [1.1761226284214128]
Intercept: -5.844433868951441


In [57]:
lr_predictions = lrModel.transform(train_df)
lr_predictions.select("prediction","mean_dew","features").show()

+-------------------+--------+--------+
|         prediction|mean_dew|features|
+-------------------+--------+--------+
|  2.800067449945942|   -0.94|  [7.35]|
| -8.808262892573401|  -11.04| [-2.52]|
| -5.162282744467022|   -3.58|  [0.58]|
|  -9.17286090738404|  -11.02| [-2.83]|
| -7.561572906446703|   -9.04| [-1.46]|
| -7.420438191036134|   -8.94| [-1.34]|
| -6.385450278025291|  -11.81| [-0.46]|
|-3.3157702178454036|   -5.15|  [2.15]|
| -4.621266335393171|   -5.91|  [1.04]|
| -2.610096640792556|   -3.98|  [2.75]|
|-3.4921886121086154|   -5.19|   [2.0]|
|-1.4339740123711433|   -3.23|  [3.75]|
|  1.482810106113961|    3.58|  [6.23]|
|  0.953554923324325|   -4.64|  [5.78]|
|  2.882396033935441|    6.36|  [7.42]|
|-0.7047779827498672|   -8.76|  [4.37]|
|-0.5518820410550838|  -24.77|   [4.5]|
|-1.8691393848870659|  -20.19|  [3.38]|
|-1.6574373117712113|   -0.93|  [3.56]|
|-1.6339148592027835|   -0.83|  [3.58]|
+-------------------+--------+--------+
only showing top 20 rows



In [53]:
# Summarize the model over the training set and print out some metrics
trainingSummary = lrModel.summary
print("RMSE: %f" % trainingSummary.rootMeanSquaredError)
print("r2: %f" % trainingSummary.r2)

RMSE: 2.564847
r2: 0.788795


In [54]:
train_df.describe().show()

+-------+-----------------+
|summary|         mean_dew|
+-------+-----------------+
|  count|              260|
|   mean|3.616884615384616|
| stddev|5.591731822457132|
|    min|           -24.77|
|    max|            13.81|
+-------+-----------------+



In [55]:
lr_predictions = lrModel.transform(test_df)
lr_predictions.select("prediction","mean_dew","features").show()

from pyspark.ml.evaluation import RegressionEvaluator
lr_evaluator = RegressionEvaluator(predictionCol="prediction", \
                 labelCol="mean_dew",metricName="r2")
print("R Squared (R2) on test data = %g" % lr_evaluator.evaluate(lr_predictions))

+--------------------+--------+--------+
|          prediction|mean_dew|features|
+--------------------+--------+--------+
| -0.3519411942234436|   -3.06|  [4.67]|
| -4.3742805834246745|    -8.5|  [1.25]|
|  -6.091419620919938|   -6.91| [-0.21]|
| -12.089645025869142|  -11.65| [-5.31]|
|  -6.820615650541214|   -8.59| [-0.83]|
|  -9.302234396510395|  -12.48| [-2.94]|
|  -3.715651911508684|   -6.35|  [1.81]|
|   -8.14963422065741|  -11.64| [-1.96]|
|  -2.904127297897909|   -3.69|   [2.5]|
| -2.9982171081716222|   -2.85|  [2.42]|
|  2.1061550991773093|     4.5|  [6.76]|
|  0.4360609668189035|   -7.55|  [5.34]|
|  1.0241222810296096|   -7.24|  [5.84]|
| -2.0455577791502777|  -11.35|  [3.23]|
|-0.41074732564451377|   -8.47|  [4.62]|
|   -1.33988420209743|   -2.58|  [3.83]|
| -3.4451437069717588|   -5.02|  [2.04]|
| -3.4921886121086154|   -3.76|   [2.0]|
|  -2.139647589423991|   -2.24|  [3.15]|
| -1.6103924066343547|     0.6|   [3.6]|
+--------------------+--------+--------+
only showing top

In [56]:
test_result = lrModel.evaluate(test_df)
print("Root Mean Squared Error (RMSE) on test data = %g" % test_result.rootMeanSquaredError)

Root Mean Squared Error (RMSE) on test data = 2.44512


In [74]:
# Viz
# lr_predictions.select("prediction","mean_dew","features").toPandas()
# viz.plot.scatter("mean_dew","features")# 

[DenseVector([-0.12]),
 DenseVector([1.69]),
 DenseVector([1.85]),
 DenseVector([2.02]),
 DenseVector([3.84]),
 DenseVector([-4.71]),
 DenseVector([-4.15]),
 DenseVector([-1.31]),
 DenseVector([-1.12]),
 DenseVector([-0.1]),
 DenseVector([3.18]),
 DenseVector([-2.35]),
 DenseVector([-0.15]),
 DenseVector([1.86]),
 DenseVector([2.14]),
 DenseVector([2.14]),
 DenseVector([3.75]),
 DenseVector([-1.32]),
 DenseVector([0.15]),
 DenseVector([1.42]),
 DenseVector([2.1]),
 DenseVector([1.27]),
 DenseVector([1.33]),
 DenseVector([1.49]),
 DenseVector([2.23]),
 DenseVector([-0.43]),
 DenseVector([2.97]),
 DenseVector([3.18]),
 DenseVector([4.09]),
 DenseVector([5.37]),
 DenseVector([6.46]),
 DenseVector([6.64]),
 DenseVector([7.21]),
 DenseVector([7.41]),
 DenseVector([7.93]),
 DenseVector([8.12]),
 DenseVector([8.37]),
 DenseVector([7.18]),
 DenseVector([9.23]),
 DenseVector([9.89]),
 DenseVector([10.46]),
 DenseVector([10.64]),
 DenseVector([11.76]),
 DenseVector([9.1]),
 DenseVector([9.18]),


## Choses à faire


0) Faire valider le 1er modele par Sayf
1) Arrondir toutes les valeurs
2) Refaire le split (train/test)
3) D'autres algos : bataille d'I.A.
4) Voir d'autres prédictions prtinentes à faire