In [3]:
import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.types import *
from pyspark.sql.functions import *

spark = (
    SparkSession.builder
    .appName("dataframe_manipulations")
    .getOrCreate()
)

from pyspark import SparkContext
sc = SparkContext.getOrCreate();

## PySpark – explode nested array into rows

#### Nested arrays >>>> flatten >>>> explode >>>> rows per element
#### Singles arrays >>>> explode >>>> rows per element

In [9]:
#creating data frame with nested arrays
arrayArrayData = [
  ("James",[["Java","Scala","C++"],["Spark","Java"]]),
  ("Michael",[["Spark","Java","C++"],["Spark","Java"]]),
  ("Robert",[["CSharp","VB"],["Spark","Python"]])
]

df = spark.createDataFrame(data=arrayArrayData, schema = ['name','subjects'])
df.printSchema()
df.show(truncate=False)

root
 |-- name: string (nullable = true)
 |-- subjects: array (nullable = true)
 |    |-- element: array (containsNull = true)
 |    |    |-- element: string (containsNull = true)

+-------+-----------------------------------+
|name   |subjects                           |
+-------+-----------------------------------+
|James  |[[Java, Scala, C++], [Spark, Java]]|
|Michael|[[Spark, Java, C++], [Spark, Java]]|
|Robert |[[CSharp, VB], [Spark, Python]]    |
+-------+-----------------------------------+



In [13]:
df.show(truncate=False)

+-------+-----------------------------------+
|name   |subjects                           |
+-------+-----------------------------------+
|James  |[[Java, Scala, C++], [Spark, Java]]|
|Michael|[[Spark, Java, C++], [Spark, Java]]|
|Robert |[[CSharp, VB], [Spark, Python]]    |
+-------+-----------------------------------+



In [10]:
from pyspark.sql.functions import explode
df.select(df.name,explode(df.subjects)).show()
# it creates 1 row per array exploded
#but what if we need 1 row per element?

+-------+------------------+
|   name|               col|
+-------+------------------+
|  James|[Java, Scala, C++]|
|  James|     [Spark, Java]|
|Michael|[Spark, Java, C++]|
|Michael|     [Spark, Java]|
| Robert|      [CSharp, VB]|
| Robert|   [Spark, Python]|
+-------+------------------+



In [11]:
from pyspark.sql.functions import flatten
df2= df.select(df.name,flatten(df.subjects)).withColumnRenamed('flatten(subjects)', 'subjects')
df2.show(truncate=False)
#with flatten we collect all elements in a single array
#after perform the flatten we could explode all the elements

+-------+-------------------------------+
|name   |subjects                       |
+-------+-------------------------------+
|James  |[Java, Scala, C++, Spark, Java]|
|Michael|[Spark, Java, C++, Spark, Java]|
|Robert |[CSharp, VB, Spark, Python]    |
+-------+-------------------------------+



In [12]:
from pyspark.sql.functions import explode
df3 = df2.select(df2.name,explode(df2.subjects))
df3.show()
# it creates 1 row per element now

+-------+------+
|   name|   col|
+-------+------+
|  James|  Java|
|  James| Scala|
|  James|   C++|
|  James| Spark|
|  James|  Java|
|Michael| Spark|
|Michael|  Java|
|Michael|   C++|
|Michael| Spark|
|Michael|  Java|
| Robert|CSharp|
| Robert|    VB|
| Robert| Spark|
| Robert|Python|
+-------+------+

