In [1]:
import pyspark
from pyspark.sql import SparkSession, SQLContext
from pyspark.sql.types import StructType, StructField, IntegerType, StringType, DateType
from pyspark.sql.functions import lit
from pyspark.ml.linalg import Vectors

In [2]:
spark = SparkSession \
    .builder \
    .appName("example") \
    .getOrCreate()
sc = spark.sparkContext
sqlc = SQLContext(sc)

In [3]:
df = sqlc.createDataFrame([
    (0, 18, 'female'),
    (1, 25, 'male'),
    (2, 40, 'female'),
    (3, 36, 'male'),
    (3, 36, 'male'),
], ['uid', 'age', 'gender'])

In [4]:
df.show()

+---+---+------+
|uid|age|gender|
+---+---+------+
|  0| 18|female|
|  1| 25|  male|
|  2| 40|female|
|  3| 36|  male|
|  3| 36|  male|
+---+---+------+



* write csv

In [5]:
df.write.mode('overwrite') \
    .option('header', 'true') \
    .csv('/tmp/csv_test')

* read csv auto

In [6]:
tdf = sqlc.read \
    .option('header', 'true') \
    .option('inferschema', 'true') \
    .csv('/tmp/csv_test')

In [7]:
tdf.show()

+---+---+------+
|uid|age|gender|
+---+---+------+
|  2| 40|female|
|  3| 36|  male|
|  3| 36|  male|
|  0| 18|female|
|  1| 25|  male|
+---+---+------+



In [8]:
tdf.printSchema()

root
 |-- uid: integer (nullable = true)
 |-- age: integer (nullable = true)
 |-- gender: string (nullable = true)



* read csv specify schema

In [9]:
schema = StructType([
    StructField('uid', IntegerType(), False),
    StructField('age', IntegerType()),
    StructField('country', StringType()),
    StructField('gender', StringType()),
])

In [10]:
tdf = sqlc.read \
    .option('header', 'true') \
    .csv('/tmp/csv_test') 

In [11]:
tdf.show()

+---+---+------+
|uid|age|gender|
+---+---+------+
|  2| 40|female|
|  3| 36|  male|
|  3| 36|  male|
|  0| 18|female|
|  1| 25|  male|
+---+---+------+



In [12]:
tdf.printSchema()

root
 |-- uid: string (nullable = true)
 |-- age: string (nullable = true)
 |-- gender: string (nullable = true)



In [13]:
for field in schema.fields:
    if field.name not in tdf.columns:
        tdf = tdf.withColumn(field.name, lit(None))
    tdf = tdf.withColumn(field.name, tdf[field.name].cast(field.dataType))

In [14]:
tdf.show()

+---+---+------+-------+
|uid|age|gender|country|
+---+---+------+-------+
|  2| 40|female|   null|
|  3| 36|  male|   null|
|  3| 36|  male|   null|
|  0| 18|female|   null|
|  1| 25|  male|   null|
+---+---+------+-------+



In [15]:
tdf.printSchema()

root
 |-- uid: integer (nullable = true)
 |-- age: integer (nullable = true)
 |-- gender: string (nullable = true)
 |-- country: string (nullable = true)



In [16]:
res_df = sqlc.createDataFrame([], schema)

In [17]:
res_df.union(tdf[schema.names]).show()

+---+---+-------+------+
|uid|age|country|gender|
+---+---+-------+------+
|  2| 40|   null|female|
|  3| 36|   null|  male|
|  3| 36|   null|  male|
|  0| 18|   null|female|
|  1| 25|   null|  male|
+---+---+-------+------+



In [18]:
res_df.union(tdf[schema.names]).printSchema()

root
 |-- uid: integer (nullable = true)
 |-- age: integer (nullable = true)
 |-- country: string (nullable = true)
 |-- gender: string (nullable = true)

