In [1]:
import pandas as pd
from pyspark.sql import SparkSession

### Preprocess

In [2]:
# 加载数据
df = pd.read_csv('./car_prices.csv')

# 检查所有列
print(df.info())

# 删除包含空值的行
df = df.dropna()

# 写回原文件
df.to_csv('./car_prices_without_null.csv', index=False)

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 558837 entries, 0 to 558836
Data columns (total 16 columns):
 #   Column        Non-Null Count   Dtype  
---  ------        --------------   -----  
 0   year          558837 non-null  int64  
 1   make          548536 non-null  object 
 2   model         548438 non-null  object 
 3   trim          548186 non-null  object 
 4   body          545642 non-null  object 
 5   transmission  493485 non-null  object 
 6   vin           558833 non-null  object 
 7   state         558837 non-null  object 
 8   condition     547017 non-null  float64
 9   odometer      558743 non-null  float64
 10  color         558088 non-null  object 
 11  interior      558088 non-null  object 
 12  seller        558837 non-null  object 
 13  mmr           558799 non-null  float64
 14  sellingprice  558825 non-null  float64
 15  saledate      558825 non-null  object 
dtypes: float64(4), int64(1), object(11)
memory usage: 68.2+ MB
None


In [3]:
spark = SparkSession.builder.appName('5003_project').getOrCreate()

your 131072x1 screen size is bogus. expect trouble


24/04/23 16:33:15 WARN Utils: Your hostname, SIMON_WANG resolves to a loopback address: 127.0.1.1; using 172.30.102.1 instead (on interface eth0)
24/04/23 16:33:15 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address


Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


24/04/23 16:33:16 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [4]:
data = spark.read.csv('./car_prices_without_null.csv', header=True, inferSchema=True)
data.show(3)

                                                                                

+----+----+--------+----------+-----+------------+-----------------+-----+---------+--------+-----+--------+--------------------+-------+------------+--------------------+
|year|make|   model|      trim| body|transmission|              vin|state|condition|odometer|color|interior|              seller|    mmr|sellingprice|            saledate|
+----+----+--------+----------+-----+------------+-----------------+-----+---------+--------+-----+--------+--------------------+-------+------------+--------------------+
|2015| Kia| Sorento|        LX|  SUV|   automatic|5xyktca69fg566472|   ca|      5.0| 16639.0|white|   black|kia motors americ...|20500.0|     21500.0|Tue Dec 16 2014 1...|
|2015| Kia| Sorento|        LX|  SUV|   automatic|5xyktca69fg561319|   ca|      5.0|  9393.0|white|   beige|kia motors americ...|20800.0|     21500.0|Tue Dec 16 2014 1...|
|2014| BMW|3 Series|328i SULEV|Sedan|   automatic|wba3c1c51ek116351|   ca|     45.0|  1331.0| gray|   black|financial service...|31900.0|   

In [5]:
data.printSchema()

root
 |-- year: integer (nullable = true)
 |-- make: string (nullable = true)
 |-- model: string (nullable = true)
 |-- trim: string (nullable = true)
 |-- body: string (nullable = true)
 |-- transmission: string (nullable = true)
 |-- vin: string (nullable = true)
 |-- state: string (nullable = true)
 |-- condition: double (nullable = true)
 |-- odometer: double (nullable = true)
 |-- color: string (nullable = true)
 |-- interior: string (nullable = true)
 |-- seller: string (nullable = true)
 |-- mmr: double (nullable = true)
 |-- sellingprice: double (nullable = true)
 |-- saledate: string (nullable = true)



### Data Cleaning

In [6]:
from pyspark.sql.functions import col
from pyspark.sql.types import IntegerType
from pyspark.ml.feature import StringIndexer

# 创建StringIndexer对象
stringIndexer = StringIndexer(inputCol="make", outputCol="make_index")

# 使用StringIndexer对象进行转换
indexedData = stringIndexer.fit(data).transform(data)

# 将"make_index"列转换为整数类型
indexedData = indexedData.withColumn("make_index", col("make_index").cast(IntegerType()))

# 替换"make"列为"make_index"列
indexedData = indexedData.withColumn("make", col("make_index")).drop("make_index")

# 查看转换后的数据
data = indexedData
data.show()

                                                                                

+----+----+-------------------+--------------------+-----------+------------+-----------------+-----+---------+--------+------+--------+--------------------+-------+------------+--------------------+
|year|make|              model|                trim|       body|transmission|              vin|state|condition|odometer| color|interior|              seller|    mmr|sellingprice|            saledate|
+----+----+-------------------+--------------------+-----------+------------+-----------------+-----+---------+--------+------+--------+--------------------+-------+------------+--------------------+
|2015|   8|            Sorento|                  LX|        SUV|   automatic|5xyktca69fg566472|   ca|      5.0| 16639.0| white|   black|kia motors americ...|20500.0|     21500.0|Tue Dec 16 2014 1...|
|2015|   8|            Sorento|                  LX|        SUV|   automatic|5xyktca69fg561319|   ca|      5.0|  9393.0| white|   beige|kia motors americ...|20800.0|     21500.0|Tue Dec 16 2014 1...|


In [7]:
# 处理state列
stringIndexer = StringIndexer(inputCol="state", outputCol="state_index")
indexedData = stringIndexer.fit(data).transform(data)
indexedData = indexedData.withColumn("state_index", col("state_index").cast(IntegerType()))
indexedData = indexedData.withColumn("state", col("state_index")).drop("state_index")

data = indexedData
data.show()

+----+----+-------------------+--------------------+-----------+------------+-----------------+-----+---------+--------+------+--------+--------------------+-------+------------+--------------------+
|year|make|              model|                trim|       body|transmission|              vin|state|condition|odometer| color|interior|              seller|    mmr|sellingprice|            saledate|
+----+----+-------------------+--------------------+-----------+------------+-----------------+-----+---------+--------+------+--------+--------------------+-------+------------+--------------------+
|2015|   8|            Sorento|                  LX|        SUV|   automatic|5xyktca69fg566472|    1|      5.0| 16639.0| white|   black|kia motors americ...|20500.0|     21500.0|Tue Dec 16 2014 1...|
|2015|   8|            Sorento|                  LX|        SUV|   automatic|5xyktca69fg561319|    1|      5.0|  9393.0| white|   beige|kia motors americ...|20800.0|     21500.0|Tue Dec 16 2014 1...|


In [8]:
# 处理model列
stringIndexer = StringIndexer(inputCol="model", outputCol="model_index")
indexedData = stringIndexer.fit(data).transform(data)
indexedData = indexedData.withColumn("model_index", col("model_index").cast(IntegerType()))
indexedData = indexedData.withColumn("model", col("model_index")).drop("model_index")

data = indexedData
data.show()

+----+----+-----+--------------------+-----------+------------+-----------------+-----+---------+--------+------+--------+--------------------+-------+------------+--------------------+
|year|make|model|                trim|       body|transmission|              vin|state|condition|odometer| color|interior|              seller|    mmr|sellingprice|            saledate|
+----+----+-----+--------------------+-----------+------------+-----------------+-----+---------+--------+------+--------+--------------------+-------+------------+--------------------+
|2015|   8|   40|                  LX|        SUV|   automatic|5xyktca69fg566472|    1|      5.0| 16639.0| white|   black|kia motors americ...|20500.0|     21500.0|Tue Dec 16 2014 1...|
|2015|   8|   40|                  LX|        SUV|   automatic|5xyktca69fg561319|    1|      5.0|  9393.0| white|   beige|kia motors americ...|20800.0|     21500.0|Tue Dec 16 2014 1...|
|2014|   7|    8|          328i SULEV|      Sedan|   automatic|wba3c1c

In [9]:
# 处理transmission
stringIndexer = StringIndexer(inputCol="transmission", outputCol="transmission_index")
indexedData = stringIndexer.fit(data).transform(data)
indexedData = indexedData.withColumn("transmission_index", col("transmission_index").cast(IntegerType()))
indexedData = indexedData.withColumn("transmission", col("transmission_index")).drop("transmission_index")

data = indexedData
data.show()

+----+----+-----+--------------------+-----------+------------+-----------------+-----+---------+--------+------+--------+--------------------+-------+------------+--------------------+
|year|make|model|                trim|       body|transmission|              vin|state|condition|odometer| color|interior|              seller|    mmr|sellingprice|            saledate|
+----+----+-----+--------------------+-----------+------------+-----------------+-----+---------+--------+------+--------+--------------------+-------+------------+--------------------+
|2015|   8|   40|                  LX|        SUV|           0|5xyktca69fg566472|    1|      5.0| 16639.0| white|   black|kia motors americ...|20500.0|     21500.0|Tue Dec 16 2014 1...|
|2015|   8|   40|                  LX|        SUV|           0|5xyktca69fg561319|    1|      5.0|  9393.0| white|   beige|kia motors americ...|20800.0|     21500.0|Tue Dec 16 2014 1...|
|2014|   7|    8|          328i SULEV|      Sedan|           0|wba3c1c

In [10]:
# 处理trim
stringIndexer = StringIndexer(inputCol="trim", outputCol="trim_index")
indexedData = stringIndexer.fit(data).transform(data)
indexedData = indexedData.withColumn("trim_index", col("trim_index").cast(IntegerType()))
indexedData = indexedData.withColumn("trim", col("trim_index")).drop("trim_index")

data = indexedData
data.show()

+----+----+-----+----+-----------+------------+-----------------+-----+---------+--------+------+--------+--------------------+-------+------------+--------------------+
|year|make|model|trim|       body|transmission|              vin|state|condition|odometer| color|interior|              seller|    mmr|sellingprice|            saledate|
+----+----+-----+----+-----------+------------+-----------------+-----+---------+--------+------+--------+--------------------+-------+------------+--------------------+
|2015|   8|   40|   2|        SUV|           0|5xyktca69fg566472|    1|      5.0| 16639.0| white|   black|kia motors americ...|20500.0|     21500.0|Tue Dec 16 2014 1...|
|2015|   8|   40|   2|        SUV|           0|5xyktca69fg561319|    1|      5.0|  9393.0| white|   beige|kia motors americ...|20800.0|     21500.0|Tue Dec 16 2014 1...|
|2014|   7|    8| 636|      Sedan|           0|wba3c1c51ek116351|    1|     45.0|  1331.0|  gray|   black|financial service...|31900.0|     30000.0|Th

In [11]:
# 处理body
stringIndexer = StringIndexer(inputCol="body", outputCol="body_index")
indexedData = stringIndexer.fit(data).transform(data)
indexedData = indexedData.withColumn("body_index", col("body_index").cast(IntegerType()))
indexedData = indexedData.withColumn("body", col("body_index")).drop("body_index")

data = indexedData
data.show()

+----+----+-----+----+----+------------+-----------------+-----+---------+--------+------+--------+--------------------+-------+------------+--------------------+
|year|make|model|trim|body|transmission|              vin|state|condition|odometer| color|interior|              seller|    mmr|sellingprice|            saledate|
+----+----+-----+----+----+------------+-----------------+-----+---------+--------+------+--------+--------------------+-------+------------+--------------------+
|2015|   8|   40|   2|   1|           0|5xyktca69fg566472|    1|      5.0| 16639.0| white|   black|kia motors americ...|20500.0|     21500.0|Tue Dec 16 2014 1...|
|2015|   8|   40|   2|   1|           0|5xyktca69fg561319|    1|      5.0|  9393.0| white|   beige|kia motors americ...|20800.0|     21500.0|Tue Dec 16 2014 1...|
|2014|   7|    8| 636|   0|           0|wba3c1c51ek116351|    1|     45.0|  1331.0|  gray|   black|financial service...|31900.0|     30000.0|Thu Jan 15 2015 0...|
|2015|  26|  127|  91|

In [12]:
# 处理color
stringIndexer = StringIndexer(inputCol="color", outputCol="color_index")
indexedData = stringIndexer.fit(data).transform(data)
indexedData = indexedData.withColumn("color_index", col("color_index").cast(IntegerType()))
indexedData = indexedData.withColumn("color", col("color_index")).drop("color_index")

data = indexedData
data.show()

+----+----+-----+----+----+------------+-----------------+-----+---------+--------+-----+--------+--------------------+-------+------------+--------------------+
|year|make|model|trim|body|transmission|              vin|state|condition|odometer|color|interior|              seller|    mmr|sellingprice|            saledate|
+----+----+-----+----+----+------------+-----------------+-----+---------+--------+-----+--------+--------------------+-------+------------+--------------------+
|2015|   8|   40|   2|   1|           0|5xyktca69fg566472|    1|      5.0| 16639.0|    1|   black|kia motors americ...|20500.0|     21500.0|Tue Dec 16 2014 1...|
|2015|   8|   40|   2|   1|           0|5xyktca69fg561319|    1|      5.0|  9393.0|    1|   beige|kia motors americ...|20800.0|     21500.0|Tue Dec 16 2014 1...|
|2014|   7|    8| 636|   0|           0|wba3c1c51ek116351|    1|     45.0|  1331.0|    3|   black|financial service...|31900.0|     30000.0|Thu Jan 15 2015 0...|
|2015|  26|  127|  91|   0| 

In [13]:
# 处理interior
stringIndexer = StringIndexer(inputCol="interior", outputCol="interior_index")
indexedData = stringIndexer.fit(data).transform(data)
indexedData = indexedData.withColumn("interior_index", col("interior_index").cast(IntegerType()))
indexedData = indexedData.withColumn("interior", col("interior_index")).drop("interior_index")

data = indexedData
data.show()

+----+----+-----+----+----+------------+-----------------+-----+---------+--------+-----+--------+--------------------+-------+------------+--------------------+
|year|make|model|trim|body|transmission|              vin|state|condition|odometer|color|interior|              seller|    mmr|sellingprice|            saledate|
+----+----+-----+----+----+------------+-----------------+-----+---------+--------+-----+--------+--------------------+-------+------------+--------------------+
|2015|   8|   40|   2|   1|           0|5xyktca69fg566472|    1|      5.0| 16639.0|    1|       0|kia motors americ...|20500.0|     21500.0|Tue Dec 16 2014 1...|
|2015|   8|   40|   2|   1|           0|5xyktca69fg561319|    1|      5.0|  9393.0|    1|       2|kia motors americ...|20800.0|     21500.0|Tue Dec 16 2014 1...|
|2014|   7|    8| 636|   0|           0|wba3c1c51ek116351|    1|     45.0|  1331.0|    3|       0|financial service...|31900.0|     30000.0|Thu Jan 15 2015 0...|
|2015|  26|  127|  91|   0| 

In [14]:
# 处理seller
stringIndexer = StringIndexer(inputCol="seller", outputCol="seller_index")
indexedData = stringIndexer.fit(data).transform(data)
indexedData = indexedData.withColumn("seller_index", col("seller_index").cast(IntegerType()))
indexedData = indexedData.withColumn("seller", col("seller_index")).drop("seller_index")

data = indexedData
data.show()

+----+----+-----+----+----+------------+-----------------+-----+---------+--------+-----+--------+------+-------+------------+--------------------+
|year|make|model|trim|body|transmission|              vin|state|condition|odometer|color|interior|seller|    mmr|sellingprice|            saledate|
+----+----+-----+----+----+------------+-----------------+-----+---------+--------+-----+--------+------+-------+------------+--------------------+
|2015|   8|   40|   2|   1|           0|5xyktca69fg566472|    1|      5.0| 16639.0|    1|       0|    23|20500.0|     21500.0|Tue Dec 16 2014 1...|
|2015|   8|   40|   2|   1|           0|5xyktca69fg561319|    1|      5.0|  9393.0|    1|       2|    23|20800.0|     21500.0|Tue Dec 16 2014 1...|
|2014|   7|    8| 636|   0|           0|wba3c1c51ek116351|    1|     45.0|  1331.0|    3|       0|    14|31900.0|     30000.0|Thu Jan 15 2015 0...|
|2015|  26|  127|  91|   0|           0|yv1612tb4f1310987|    1|     41.0| 14282.0|    1|       0|   123|27500.0

In [15]:
# 显示"saledate"列的前10条数据
saledate_data = data.select("saledate").head(10)

# 打印前10条数据
for row in saledate_data:
    print(row.saledate)

Tue Dec 16 2014 12:30:00 GMT-0800 (PST)
Tue Dec 16 2014 12:30:00 GMT-0800 (PST)
Thu Jan 15 2015 04:30:00 GMT-0800 (PST)
Thu Jan 29 2015 04:30:00 GMT-0800 (PST)
Thu Dec 18 2014 12:30:00 GMT-0800 (PST)
Tue Dec 30 2014 12:00:00 GMT-0800 (PST)
Wed Dec 17 2014 12:30:00 GMT-0800 (PST)
Tue Dec 16 2014 13:00:00 GMT-0800 (PST)
Thu Dec 18 2014 12:00:00 GMT-0800 (PST)
Tue Jan 20 2015 04:00:00 GMT-0800 (PST)


In [16]:
from pyspark.sql.functions import substring, unix_timestamp

# 将"timestamp_string"列转换为Unix时间戳
data = data.withColumn("timestamp_string", substring("saledate", 5, 20))
data = data.withColumn("timestamp", unix_timestamp("timestamp_string", "MMM dd yyyy HH:mm:ss"))

data = data.drop('saledate')
data = data.drop('timestamp_string')
data = data.drop('vin')

# 查看转换后的数据
data.show()

+----+----+-----+----+----+------------+-----+---------+--------+-----+--------+------+-------+------------+----------+
|year|make|model|trim|body|transmission|state|condition|odometer|color|interior|seller|    mmr|sellingprice| timestamp|
+----+----+-----+----+----+------------+-----+---------+--------+-----+--------+------+-------+------------+----------+
|2015|   8|   40|   2|   1|           0|    1|      5.0| 16639.0|    1|       0|    23|20500.0|     21500.0|1418704200|
|2015|   8|   40|   2|   1|           0|    1|      5.0|  9393.0|    1|       2|    23|20800.0|     21500.0|1418704200|
|2014|   7|    8| 636|   0|           0|    1|     45.0|  1331.0|    3|       0|    14|31900.0|     30000.0|1421267400|
|2015|  26|  127|  91|   0|           0|    1|     41.0| 14282.0|    1|       0|   123|27500.0|     27750.0|1422477000|
|2014|   7|  398| 139|   0|           0|    1|     43.0|  2641.0|    3|       0|    14|66000.0|     67000.0|1418877000|
|2015|   2|    0|  12|   0|           0|

In [17]:
# 将"timestamp"列丢弃末两位
data = data.withColumn("timestamp_str", data["timestamp"].cast("string"))
data = data.withColumn("timestamp_str_truncated", substring("timestamp_str", 0, 8))
data = data.withColumn("timestamp_int", data["timestamp_str_truncated"].cast("integer"))

data = data.drop('timestamp_str')
data = data.drop('timestamp_str_truncated')
data = data.drop('timestamp')
data = data.withColumnRenamed("timestamp_int", "timestamp")


data.show()

+----+----+-----+----+----+------------+-----+---------+--------+-----+--------+------+-------+------------+---------+
|year|make|model|trim|body|transmission|state|condition|odometer|color|interior|seller|    mmr|sellingprice|timestamp|
+----+----+-----+----+----+------------+-----+---------+--------+-----+--------+------+-------+------------+---------+
|2015|   8|   40|   2|   1|           0|    1|      5.0| 16639.0|    1|       0|    23|20500.0|     21500.0| 14187042|
|2015|   8|   40|   2|   1|           0|    1|      5.0|  9393.0|    1|       2|    23|20800.0|     21500.0| 14187042|
|2014|   7|    8| 636|   0|           0|    1|     45.0|  1331.0|    3|       0|    14|31900.0|     30000.0| 14212674|
|2015|  26|  127|  91|   0|           0|    1|     41.0| 14282.0|    1|       0|   123|27500.0|     27750.0| 14224770|
|2014|   7|  398| 139|   0|           0|    1|     43.0|  2641.0|    3|       0|    14|66000.0|     67000.0| 14188770|
|2015|   2|    0|  12|   0|           0|    1|  

In [18]:
data.printSchema()
print(data.dtypes)

root
 |-- year: integer (nullable = true)
 |-- make: integer (nullable = true)
 |-- model: integer (nullable = true)
 |-- trim: integer (nullable = true)
 |-- body: integer (nullable = true)
 |-- transmission: integer (nullable = true)
 |-- state: integer (nullable = true)
 |-- condition: double (nullable = true)
 |-- odometer: double (nullable = true)
 |-- color: integer (nullable = true)
 |-- interior: integer (nullable = true)
 |-- seller: integer (nullable = true)
 |-- mmr: double (nullable = true)
 |-- sellingprice: double (nullable = true)
 |-- timestamp: integer (nullable = true)

[('year', 'int'), ('make', 'int'), ('model', 'int'), ('trim', 'int'), ('body', 'int'), ('transmission', 'int'), ('state', 'int'), ('condition', 'double'), ('odometer', 'double'), ('color', 'int'), ('interior', 'int'), ('seller', 'int'), ('mmr', 'double'), ('sellingprice', 'double'), ('timestamp', 'int')]


In [19]:
# 1. 检查数据中是否存在 null 值
null_cols = [col for col in data.columns if data.select(col).where(data[col].isNull()).count() > 0]
if null_cols:
    print(f"Null values found in columns: {', '.join(null_cols)}")

In [20]:
from pyspark.ml.feature import VectorAssembler

assembler = VectorAssembler(inputCols = ['year', 'make', 'model', 'trim','body', 'transmission', 'state', 'condition', 'odometer', 'color', 'interior', 'seller', 'mmr', 'timestamp'], outputCol = 'features')
output = assembler.transform(data)
output.take(1)

24/04/23 16:33:30 WARN package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.


[Row(year=2015, make=8, model=40, trim=2, body=1, transmission=0, state=1, condition=5.0, odometer=16639.0, color=1, interior=0, seller=23, mmr=20500.0, sellingprice=21500.0, timestamp=14187042, features=DenseVector([2015.0, 8.0, 40.0, 2.0, 1.0, 0.0, 1.0, 5.0, 16639.0, 1.0, 0.0, 23.0, 20500.0, 14187042.0]))]

In [21]:
output.printSchema()

root
 |-- year: integer (nullable = true)
 |-- make: integer (nullable = true)
 |-- model: integer (nullable = true)
 |-- trim: integer (nullable = true)
 |-- body: integer (nullable = true)
 |-- transmission: integer (nullable = true)
 |-- state: integer (nullable = true)
 |-- condition: double (nullable = true)
 |-- odometer: double (nullable = true)
 |-- color: integer (nullable = true)
 |-- interior: integer (nullable = true)
 |-- seller: integer (nullable = true)
 |-- mmr: double (nullable = true)
 |-- sellingprice: double (nullable = true)
 |-- timestamp: integer (nullable = true)
 |-- features: vector (nullable = true)



In [22]:
final_data = output.select(['features', 'sellingprice'])
final_data = final_data.dropna()
final_data.show(3, truncate=False)

+-------------------------------------------------------------------------------+------------+
|features                                                                       |sellingprice|
+-------------------------------------------------------------------------------+------------+
|[2015.0,8.0,40.0,2.0,1.0,0.0,1.0,5.0,16639.0,1.0,0.0,23.0,20500.0,1.4187042E7] |21500.0     |
|[2015.0,8.0,40.0,2.0,1.0,0.0,1.0,5.0,9393.0,1.0,2.0,23.0,20800.0,1.4187042E7]  |21500.0     |
|[2014.0,7.0,8.0,636.0,0.0,0.0,1.0,45.0,1331.0,3.0,0.0,14.0,31900.0,1.4212674E7]|30000.0     |
+-------------------------------------------------------------------------------+------------+
only showing top 3 rows



In [23]:
final_data.printSchema()

root
 |-- features: vector (nullable = true)
 |-- sellingprice: double (nullable = true)



In [None]:
final_data.describe().show()

### Train-Test Split

In [24]:
train, test = final_data.randomSplit([0.8,0.2])

### Utils

In [27]:
from pyspark.ml.evaluation import RegressionEvaluator
evaluator = RegressionEvaluator(labelCol="sellingprice", predictionCol="prediction")

def evaluate(pred):
    rmse = evaluator.evaluate(pred, {evaluator.metricName: "rmse"})
    r2 = evaluator.evaluate(pred, {evaluator.metricName: "r2"})

    return rmse, r2


### Linear Regression

In [26]:
from pyspark.ml.regression import LinearRegression

lr = LinearRegression(featuresCol='features', labelCol='sellingprice')
lr_model = lr.fit(train)

24/04/23 16:56:32 WARN Instrumentation: [2b8006f0] regParam is zero, which might cause numerical instability and overfitting.


[Stage 89:>                                                       (0 + 19) / 19]

24/04/23 16:56:34 WARN InstanceBuilder$NativeBLAS: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS
24/04/23 16:56:34 WARN InstanceBuilder$NativeBLAS: Failed to load implementation from:dev.ludovic.netlib.blas.ForeignLinkerBLAS


                                                                                

24/04/23 16:56:34 WARN InstanceBuilder$NativeLAPACK: Failed to load implementation from:dev.ludovic.netlib.lapack.JNILAPACK


                                                                                

In [30]:
lr_predictions = lr_model.transform(test)
lr_predictions.show(10)

+--------------------+------------+------------------+
|            features|sellingprice|        prediction|
+--------------------+------------+------------------+
|(14,[0,1,2,7,8,11...|     13800.0| 13127.38850515896|
|(14,[0,1,2,7,8,11...|     17800.0|16324.811220911142|
|(14,[0,2,3,6,7,8,...|     13400.0|13387.499096099142|
|(14,[0,2,3,6,7,8,...|     15900.0|15168.139189288602|
|(14,[0,2,3,6,7,8,...|     17600.0|17693.486376644563|
|(14,[0,2,3,6,7,8,...|     19300.0|18718.029868359416|
|(14,[0,2,3,6,7,8,...|      9100.0| 9651.786257843516|
|(14,[0,2,3,7,8,9,...|     11200.0|12847.570522258291|
|(14,[0,2,3,7,8,9,...|     19200.0|17031.707015257605|
|(14,[0,2,3,7,8,11...|     11500.0|12226.852300727638|
+--------------------+------------+------------------+
only showing top 10 rows



In [32]:
lr_rmse, lr_r2 = evaluate(lr_predictions)
lr_rmse, lr_r2

                                                                                

(1652.3763049118988, 0.9704001648409633)

### Decision Tree

In [33]:
from pyspark.ml.regression import DecisionTreeRegressor

dt = DecisionTreeRegressor(featuresCol='features', labelCol='sellingprice')
dt_model = dt.fit(train)

                                                                                

In [34]:
dt_predictions = dt_model.transform(test)
dt_predictions.show(10)

+--------------------+------------+------------------+
|            features|sellingprice|        prediction|
+--------------------+------------+------------------+
|(14,[0,1,2,7,8,11...|     13800.0|12137.778249882595|
|(14,[0,1,2,7,8,11...|     17800.0|16497.701501771753|
|(14,[0,2,3,6,7,8,...|     13400.0|12137.778249882595|
|(14,[0,2,3,6,7,8,...|     15900.0| 14114.52932238553|
|(14,[0,2,3,6,7,8,...|     17600.0|16497.701501771753|
|(14,[0,2,3,6,7,8,...|     19300.0|18886.148238831614|
|(14,[0,2,3,6,7,8,...|      9100.0| 9322.025377172844|
|(14,[0,2,3,7,8,9,...|     11200.0|12137.778249882595|
|(14,[0,2,3,7,8,9,...|     19200.0|16497.701501771753|
|(14,[0,2,3,7,8,11...|     11500.0|12137.778249882595|
+--------------------+------------+------------------+
only showing top 10 rows



In [36]:
dt_rmse, dt_r2 = evaluate(dt_predictions)
dt_rmse, dt_r2

                                                                                

(2487.757450980491, 0.9329054187593008)

### Random Forest

In [40]:
from pyspark.ml.regression import RandomForestRegressor

rf = RandomForestRegressor(featuresCol='features', labelCol='sellingprice')
rf_model = rf.fit(train)

                                                                                

In [41]:
rf_predictions = rf_model.transform(test)
rf_predictions.show(10)

+--------------------+------------+------------------+
|            features|sellingprice|        prediction|
+--------------------+------------+------------------+
|(14,[0,1,2,7,8,11...|     13800.0|11502.019842857511|
|(14,[0,1,2,7,8,11...|     17800.0|14136.184882140451|
|(14,[0,2,3,6,7,8,...|     13400.0|13208.265514886763|
|(14,[0,2,3,6,7,8,...|     15900.0|14298.971691516828|
|(14,[0,2,3,6,7,8,...|     17600.0| 14995.26873292516|
|(14,[0,2,3,6,7,8,...|     19300.0|18668.290276713567|
|(14,[0,2,3,6,7,8,...|      9100.0| 11615.31538787431|
|(14,[0,2,3,7,8,9,...|     11200.0|11750.879354513714|
|(14,[0,2,3,7,8,9,...|     19200.0| 15347.23815934978|
|(14,[0,2,3,7,8,11...|     11500.0| 11054.13890218836|
+--------------------+------------+------------------+
only showing top 10 rows



In [43]:
rf_rmse, rf_r2 = evaluate(rf_predictions)
rf_rmse, rf_r2

                                                                                

(3146.2556264226287, 0.8926852580087221)