In [1]:
import os
import sys

os.environ['PYSPARK_PYTHON'] = sys.executable
os.environ['PYSPARK_DRIVER_PYTHON'] = sys.executable

## 创建spark对象

In [2]:
from pyspark.sql import SparkSession

spark = SparkSession.builder.appName("Python Spark SQL basic example").getOrCreate()
spark

`spark.createDataFrame`接受的对象包括列表List，pandas.DataFrame  numpy.ndarray等。

### 1. 从list创建

In [4]:
from datetime import datetime, date
import pandas as pd
from pyspark.sql import Row

df = spark.createDataFrame([
    Row(a=1, b=2., c='string1', d=date(2000, 1, 1), e=datetime(2000, 1, 1, 12, 0)),
    Row(a=2, b=3., c='string2', d=date(2000, 2, 1), e=datetime(2000, 1, 2, 12, 0)),
    Row(a=4, b=5., c='string3', d=date(2000, 3, 1), e=datetime(2000, 1, 3, 12, 0))
])
df.show()

+---+---+-------+----------+-------------------+
|  a|  b|      c|         d|                  e|
+---+---+-------+----------+-------------------+
|  1|2.0|string1|2000-01-01|2000-01-01 12:00:00|
|  2|3.0|string2|2000-02-01|2000-01-02 12:00:00|
|  4|5.0|string3|2000-03-01|2000-01-03 12:00:00|
+---+---+-------+----------+-------------------+



#### Schema: 定义数据类型
在 Apache Spark 中，Schema 是用于定义和描述数据框（DataFrame）中列的数据类型的结构。Schema 包括每个列的名称和对应的数据类型。使用 Schema 可以确保数据框中的数据按照预期的格式和类型进行组织。

Spark 数据框可以具有显式的或隐式的 Schema。显式 Schema 是在创建数据框时明确指定的，而隐式 Schema 是通过 Spark 在运行时自动推断的。在大多数情况下，显式 Schema 是推荐的，因为它可以提高性能并确保数据的正确解析。

schema的几种形式：
1. str: schema='a long, b double, c string, d date, e timestamp'
2. list: schema = ['a', 'b', 'c', 'd', 'e']
3. StructType: 
   ```python
    from pyspark.sql.types import StructType, StructField, StringType, IntegerType
    schema = StructType([
        StructField("Name", StringType(), True),
        StructField("Age", IntegerType(), True)
    ])
   
   ```

In [5]:
df = spark.createDataFrame([
    (1, 2., 'string1', date(2000, 1, 1), datetime(2000, 1, 1, 12, 0)),
    (2, 3., 'string2', date(2000, 2, 1), datetime(2000, 1, 2, 12, 0)),
    (3, 4., 'string3', date(2000, 3, 1), datetime(2000, 1, 3, 12, 0))
], schema='a long, b double, c string, d date, e timestamp')
df.show()

+---+---+-------+----------+-------------------+
|  a|  b|      c|         d|                  e|
+---+---+-------+----------+-------------------+
|  1|2.0|string1|2000-01-01|2000-01-01 12:00:00|
|  2|3.0|string2|2000-02-01|2000-01-02 12:00:00|
|  3|4.0|string3|2000-03-01|2000-01-03 12:00:00|
+---+---+-------+----------+-------------------+



### 2. 从`pandas.DataFrame`创建

In [6]:
pandas_df = pd.DataFrame({
    'a': [1, 2, 3],
    'b': [2., 3., 4.],
    'c': ['string1', 'string2', 'string3'],
    'd': [date(2000, 1, 1), date(2000, 2, 1), date(2000, 3, 1)],
    'e': [datetime(2000, 1, 1, 12, 0), datetime(2000, 1, 2, 12, 0), datetime(2000, 1, 3, 12, 0)]
})
df = spark.createDataFrame(pandas_df)
df.show()

+---+---+-------+----------+-------------------+
|  a|  b|      c|         d|                  e|
+---+---+-------+----------+-------------------+
|  1|2.0|string1|2000-01-01|2000-01-01 12:00:00|
|  2|3.0|string2|2000-02-01|2000-01-02 12:00:00|
|  3|4.0|string3|2000-03-01|2000-01-03 12:00:00|
+---+---+-------+----------+-------------------+



In [7]:
df.head(10)

[Row(a=1, b=2.0, c='string1', d=datetime.date(2000, 1, 1), e=datetime.datetime(2000, 1, 1, 12, 0)),
 Row(a=2, b=3.0, c='string2', d=datetime.date(2000, 2, 1), e=datetime.datetime(2000, 1, 2, 12, 0)),
 Row(a=3, b=4.0, c='string3', d=datetime.date(2000, 3, 1), e=datetime.datetime(2000, 1, 3, 12, 0))]

In [8]:
df.printSchema()

root
 |-- a: long (nullable = true)
 |-- b: double (nullable = true)
 |-- c: string (nullable = true)
 |-- d: date (nullable = true)
 |-- e: timestamp (nullable = true)



In [9]:
df.show(1)

+---+---+-------+----------+-------------------+
|  a|  b|      c|         d|                  e|
+---+---+-------+----------+-------------------+
|  1|2.0|string1|2000-01-01|2000-01-01 12:00:00|
+---+---+-------+----------+-------------------+
only showing top 1 row



### 设置jupyter notebook展示

In [10]:
spark.conf.set('spark.sql.repl.eagerEval.enabled', True)
df

a,b,c,d,e
1,2.0,string1,2000-01-01,2000-01-01 12:00:00
2,3.0,string2,2000-02-01,2000-01-02 12:00:00
3,4.0,string3,2000-03-01,2000-01-03 12:00:00


In [11]:
# vertically show
df.show(1, vertical=True)

-RECORD 0------------------
 a   | 1                   
 b   | 2.0                 
 c   | string1             
 d   | 2000-01-01          
 e   | 2000-01-01 12:00:00 
only showing top 1 row



In [12]:
df.columns

['a', 'b', 'c', 'd', 'e']

## 统计
统计，和dataframe的describe()类似。

In [13]:
df.select('a', 'b', 'c').describe().show()

+-------+---+---+-------+
|summary|  a|  b|      c|
+-------+---+---+-------+
|  count|  3|  3|      3|
|   mean|2.0|3.0|   NULL|
| stddev|1.0|1.0|   NULL|
|    min|  1|2.0|string1|
|    max|  3|4.0|string3|
+-------+---+---+-------+



In [13]:
df.collect()

[Row(a=1, b=2.0, c='string1', d=datetime.date(2000, 1, 1), e=datetime.datetime(2000, 1, 1, 12, 0)),
 Row(a=2, b=3.0, c='string2', d=datetime.date(2000, 2, 1), e=datetime.datetime(2000, 1, 2, 12, 0)),
 Row(a=3, b=4.0, c='string3', d=datetime.date(2000, 3, 1), e=datetime.datetime(2000, 1, 3, 12, 0))]

collect() 会展示driver里的全部数据，容易爆内存，需要注意。为了防止爆内存，可以使用take(n)来控制返回的行数，或者`tail()`、`head()`只取头尾部分数据。

In [14]:
df.take(1)

[Row(a=1, b=2.0, c='string1', d=datetime.date(2000, 1, 1), e=datetime.datetime(2000, 1, 1, 12, 0))]

In [15]:
type(df)

pyspark.sql.dataframe.DataFrame

pyspark dataframe 转 pandas dataframe： `to_pandas()`。

In [16]:
df.toPandas()

Unnamed: 0,a,b,c,d,e
0,1,2.0,string1,2000-01-01,2000-01-01 12:00:00
1,2,3.0,string2,2000-02-01,2000-01-02 12:00:00
2,3,4.0,string3,2000-03-01,2000-01-03 12:00:00


## 列操作

In [15]:
from pyspark.sql import Column
from pyspark.sql.functions import upper

df.show()

+---+---+-------+----------+-------------------+
|  a|  b|      c|         d|                  e|
+---+---+-------+----------+-------------------+
|  1|2.0|string1|2000-01-01|2000-01-01 12:00:00|
|  2|3.0|string2|2000-02-01|2000-01-02 12:00:00|
|  3|4.0|string3|2000-03-01|2000-01-03 12:00:00|
+---+---+-------+----------+-------------------+



In [16]:
type(df.c) == type(upper(df.c)) == type(df.c.isNull())

True

### 查询: select
用`select`函数查询某一列的数据。

In [17]:
df.select(df.c, upper(df.c)).show()

+-------+--------+
|      c|upper(c)|
+-------+--------+
|string1| STRING1|
|string2| STRING2|
|string3| STRING3|
+-------+--------+



也可直接输入列名查询：

In [18]:
df.select('c').show()

+-------+
|      c|
+-------+
|string1|
|string2|
|string3|
+-------+



`selectExpr`是`select`的一个变种，可以接收SQL表达式并返回DataFrame结果。

In [28]:
# 取特定列
df.selectExpr('a', 'b').show()

+---+---+
|  a|  b|
+---+---+
|  1|2.0|
|  2|3.0|
|  3|4.0|
+---+---+



In [30]:
df.selectExpr('max(b) as max_b').show()

+-----+
|max_b|
+-----+
|  4.0|
+-----+



In [29]:
df.selectExpr('count(a) as Count_a').show()

+-------+
|Count_a|
+-------+
|      3|
+-------+



### 创建新列: withColumn
用`withColumn`添加或替换数据框中的列。

In [19]:
df.withColumn('upper_c', upper(df.c)).show()

+---+---+-------+----------+-------------------+-------+
|  a|  b|      c|         d|                  e|upper_c|
+---+---+-------+----------+-------------------+-------+
|  1|2.0|string1|2000-01-01|2000-01-01 12:00:00|STRING1|
|  2|3.0|string2|2000-02-01|2000-01-02 12:00:00|STRING2|
|  3|4.0|string3|2000-03-01|2000-01-03 12:00:00|STRING3|
+---+---+-------+----------+-------------------+-------+



In [20]:
df = df.withColumn('upper_c', upper(df.c))
df.show()

+---+---+-------+----------+-------------------+-------+
|  a|  b|      c|         d|                  e|upper_c|
+---+---+-------+----------+-------------------+-------+
|  1|2.0|string1|2000-01-01|2000-01-01 12:00:00|STRING1|
|  2|3.0|string2|2000-02-01|2000-01-02 12:00:00|STRING2|
|  3|4.0|string3|2000-03-01|2000-01-03 12:00:00|STRING3|
+---+---+-------+----------+-------------------+-------+



In [21]:
df = df.withColumn('b^^2', df.b ** 2)
df.show()

+---+---+-------+----------+-------------------+-------+----+
|  a|  b|      c|         d|                  e|upper_c|b^^2|
+---+---+-------+----------+-------------------+-------+----+
|  1|2.0|string1|2000-01-01|2000-01-01 12:00:00|STRING1| 4.0|
|  2|3.0|string2|2000-02-01|2000-01-02 12:00:00|STRING2| 9.0|
|  3|4.0|string3|2000-03-01|2000-01-03 12:00:00|STRING3|16.0|
+---+---+-------+----------+-------------------+-------+----+



### 筛选: filter
例如筛选a变量中大于1的数据：

In [22]:
print(f'Columns: {df.columns}')
df.filter(df.a > 1).show()

Columns: ['a', 'b', 'c', 'd', 'e', 'upper_c', 'b^^2']
+---+---+-------+----------+-------------------+-------+----+
|  a|  b|      c|         d|                  e|upper_c|b^^2|
+---+---+-------+----------+-------------------+-------+----+
|  2|3.0|string2|2000-02-01|2000-01-02 12:00:00|STRING2| 9.0|
|  3|4.0|string3|2000-03-01|2000-01-03 12:00:00|STRING3|16.0|
+---+---+-------+----------+-------------------+-------+----+



In [32]:
df.filter(df.c == 'string1').show()

+---+---+-------+----------+-------------------+-------+----+
|  a|  b|      c|         d|                  e|upper_c|b^^2|
+---+---+-------+----------+-------------------+-------+----+
|  1|2.0|string1|2000-01-01|2000-01-01 12:00:00|STRING1| 4.0|
+---+---+-------+----------+-------------------+-------+----+



### 删除： drop

drop意为丢弃，删除。`pyspark.pandas.DataFrame.drop`有以下4个参数：
1. labels: 要删除的列的标签，str or list.
2. axis: 删除行还是列。行：0 or 'index'。列：1 or 'columns'。默认为0.
3. index：删除索引，即行。可选参数。
4. columns: 删除列，可选参数，`labels, axis=1`等价于`columns=labels`。

注意`pyspark.sql.DataFrame.drop`和`pyspark.pandas.DataFrame.drop`功能相同，但是参数不同，`pyspark.sql.DataFrame.drop`只有一个输入参数：列名或Column类（即df.a）。

In [35]:
df.show()

+---+---+-------+----------+-------------------+-------+----+
|  a|  b|      c|         d|                  e|upper_c|b^^2|
+---+---+-------+----------+-------------------+-------+----+
|  1|2.0|string1|2000-01-01|2000-01-01 12:00:00|STRING1| 4.0|
|  2|3.0|string2|2000-02-01|2000-01-02 12:00:00|STRING2| 9.0|
|  3|4.0|string3|2000-03-01|2000-01-03 12:00:00|STRING3|16.0|
+---+---+-------+----------+-------------------+-------+----+



In [43]:
# 注意参数类型
type(df)

pyspark.sql.dataframe.DataFrame

In [41]:
df.drop('a').show()

+---+-------+----------+-------------------+-------+----+
|  b|      c|         d|                  e|upper_c|b^^2|
+---+-------+----------+-------------------+-------+----+
|2.0|string1|2000-01-01|2000-01-01 12:00:00|STRING1| 4.0|
|3.0|string2|2000-02-01|2000-01-02 12:00:00|STRING2| 9.0|
|4.0|string3|2000-03-01|2000-01-03 12:00:00|STRING3|16.0|
+---+-------+----------+-------------------+-------+----+



In [42]:
df.drop(df.e).show()

+---+---+-------+----------+-------+----+
|  a|  b|      c|         d|upper_c|b^^2|
+---+---+-------+----------+-------+----+
|  1|2.0|string1|2000-01-01|STRING1| 4.0|
|  2|3.0|string2|2000-02-01|STRING2| 9.0|
|  3|4.0|string3|2000-03-01|STRING3|16.0|
+---+---+-------+----------+-------+----+



### 增加新的行：Union
spark里没法直接增加新数据，但是可以通过合并的方式进行增加。

In [49]:
df = df.drop('upper_c', 'b^^2')
df

a,b,c,d,e
1,2.0,string1,2000-01-01,2000-01-01 12:00:00
2,3.0,string2,2000-02-01,2000-01-02 12:00:00
3,4.0,string3,2000-03-01,2000-01-03 12:00:00


In [51]:
schema = 'a long, b double, c string, d date, e timestamp'
df2 = spark.createDataFrame([
    (4, 5., 'string4', date(2000, 4, 1), datetime(2000, 1, 4, 12, 0)),
    (4, 5., 'string4', date(2000, 4, 1), datetime(2000, 1, 4, 12, 0)),
    (4, 6., 'string4', date(2000, 4, 1), datetime(2000, 1, 4, 12, 0)),
], schema = schema)
df2

a,b,c,d,e
4,5.0,string4,2000-04-01,2000-01-04 12:00:00
4,5.0,string4,2000-04-01,2000-01-04 12:00:00
4,6.0,string4,2000-04-01,2000-01-04 12:00:00


In [52]:
df = df.union(df2)
df

a,b,c,d,e
1,2.0,string1,2000-01-01,2000-01-01 12:00:00
2,3.0,string2,2000-02-01,2000-01-02 12:00:00
3,4.0,string3,2000-03-01,2000-01-03 12:00:00
4,5.0,string4,2000-04-01,2000-01-04 12:00:00
4,5.0,string4,2000-04-01,2000-01-04 12:00:00
4,5.0,string4,2000-04-01,2000-01-04 12:00:00
4,5.0,string4,2000-04-01,2000-01-04 12:00:00
4,6.0,string4,2000-04-01,2000-01-04 12:00:00


### 去重：dropDuplicates
`dropDuplicates`函数默认只去除全部列的值都一样的样本，如果要更严格的去重可以指定对比的列。

In [53]:
df.dropDuplicates().show()

+---+---+-------+----------+-------------------+
|  a|  b|      c|         d|                  e|
+---+---+-------+----------+-------------------+
|  1|2.0|string1|2000-01-01|2000-01-01 12:00:00|
|  2|3.0|string2|2000-02-01|2000-01-02 12:00:00|
|  3|4.0|string3|2000-03-01|2000-01-03 12:00:00|
|  4|5.0|string4|2000-04-01|2000-01-04 12:00:00|
|  4|6.0|string4|2000-04-01|2000-01-04 12:00:00|
+---+---+-------+----------+-------------------+



有两个a = 4, c = string4的数据，还可以继续去重：

In [55]:
df.dropDuplicates(['a', 'c']).show()

+---+---+-------+----------+-------------------+
|  a|  b|      c|         d|                  e|
+---+---+-------+----------+-------------------+
|  1|2.0|string1|2000-01-01|2000-01-01 12:00:00|
|  2|3.0|string2|2000-02-01|2000-01-02 12:00:00|
|  3|4.0|string3|2000-03-01|2000-01-03 12:00:00|
|  4|5.0|string4|2000-04-01|2000-01-04 12:00:00|
+---+---+-------+----------+-------------------+



## Applying a Function
终于可以用上装饰器了。

In [26]:
import pandas as pd
from pyspark.sql.functions import pandas_udf

@pandas_udf('long')
def pandas_plus_one(series: pd.Series) -> pd.Series:
    return series + 1

df.select(pandas_plus_one('a').alias('a+1')).show()

+---+
|a+1|
+---+
|  2|
|  3|
|  4|
+---+



In [27]:
df.select(pandas_plus_one(df.a).alias('a+1')).show()

+---+
|a+1|
+---+
|  2|
|  3|
|  4|
+---+



In [29]:
df

a,b,c,d,e
1,2.0,string1,2000-01-01,2000-01-01 12:00:00
2,3.0,string2,2000-02-01,2000-01-02 12:00:00
3,4.0,string3,2000-03-01,2000-01-03 12:00:00


In [30]:
def pandas_filter_func(iterator):
    for pandas_df in iterator:
        yield pandas_df[pandas_df.a == 1]

df.mapInPandas(pandas_filter_func, schema=df.schema).show()

+---+---+-------+----------+-------------------+
|  a|  b|      c|         d|                  e|
+---+---+-------+----------+-------------------+
|  1|2.0|string1|2000-01-01|2000-01-01 12:00:00|
+---+---+-------+----------+-------------------+



## Grouping Data

In [31]:
df = spark.createDataFrame([
    ['red', 'banana', 1, 10], ['blue', 'banana', 2, 20], ['red', 'carrot', 3, 30],
    ['blue', 'grape', 4, 40], ['red', 'carrot', 5, 50], ['black', 'carrot', 6, 60],
    ['red', 'banana', 7, 70], ['red', 'grape', 8, 80]], schema=['color', 'fruit', 'v1', 'v2'])
df.show()

+-----+------+---+---+
|color| fruit| v1| v2|
+-----+------+---+---+
|  red|banana|  1| 10|
| blue|banana|  2| 20|
|  red|carrot|  3| 30|
| blue| grape|  4| 40|
|  red|carrot|  5| 50|
|black|carrot|  6| 60|
|  red|banana|  7| 70|
|  red| grape|  8| 80|
+-----+------+---+---+



In [33]:
df.groupby('color').avg().show()

+-----+-------+-------+
|color|avg(v1)|avg(v2)|
+-----+-------+-------+
|  red|    4.8|   48.0|
| blue|    3.0|   30.0|
|black|    6.0|   60.0|
+-----+-------+-------+



In [35]:
df.groupby('fruit').avg().show()

+------+------------------+------------------+
| fruit|           avg(v1)|           avg(v2)|
+------+------------------+------------------+
|banana|3.3333333333333335|33.333333333333336|
|carrot| 4.666666666666667|46.666666666666664|
| grape|               6.0|              60.0|
+------+------------------+------------------+



可以嵌套使用Python函数。

In [53]:
def plus_mean(pandas_df):
    return pandas_df.assign(v1=pandas_df.v1 - pandas_df.v1.mean())

def mean(pandas_df):
    return pandas_df.assign(v2=pandas_df.v2.mean())

In [52]:
print(type(df))
df

<class 'pyspark.sql.dataframe.DataFrame'>


color,fruit,v1,v2
red,banana,1,10
blue,banana,2,20
red,carrot,3,30
blue,grape,4,40
red,carrot,5,50
black,carrot,6,60
red,banana,7,70
red,grape,8,80


In [50]:
df.groupby('color').applyInPandas(plus_mean, schema=df.schema).show()

+-----+------+---+---+
|color| fruit| v1| v2|
+-----+------+---+---+
|black|carrot|  0| 60|
| blue|banana| -1| 20|
| blue| grape|  1| 40|
|  red|banana| -3| 10|
|  red|carrot| -1| 30|
|  red|carrot|  0| 50|
|  red|banana|  2| 70|
|  red| grape|  3| 80|
+-----+------+---+---+



In [55]:
df.groupby('fruit').applyInPandas(mean, schema=df.schema).show()

+-----+------+---+---+
|color| fruit| v1| v2|
+-----+------+---+---+
|  red|banana|  1| 33|
| blue|banana|  2| 33|
|  red|banana|  7| 33|
|  red|carrot|  3| 46|
|  red|carrot|  5| 46|
|black|carrot|  6| 46|
| blue| grape|  4| 60|
|  red| grape|  8| 60|
+-----+------+---+---+



In [56]:
df.describe()

summary,color,fruit,v1,v2
count,8,8,8.0,8.0
mean,,,4.5,45.0
stddev,,,2.449489742783178,24.49489742783178
min,black,banana,1.0,10.0
max,red,grape,8.0,80.0


In [57]:
df1 = spark.createDataFrame(
    [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)],
    ('time', 'id', 'v1'))

df2 = spark.createDataFrame(
    [(20000101, 1, 'x'), (20000101, 2, 'y')],
    ('time', 'id', 'v2'))

def merge_ordered(l, r):
    return pd.merge_ordered(l, r)

df1.groupby('id').cogroup(df2.groupby('id')).applyInPandas(
    merge_ordered, schema='time int, id int, v1 double, v2 string').show()

+--------+---+---+----+
|    time| id| v1|  v2|
+--------+---+---+----+
|20000101|  1|1.0|   x|
|20000102|  1|3.0|NULL|
|20000101|  2|2.0|   y|
|20000102|  2|4.0|NULL|
+--------+---+---+----+



## 统计均值等特征的方法

In [3]:
from pyspark.sql.functions import col, count, mean, sum, min, max, stddev, skewness, kurtosis
from pyspark.sql import SparkSession

spark = SparkSession.builder.master("local[*]").appName("Python Spark SQL basic example").getOrCreate()
df = spark.createDataFrame([
    ['red', 'banana', 1, 10], ['blue', 'banana', 2, 20], ['red', 'carrot', 3, 30],
    ['blue', 'grape', 4, 40], ['red', 'carrot', 5, 50], ['black', 'carrot', 6, 60],
    ['red', 'banana', 7, 70], ['red', 'grape', 8, 80]], schema=['color', 'fruit', 'v1', 'v2'])
df.show()

+-----+------+---+---+
|color| fruit| v1| v2|
+-----+------+---+---+
|  red|banana|  1| 10|
| blue|banana|  2| 20|
|  red|carrot|  3| 30|
| blue| grape|  4| 40|
|  red|carrot|  5| 50|
|black|carrot|  6| 60|
|  red|banana|  7| 70|
|  red| grape|  8| 80|
+-----+------+---+---+



In [4]:

df.groupby('color').agg(count('*'), mean('v1'), mean('v2')).show()

+-----+--------+-------+-------+
|color|count(1)|avg(v1)|avg(v2)|
+-----+--------+-------+-------+
|  red|       5|    4.8|   48.0|
| blue|       2|    3.0|   30.0|
|black|       1|    6.0|   60.0|
+-----+--------+-------+-------+

