In [1]:
from pyspark.sql import functions as F
from pyspark.sql.types import FloatType, TimestampType, StringType, StructType, StructField, IntegerType
from pyspark.sql.window import Window

from pyspark.sql.functions import array, col, explode, lit, struct
from pyspark.sql import DataFrame
from typing import Iterable 

In [2]:
spark.conf.set("spark.sql.session.timeZone", 'Asia/Shanghai')

In [3]:
spark = SparkSession.builder.appName('demo-app').getOrCreate()

In [4]:
schema = StructType([
    StructField("date", StringType(), True), # nullable=True, this field can not be null
    StructField("date_block_num", IntegerType(), True),
    StructField("shop_id", IntegerType(), True),
    StructField("item_id", IntegerType(), True),
    StructField("item_price", FloatType(), True),
    StructField("item_cnt_day", FloatType(), True),
])

In [5]:
%%time
df = spark.read.option('header', 'True').csv('data/sales_train.csv', schema=schema)
df.count()

CPU times: user 3.54 ms, sys: 2.21 ms, total: 5.75 ms
Wall time: 3.93 s


2935849

In [6]:
df.printSchema()

root
 |-- date: string (nullable = true)
 |-- date_block_num: integer (nullable = true)
 |-- shop_id: integer (nullable = true)
 |-- item_id: integer (nullable = true)
 |-- item_price: float (nullable = true)
 |-- item_cnt_day: float (nullable = true)



In [7]:
df.show(5)

+----------+--------------+-------+-------+----------+------------+
|      date|date_block_num|shop_id|item_id|item_price|item_cnt_day|
+----------+--------------+-------+-------+----------+------------+
|02.01.2013|             0|     59|  22154|     999.0|         1.0|
|03.01.2013|             0|     25|   2552|     899.0|         1.0|
|05.01.2013|             0|     25|   2552|     899.0|        -1.0|
|06.01.2013|             0|     25|   2554|   1709.05|         1.0|
|15.01.2013|             0|     25|   2555|    1099.0|         1.0|
+----------+--------------+-------+-------+----------+------------+
only showing top 5 rows



In [8]:
df = df.withColumn('date', F.from_unixtime(F.unix_timestamp('date', 'dd.MM.yyyy')).cast('timestamp'))
df.printSchema()

root
 |-- date: timestamp (nullable = true)
 |-- date_block_num: integer (nullable = true)
 |-- shop_id: integer (nullable = true)
 |-- item_id: integer (nullable = true)
 |-- item_price: float (nullable = true)
 |-- item_cnt_day: float (nullable = true)



In [9]:
df.show(5)

+-------------------+--------------+-------+-------+----------+------------+
|               date|date_block_num|shop_id|item_id|item_price|item_cnt_day|
+-------------------+--------------+-------+-------+----------+------------+
|2013-01-02 00:00:00|             0|     59|  22154|     999.0|         1.0|
|2013-01-03 00:00:00|             0|     25|   2552|     899.0|         1.0|
|2013-01-05 00:00:00|             0|     25|   2552|     899.0|        -1.0|
|2013-01-06 00:00:00|             0|     25|   2554|   1709.05|         1.0|
|2013-01-15 00:00:00|             0|     25|   2555|    1099.0|         1.0|
+-------------------+--------------+-------+-------+----------+------------+
only showing top 5 rows



In [10]:
def melt(
        df: DataFrame, 
        id_vars: Iterable[str], value_vars: Iterable[str], 
        var_name: str="variable", value_name: str="value") -> DataFrame:
    """Convert :class:`DataFrame` from wide to long format."""

    # Create array<struct<variable: str, value: ...>>
    _vars_and_vals = array(*(
        struct(lit(c).alias(var_name), col(c).alias(value_name)) 
        for c in value_vars))

    # Add to the DataFrame and explode
    _tmp = df.withColumn("_vars_and_vals", explode(_vars_and_vals))

    cols = id_vars + [
            col("_vars_and_vals")[x].alias(x) for x in [var_name, value_name]]
    return _tmp.select(*cols)

In [11]:
df2 = melt(df, id_vars=['date', 'date_block_num', 'shop_id', 'item_id'], value_vars=['item_price', 'item_cnt_day'])
print(df2.count())
df2.show(5)

5871698
+-------------------+--------------+-------+-------+------------+-----+
|               date|date_block_num|shop_id|item_id|    variable|value|
+-------------------+--------------+-------+-------+------------+-----+
|2013-01-02 00:00:00|             0|     59|  22154|  item_price|999.0|
|2013-01-02 00:00:00|             0|     59|  22154|item_cnt_day|  1.0|
|2013-01-03 00:00:00|             0|     25|   2552|  item_price|899.0|
|2013-01-03 00:00:00|             0|     25|   2552|item_cnt_day|  1.0|
|2013-01-05 00:00:00|             0|     25|   2552|  item_price|899.0|
+-------------------+--------------+-------+-------+------------+-----+
only showing top 5 rows



In [12]:
"""
from pyspark.sql.functions import when

df2 = df.withColumn('cond', 
              when(df.MSZoning=='RH', 'RH_cond').
              when(df.MSZoning=='FV', 'FV_cond').
              when(df.MSZoning=='RL', 'RL_cond').
              otherwise(' ')
             )
"""

"\nfrom pyspark.sql.functions import when\n\ndf2 = df.withColumn('cond', \n              when(df.MSZoning=='RH', 'RH_cond').\n              when(df.MSZoning=='FV', 'FV_cond').\n              when(df.MSZoning=='RL', 'RL_cond').\n              otherwise(' ')\n             )\n"

In [13]:
%%time
df2.write.format('orc').mode('overwrite').saveAsTable('sales')

CPU times: user 2.95 ms, sys: 2.28 ms, total: 5.23 ms
Wall time: 16.6 s


In [14]:
%%time
spark.sql("""
select * from sales limit 100
""").show()

+-------------------+--------------+-------+-------+------------+-------+
|               date|date_block_num|shop_id|item_id|    variable|  value|
+-------------------+--------------+-------+-------+------------+-------+
|2013-01-02 00:00:00|             0|     59|  22154|  item_price|  999.0|
|2013-01-02 00:00:00|             0|     59|  22154|item_cnt_day|    1.0|
|2013-01-03 00:00:00|             0|     25|   2552|  item_price|  899.0|
|2013-01-03 00:00:00|             0|     25|   2552|item_cnt_day|    1.0|
|2013-01-05 00:00:00|             0|     25|   2552|  item_price|  899.0|
|2013-01-05 00:00:00|             0|     25|   2552|item_cnt_day|   -1.0|
|2013-01-06 00:00:00|             0|     25|   2554|  item_price|1709.05|
|2013-01-06 00:00:00|             0|     25|   2554|item_cnt_day|    1.0|
|2013-01-15 00:00:00|             0|     25|   2555|  item_price| 1099.0|
|2013-01-15 00:00:00|             0|     25|   2555|item_cnt_day|    1.0|
|2013-01-10 00:00:00|             0|  

In [15]:
%%time
url = 'jdbc:mysql://localhost:3306/spark_db?user=spark&password=sparkpw'
# useUnicode=true&characterEncoding=GBK

df.write.jdbc(
    url=url,
    mode="append",
    table="sales",
    properties={"driver": 'com.mysql.jdbc.Driver'})

CPU times: user 16.5 ms, sys: 8.76 ms, total: 25.3 ms
Wall time: 3min 11s


In [16]:
%%time
df.write.format('jdbc')\
    .option('url',url)\
    .option('mode',"append")\
    .option('dbtable',"sales2")\
    .option('batchsize',10000)\
    .option('isolationLevel',"NONE")\
    .option('driver','com.mysql.jdbc.Driver').save()

CPU times: user 35.1 ms, sys: 19.3 ms, total: 54.4 ms
Wall time: 7min 35s


In [17]:
%%time
df.write.mode('overwrite').format('jdbc')\
    .option('url',url + '&rewriteBatchedStatements=true')\
    .option('dbtable',"sales22")\
    .option('batchsize',10000)\
    .option('isolationLevel',"NONE")\
    .option('driver','com.mysql.jdbc.Driver')\
    .option("truncate", "true").save()

CPU times: user 5.66 ms, sys: 2.98 ms, total: 8.63 ms
Wall time: 32.2 s


In [18]:
%%time
df.write.mode('append').format('jdbc')\
    .option('url',url + '&rewriteBatchedStatements=true')\
    .option('dbtable',"sales22")\
    .option('batchsize',10000)\
    .option('isolationLevel',"NONE")\
    .option('driver','com.mysql.jdbc.Driver').save()

CPU times: user 4.21 ms, sys: 2.69 ms, total: 6.89 ms
Wall time: 31 s
