```bash
$ hadoop fs -mkdir linkage
$ hadoop fs -put block_*.csv linkage
```

In [1]:
sc

res0: org.apache.spark.SparkContext = org.apache.spark.SparkContext@22a8369c


In [4]:
val rawblocks = sc.textFile("data/linkage/*.csv")

rawblocks: org.apache.spark.rdd.RDD[String] = data/linkage/*.csv MapPartitionsRDD[3] at textFile at <console>:25


In [5]:
rawblocks.first

res2: String = "id_1","id_2","cmp_fname_c1","cmp_fname_c2","cmp_lname_c1","cmp_lname_c2","cmp_sex","cmp_bd","cmp_bm","cmp_by","cmp_plz","is_match"


In [7]:
val head = rawblocks.take(10)

head: Array[String] = Array("id_1","id_2","cmp_fname_c1","cmp_fname_c2","cmp_lname_c1","cmp_lname_c2","cmp_sex","cmp_bd","cmp_bm","cmp_by","cmp_plz","is_match", 607,53170,1,?,1,?,1,1,1,1,1,TRUE, 88569,88592,1,?,1,?,1,1,1,1,1,TRUE, 21282,26255,1,?,1,?,1,1,1,1,1,TRUE, 20995,42541,1,?,1,?,1,1,1,1,1,TRUE, 27989,34739,1,?,1,?,1,1,1,1,1,TRUE, 32442,69159,1,?,1,?,1,1,1,1,1,TRUE, 24738,29196,1,1,1,?,1,1,1,1,1,TRUE, 9904,89061,1,?,1,?,1,1,1,1,1,TRUE, 29926,36578,1,?,1,?,1,1,1,1,1,TRUE)


In [8]:
head.foreach(println)

"id_1","id_2","cmp_fname_c1","cmp_fname_c2","cmp_lname_c1","cmp_lname_c2","cmp_sex","cmp_bd","cmp_bm","cmp_by","cmp_plz","is_match"
607,53170,1,?,1,?,1,1,1,1,1,TRUE
88569,88592,1,?,1,?,1,1,1,1,1,TRUE
21282,26255,1,?,1,?,1,1,1,1,1,TRUE
20995,42541,1,?,1,?,1,1,1,1,1,TRUE
27989,34739,1,?,1,?,1,1,1,1,1,TRUE
32442,69159,1,?,1,?,1,1,1,1,1,TRUE
24738,29196,1,1,1,?,1,1,1,1,1,TRUE
9904,89061,1,?,1,?,1,1,1,1,1,TRUE
29926,36578,1,?,1,?,1,1,1,1,1,TRUE


In [9]:
//过滤标题行
def isHeader(line: String) = line.contains("id_1")
// 显式地指明函数返回类型
// def isHeader(line: String): Boolean = {
//     line.contains("id_1")
// }

isHeader: (line: String)Boolean


In [10]:
head.filter(isHeader).foreach(println)

"id_1","id_2","cmp_fname_c1","cmp_fname_c2","cmp_lname_c1","cmp_lname_c2","cmp_sex","cmp_bd","cmp_bm","cmp_by","cmp_plz","is_match"


In [11]:
head.filterNot(isHeader).length
//Int=9是因为head.length=10

res6: Int = 9


In [12]:
head.filter(x => !isHeader(x)).length

res7: Int = 9


In [13]:
head.filter(!isHeader(_)).length

res8: Int = 9


In [14]:
val noheader = rawblocks.filter(!isHeader(_))

noheader: org.apache.spark.rdd.RDD[String] = MapPartitionsRDD[4] at filter at <console>:28


In [15]:
noheader.first

res9: String = 607,53170,1,?,1,?,1,1,1,1,1,TRUE


2.7　从RDD到DataFrame

In [16]:
val prev = spark.read.csv("data/linkage/*.csv")

prev: org.apache.spark.sql.DataFrame = [_c0: string, _c1: string ... 10 more fields]


In [17]:
val parsed = spark.read.
        option("header","true").
        option("nullValue","?").
        option("inferSchema","true").
        csv("data/linkage/*.csv")

parsed: org.apache.spark.sql.DataFrame = [id_1: string, id_2: string ... 10 more fields]


In [18]:
parsed.printSchema()

root
 |-- id_1: string (nullable = true)
 |-- id_2: string (nullable = true)
 |-- cmp_fname_c1: string (nullable = true)
 |-- cmp_fname_c2: string (nullable = true)
 |-- cmp_lname_c1: string (nullable = true)
 |-- cmp_lname_c2: string (nullable = true)
 |-- cmp_sex: string (nullable = true)
 |-- cmp_bd: string (nullable = true)
 |-- cmp_bm: string (nullable = true)
 |-- cmp_by: integer (nullable = true)
 |-- cmp_plz: integer (nullable = true)
 |-- is_match: boolean (nullable = true)



In [19]:
parsed.count()

res11: Long = 5749133


一般来说，当数据可能被多个操作依赖时，并且相对于集群可用的内存和磁盘空间而言，如果数据集较小，而且重新生成的代价很高，那么数据就应该被缓存起来。

In [20]:
parsed.rdd.
    map(_.getAs[Boolean]("is_match")).
    countByValue()

res12: scala.collection.Map[Boolean,Long] = Map(true -> 20931, false -> 5728202)


In [21]:
parsed.
    groupBy("is_match").
    count().
    orderBy($"count".desc).
    show()

+--------+-------+
|is_match|  count|
+--------+-------+
|   false|5728201|
|    true|  20931|
|    null|      1|
+--------+-------+



In [22]:
parsed.agg(avg($"cmp_sex"),stddev_samp($"cmp_sex")).show()

+------------------+--------------------+
|      avg(cmp_sex)|stddev_samp(cmp_sex)|
+------------------+--------------------+
|0.9550012294607436|  0.2073014119031234|
+------------------+--------------------+



In [23]:
parsed.createOrReplaceTempView("linkage")

In [24]:
spark.sql("""
    select is_match,count(*) cnt
    from linkage
    group by is_match
    order by cnt Desc
""").show()

+--------+-------+
|is_match|    cnt|
+--------+-------+
|   false|5728201|
|    true|  20931|
|    null|      1|
+--------+-------+



2.9　DataFrame的统计信息

In [26]:
val summary = parsed.describe()

In [27]:
summary.select("summary", "cmp_fname_c1", "cmp_fname_c2").show()

+-------+--------------------+-------------------+
|summary|        cmp_fname_c1|       cmp_fname_c2|
+-------+--------------------+-------------------+
|  count|             5748126|             103699|
|   mean|  0.7129023464249419|  0.900008998936421|
| stddev| 0.38875843950829186|0.27133067681523776|
|    min|                   0|                  0|
|    max|2.68694413843136e-05|                  1|
+-------+--------------------+-------------------+



In [30]:
//调查变量与列 is_match 的值之间的相关性
val matches = parsed.where("is_match=true")
val matchSummary = matches.describe()

val misses = parsed.filter($"is_match"===false)
val missSummary = misses.describe()

matches: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [id_1: string, id_2: string ... 10 more fields]
matchSummary: org.apache.spark.sql.DataFrame = [summary: string, id_1: string ... 10 more fields]
misses: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [id_1: string, id_2: string ... 10 more fields]
missSummary: org.apache.spark.sql.DataFrame = [summary: string, id_1: string ... 10 more fields]


2.10　DataFrame的转置和重塑

In [31]:
summary.printSchema()

root
 |-- summary: string (nullable = true)
 |-- id_1: string (nullable = true)
 |-- id_2: string (nullable = true)
 |-- cmp_fname_c1: string (nullable = true)
 |-- cmp_fname_c2: string (nullable = true)
 |-- cmp_lname_c1: string (nullable = true)
 |-- cmp_lname_c2: string (nullable = true)
 |-- cmp_sex: string (nullable = true)
 |-- cmp_bd: string (nullable = true)
 |-- cmp_bm: string (nullable = true)
 |-- cmp_by: string (nullable = true)
 |-- cmp_plz: string (nullable = true)



In [34]:
summary.schema

res24: org.apache.spark.sql.types.StructType = StructType(StructField(summary,StringType,true), StructField(id_1,StringType,true), StructField(id_2,StringType,true), StructField(cmp_fname_c1,StringType,true), StructField(cmp_fname_c2,StringType,true), StructField(cmp_lname_c1,StringType,true), StructField(cmp_lname_c2,StringType,true), StructField(cmp_sex,StringType,true), StructField(cmp_bd,StringType,true), StructField(cmp_bm,StringType,true), StructField(cmp_by,StringType,true), StructField(cmp_plz,StringType,true))


In [37]:
summary.select("summary", "id_1", "id_2").show()

+-------+--------------------+-------------------+
|summary|                id_1|               id_2|
+-------+--------------------+-------------------+
|  count|             5749133|            5749133|
|   mean|   33324.47979999771|  66587.42400114964|
| stddev|   23659.86139888655|  23620.50188438175|
|    min|0.000235404896421846|0.00147710487444609|
|    max|                9999|              99999|
+-------+--------------------+-------------------+



In [35]:
val schema = summary.schema
val longForm = summary.flatMap(row => {
    val metric = row.getString(0) //获取每行的第一个元素即指标名称
    (1 until row.size).map(i => {
        (metric,schema(i).name,row.getString(i).toDouble)
    })
})

schema: org.apache.spark.sql.types.StructType = StructType(StructField(summary,StringType,true), StructField(id_1,StringType,true), StructField(id_2,StringType,true), StructField(cmp_fname_c1,StringType,true), StructField(cmp_fname_c2,StringType,true), StructField(cmp_lname_c1,StringType,true), StructField(cmp_lname_c2,StringType,true), StructField(cmp_sex,StringType,true), StructField(cmp_bd,StringType,true), StructField(cmp_bm,StringType,true), StructField(cmp_by,StringType,true), StructField(cmp_plz,StringType,true))
longForm: org.apache.spark.sql.Dataset[(String, String, Double)] = [_1: string, _2: string ... 1 more field]


In [40]:
longForm.show(5)

+-----+------------+---------+
|   _1|          _2|       _3|
+-----+------------+---------+
|count|        id_1|5749133.0|
|count|        id_2|5749133.0|
|count|cmp_fname_c1|5748126.0|
|count|cmp_fname_c2| 103699.0|
|count|cmp_lname_c1|5749133.0|
+-----+------------+---------+
only showing top 5 rows



In [42]:
val longDF = longForm.toDF("metric","field","value")
longDF.show(5)

+------+------------+---------+
|metric|       field|    value|
+------+------------+---------+
| count|        id_1|5749133.0|
| count|        id_2|5749133.0|
| count|cmp_fname_c1|5748126.0|
| count|cmp_fname_c2| 103699.0|
| count|cmp_lname_c1|5749133.0|
+------+------------+---------+
only showing top 5 rows



longDF: org.apache.spark.sql.DataFrame = [metric: string, field: string ... 1 more field]


In [43]:
val wideDF = longDF.
    groupBy("field").
    pivot("metric",Seq("count","mean")).
    agg(first("value"))

wideDF: org.apache.spark.sql.DataFrame = [field: string, count: double ... 1 more field]


In [44]:
wideDF.show(5)

+------------+---------+-------------------+
|       field|    count|               mean|
+------------+---------+-------------------+
|        id_2|5749133.0|  66587.42400114964|
|     cmp_plz|5736289.0|0.00552866147434343|
|cmp_lname_c1|5749133.0| 0.3156278513776009|
|cmp_lname_c2|   2465.0|  0.318296744405166|
|     cmp_sex|5749133.0| 0.9550012294607436|
+------------+---------+-------------------+
only showing top 5 rows



In [50]:
//Pivot.scala
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.first

def pivotSummary(desc: DataFrame): DataFrame = {
    val schema = desc.schema
    import desc.sparkSession.implicits._
    
    val lf = desc.flatMap(row => {
        val metric = row.getString(0)
        (1 until row.size).map(i => {
            (metric,schema(i).name,row.getString(i).toDouble)
        })
    }).toDF("metric","field","value")
    lf.groupBy("field").
        pivot("metric",Seq("count","mean","stddev","min","max")).
        agg(first("value"))
}

import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.first
pivotSummary: (desc: org.apache.spark.sql.DataFrame)org.apache.spark.sql.DataFrame


In [51]:
val matchSummaryT = pivotSummary(matchSummary)
val missSummaryT = pivotSummary(missSummary)

matchSummaryT: org.apache.spark.sql.DataFrame = [field: string, count: double ... 4 more fields]
missSummaryT: org.apache.spark.sql.DataFrame = [field: string, count: double ... 4 more fields]


2.11　DataFrame的连接和特征选择
`用 Spark SQL 来表示这些连接会更容易，特别是当待连接的表中有许多列名在两个表中都存在时`

In [52]:
matchSummaryT.createOrReplaceTempView("match_desc")
missSummaryT.createOrReplaceTempView("miss_desc")
spark.sql("""
    select a.field,a.count+b.count as total,a.mean - b.mean as delta
    from match_desc a inner join miss_desc b 
    on a.field = b.field
    where a.field not in ("id_1","id_2")
    order by delta desc, total desc
""").show()

+------------+---------+--------------------+
|       field|    total|               delta|
+------------+---------+--------------------+
|     cmp_plz|5736289.0|  0.9563812499852176|
|cmp_lname_c2|   2464.0|  0.8064147192926266|
|      cmp_by|5748337.0|  0.7762059675300512|
|      cmp_bd|5748337.0|   0.775442311783404|
|cmp_lname_c1|5749132.0|  0.6838772482594513|
|      cmp_bm|5748337.0|  0.5109496938298685|
|cmp_fname_c1|5748125.0|  0.2854529057459947|
|cmp_fname_c2| 103698.0| 0.09104268062280174|
|     cmp_sex|5749132.0|0.032408185250332844|
+------------+---------+--------------------+



2.12　为生产环境准备模型

In [48]:
case class MatchData(
  id_1: Int,
  id_2: Int,
  cmp_fname_c1: Option[Double],
  cmp_fname_c2: Option[Double],
  cmp_lname_c1: Option[Double],
  cmp_lname_c2: Option[Double],
  cmp_sex: Option[Int],
  cmp_bd: Option[Int],
  cmp_bm: Option[Int],
  cmp_by: Option[Int],
  cmp_plz: Option[Int],
  is_match: Boolean
)

defined class MatchData


In [49]:
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types._
val cusSchema = StructType(Array(
    StructField("id_1",IntegerType,true),
    StructField("id_2",IntegerType,true),
    StructField("cmp_fname_c1",DoubleType,true),
    StructField("cmp_fname_c2",DoubleType,true),
    StructField("cmp_lname_c1",DoubleType,true),
    StructField("cmp_lname_c2",DoubleType,true),
    StructField("cmp_sex",IntegerType,true),
    StructField("cmp_bd",IntegerType,true),
    StructField("cmp_bm",IntegerType,true),
    StructField("cmp_by",IntegerType,true),
    StructField("cmp_plz",IntegerType,true),
    StructField("is_match",BooleanType,true)
))

val df = spark.read.
    option("header","true").
    option("nullValue","?").
    schema(cusSchema).
    csv("data/linkage/*.csv")

import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types._
cusSchema: org.apache.spark.sql.types.StructType = StructType(StructField(id_1,IntegerType,true), StructField(id_2,IntegerType,true), StructField(cmp_fname_c1,DoubleType,true), StructField(cmp_fname_c2,DoubleType,true), StructField(cmp_lname_c1,DoubleType,true), StructField(cmp_lname_c2,DoubleType,true), StructField(cmp_sex,IntegerType,true), StructField(cmp_bd,IntegerType,true), StructField(cmp_bm,IntegerType,true), StructField(cmp_by,IntegerType,true), StructField(cmp_plz,IntegerType,true), StructField(is_match,BooleanType,true))
df: org.apache.spark.sql.DataFrame = [id_1: int, id_2: int ... 10 more fields]


In [50]:
val matchData = df.as[MatchData]

matchData: org.apache.spark.sql.Dataset[MatchData] = [id_1: int, id_2: int ... 10 more fields]


In [52]:
matchData.show(2)

+-----+-----+------------+------------+------------+------------+-------+------+------+------+-------+--------+
| id_1| id_2|cmp_fname_c1|cmp_fname_c2|cmp_lname_c1|cmp_lname_c2|cmp_sex|cmp_bd|cmp_bm|cmp_by|cmp_plz|is_match|
+-----+-----+------------+------------+------------+------------+-------+------+------+------+-------+--------+
| 3148| 8326|         1.0|        null|         1.0|        null|      1|     1|     1|     1|      1|    true|
|14055|94934|         1.0|        null|         1.0|        null|      1|     1|     1|     1|      1|    true|
+-----+-----+------------+------------+------------+------------+-------+------+------+------+-------+--------+
only showing top 2 rows



matchData 这个 Dataset 中所有的列和值与 parsed 这个 DataFrame 中的数据是一样的，我们仍然可以对 matchData 使用所有 SQL 风格的 DataFrame API 方法以及 Spark SQL 代码。两者之间的主要区别是，当我们对 matchData 调用函数时，例如 map、flatMap 和 filter，我们处理的是 MatchData 这个 case 类，而不是 Row 类。

In [53]:
case class Score(value: Double){
    def +(oi: Option[Int])={
        Score(value + oi.getOrElse(0))
    }
}

defined class Score


In [58]:
def scoreMatchData(md: MatchData): Double = {
    (Score(md.cmp_lname_c1.getOrElse(0.0)) +
     md.cmp_plz + md.cmp_by + md.cmp_bd + md.cmp_bm).value
}

scoreMatchData: (md: MatchData)Double


In [55]:
⚠️scoreMatchData(matchData)

<console>: 64: error: type mismatch;

In [69]:
val scored = matchData.map{md => 
(scoreMatchData(md),md.is_match)
}.toDF("score","is_match")

scored: org.apache.spark.sql.DataFrame = [score: double, is_match: boolean]


2.13　评估模型

In [110]:
def crossTabs(scored: DataFrame, t: Double): DataFrame = {
  scored.
    selectExpr(s"score >= $t as above", "is_match").
    groupBy("above").
    pivot("is_match", Seq("true", "false")).
    count()
}

crossTabs: (scored: org.apache.spark.sql.DataFrame, t: Double)org.apache.spark.sql.DataFrame


In [73]:
crossTabs(scored, 4.0).show()

org.apache.spark.SparkException:  Job aborted due to stage failure: Task 1 in stage 57.0 failed 1 times, most recent failure: Lost task 1.0 in stage 57.0 (TID 1686, localhost, executor driver): java.lang.ClassCastException: $iw cannot be cast to $iw

### Another Practice

In [103]:
df.printSchema()

root
 |-- id_1: integer (nullable = true)
 |-- id_2: integer (nullable = true)
 |-- cmp_fname_c1: double (nullable = true)
 |-- cmp_fname_c2: double (nullable = true)
 |-- cmp_lname_c1: double (nullable = true)
 |-- cmp_lname_c2: double (nullable = true)
 |-- cmp_sex: integer (nullable = true)
 |-- cmp_bd: integer (nullable = true)
 |-- cmp_bm: integer (nullable = true)
 |-- cmp_by: integer (nullable = true)
 |-- cmp_plz: integer (nullable = true)
 |-- is_match: boolean (nullable = true)



In [107]:
df.createOrReplaceTempView("matchdataframe")

In [108]:
val result = spark.sql("""
    select cmp_lname_c1 + cmp_plz + cmp_by + cmp_bd + cmp_bm as score,
    is_match
    from matchdataframe
""")

result: org.apache.spark.sql.DataFrame = [score: double, is_match: boolean]


In [111]:
crossTabs(result, 4.0).show()

+-----+-----+-------+
|above| true|  false|
+-----+-----+-------+
| null|   35|  13603|
| true|20842|    636|
|false|   54|5713962|
+-----+-----+-------+



- [CASE CLASSES](https://docs.scala-lang.org/zh-cn/tour/case-classes.html)
- [Create a Dataset from an RDD](https://docs.databricks.com/spark/latest/dataframes-datasets/introduction-to-datasets.html)

In [1]:
case class Book(isbn: String)
val frankenstein = Book("978-0486282114")
frankenstein.isbn

Intitializing Scala interpreter ...

Spark Web UI available at http://172.16.8.92:4040
SparkContext available as 'sc' (version = 2.4.4, master = local[*], app id = local-1582682941971)
SparkSession available as 'spark'


defined class Book
frankenstein: Book = Book(978-0486282114)


//Create a Dataset from a DataFrame


In [47]:
import org.apache.spark.sql.DataFrame
import spark.implicits._

case class Company(name: String, foundingYear: Int, numEmployees: Int)
val inputSeq = Seq(Company("ABC",1980,310),
                   Company("XYZ",1983,904),
                   Company("NOP",2005,83))

val df = inputSeq.toDS().toDF()
// val df = sc.parallelize(inputSeq).toDF() //←报错
val companyDS = df.as[Company]
companyDS.show()

+----+------------+------------+
|name|foundingYear|numEmployees|
+----+------------+------------+
| ABC|        1980|         310|
| XYZ|        1983|         904|
| NOP|        2005|          83|
+----+------------+------------+



import org.apache.spark.sql.DataFrame
import spark.implicits._
defined class Company
inputSeq: Seq[Company] = List(Company(ABC,1980,310), Company(XYZ,1983,904), Company(NOP,2005,83))
df: org.apache.spark.sql.DataFrame = [name: string, foundingYear: int ... 1 more field]
companyDS: org.apache.spark.sql.Dataset[Company] = [name: string, foundingYear: int ... 1 more field]
