In [39]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import aggregate, mean, col, locate, lit, expr, struct, min, sum
import pandas as pd

In [3]:
spark = SparkSession.builder.appName("PySparkDemo").getOrCreate()
spark.sparkContext.setLogLevel("WARN")

In [4]:
spark

In [5]:
# Read CSV file and load data into a DataFrame
df = spark.read.csv('2015-summary.csv', header=True, inferSchema=True)

In [6]:
df.printSchema()

root
 |-- DEST_COUNTRY_NAME: string (nullable = true)
 |-- ORIGIN_COUNTRY_NAME: string (nullable = true)
 |-- count: integer (nullable = true)



In [7]:
df.show()

+--------------------+-------------------+-----+
|   DEST_COUNTRY_NAME|ORIGIN_COUNTRY_NAME|count|
+--------------------+-------------------+-----+
|       United States|            Romania|   15|
|       United States|            Croatia|    1|
|       United States|            Ireland|  344|
|               Egypt|      United States|   15|
|       United States|              India|   62|
|       United States|          Singapore|    1|
|       United States|            Grenada|   62|
|          Costa Rica|      United States|  588|
|             Senegal|      United States|   40|
|             Moldova|      United States|    1|
|       United States|       Sint Maarten|  325|
|       United States|   Marshall Islands|   39|
|              Guyana|      United States|   64|
|               Malta|      United States|    1|
|            Anguilla|      United States|   41|
|             Bolivia|      United States|   30|
|       United States|           Paraguay|    6|
|             Algeri

In [45]:
df.groupBy("DEST_COUNTRY_NAME").agg(sum("count").alias("Total")).show()

+--------------------+-----+
|   DEST_COUNTRY_NAME|Total|
+--------------------+-----+
|            Anguilla|   41|
|              Russia|  176|
|            Paraguay|   60|
|             Senegal|   40|
|              Sweden|  118|
|            Kiribati|   26|
|              Guyana|   64|
|         Philippines|  134|
|            Djibouti|    1|
|            Malaysia|    2|
|           Singapore|    3|
|                Fiji|   24|
|              Turkey|  138|
|                Iraq|    1|
|             Germany| 1468|
|              Jordan|   44|
|               Palau|   30|
|Turks and Caicos ...|  230|
|              France|  935|
|              Greece|   30|
+--------------------+-----+
only showing top 20 rows



In [91]:
df.selectExpr("locate('i', ORIGIN_COUNTRY_NAME, 1)").show()

+---------------------------------+
|locate(i, ORIGIN_COUNTRY_NAME, 1)|
+---------------------------------+
|                                6|
|                                6|
|                                0|
|                                3|
|                                4|
|                                2|
|                                0|
|                                3|
|                                3|
|                                3|
|                                2|
|                                0|
|                                3|
|                                3|
|                                3|
|                                3|
|                                0|
|                                3|
|                                3|
|                                2|
+---------------------------------+
only showing top 20 rows



In [94]:
df.select(expr("locate('i', ORIGIN_COUNTRY_NAME, 1)")).show()

+---------------------------------+
|locate(i, ORIGIN_COUNTRY_NAME, 1)|
+---------------------------------+
|                                6|
|                                6|
|                                0|
|                                3|
|                                4|
|                                2|
|                                0|
|                                3|
|                                3|
|                                3|
|                                2|
|                                0|
|                                3|
|                                3|
|                                3|
|                                3|
|                                0|
|                                3|
|                                3|
|                                2|
+---------------------------------+
only showing top 20 rows



In [6]:
dfr = spark.createDataFrame([(1, [20.0, 4.0, 2.0, 6.0, 10.0])], ("id", "values"))

In [7]:
def merge(acc, x):
    count = acc.count + 1
    sum = acc.sum + x
    return struct(count.alias("count"), sum.alias("sum"))

In [12]:
dfr.select(
    aggregate(
        "values",
        struct(lit(0).alias("count"), lit(0.0).alias("sum")),
        merge,
        lambda acc: acc.sum / acc.count,
    ).alias("mean")
).show()

+----+
|mean|
+----+
| 8.4|
+----+



In [32]:
df.groupBy().agg({"count": "min"}).show()

+----------+
|min(count)|
+----------+
|         1|
+----------+



In [33]:
dfx = spark.createDataFrame([([1, 2, 3],), ([4, 5],)], ['x'])

In [34]:
dfx.show()

+---------+
|        x|
+---------+
|[1, 2, 3]|
|   [4, 5]|
+---------+



In [61]:
df.selectExpr("sum(count) as count").show()

+------+
| count|
+------+
|453316|
+------+



In [62]:
df.groupBy().count().show()

+-----+
|count|
+-----+
|  256|
+-----+



In [59]:
def func1(x):
    journey = f'{x.ORIGIN_COUNTRY_NAME} --> {x.DEST_COUNTRY_NAME}'
    return (journey)

rdd2 = df.rdd.map(func1)

In [60]:
rdd2.take(5)

['Romania --> United States',
 'Croatia --> United States',
 'Ireland --> United States',
 'United States --> Egypt',
 'India --> United States']

In [63]:
data = [("Banana",1000,"USA"), ("Carrots",1500,"USA"), ("Beans",1600,"USA"), \
      ("Orange",2000,"USA"),("Orange",2000,"USA"),("Banana",400,"China"), \
      ("Carrots",1200,"China"),("Beans",1500,"China"),("Orange",4000,"China"), \
      ("Banana",2000,"Canada"),("Carrots",2000,"Canada"),("Beans",2000,"Mexico")]

columns= ["Product","Amount","Country"]
df = spark.createDataFrame(data = data, schema = columns)
df.printSchema()
df.show(truncate=False)

root
 |-- Product: string (nullable = true)
 |-- Amount: long (nullable = true)
 |-- Country: string (nullable = true)

+-------+------+-------+
|Product|Amount|Country|
+-------+------+-------+
|Banana |1000  |USA    |
|Carrots|1500  |USA    |
|Beans  |1600  |USA    |
|Orange |2000  |USA    |
|Orange |2000  |USA    |
|Banana |400   |China  |
|Carrots|1200  |China  |
|Beans  |1500  |China  |
|Orange |4000  |China  |
|Banana |2000  |Canada |
|Carrots|2000  |Canada |
|Beans  |2000  |Mexico |
+-------+------+-------+



In [65]:
pivotDF = df.groupBy("Product").pivot("Country").sum("Amount")
pivotDF.printSchema()
pivotDF.show(truncate=False)

                                                                                

root
 |-- Product: string (nullable = true)
 |-- Canada: long (nullable = true)
 |-- China: long (nullable = true)
 |-- Mexico: long (nullable = true)
 |-- USA: long (nullable = true)



                                                                                

+-------+------+-----+------+----+
|Product|Canada|China|Mexico|USA |
+-------+------+-----+------+----+
|Orange |null  |4000 |null  |4000|
|Beans  |null  |1500 |2000  |1600|
|Banana |2000  |400  |null  |1000|
|Carrots|2000  |1200 |null  |1500|
+-------+------+-----+------+----+



In [69]:
pivotDF = df.groupBy("Country").pivot("Product").sum("Amount")
pivotDF.printSchema()
pivotDF.show(truncate=False)

root
 |-- Country: string (nullable = true)
 |-- Banana: long (nullable = true)
 |-- Beans: long (nullable = true)
 |-- Carrots: long (nullable = true)
 |-- Orange: long (nullable = true)





+-------+------+-----+-------+------+
|Country|Banana|Beans|Carrots|Orange|
+-------+------+-----+-------+------+
|China  |400   |1500 |1200   |4000  |
|USA    |1000  |1600 |1500   |4000  |
|Mexico |null  |2000 |null   |null  |
|Canada |2000  |null |2000   |null  |
+-------+------+-----+-------+------+

