In [1]:
import pyspark,ogr,datetime

### 读取中国每个省

In [2]:
driver = ogr.GetDriverByName('ESRI Shapefile')
dataSource = driver.Open("./data/cn.shp", 0)
layer = dataSource.GetLayerByIndex(0)

## 构造一个用来计算的几何结构，注意这里用wkt，是因为Python无法序列化GDAL的对象为一个列表

In [3]:
featArr = []
layer.ResetReading()
for f in layer:
    featArr.append((f.GetField("FIRST_NAME"),
                    f.geometry().ExportToIsoWkt()))

In [4]:
print(featArr[0][0],featArr[0][1][:300])

北京 POLYGON ((117.201635735 40.077285175,117.188380491 40.0635675290001,117.172605018 40.047241738,117.17614785 40.059573125,117.176306955 40.0598749100001,117.177471185 40.0619996300001,117.182759255 40.0684182210001,117.17511575 40.0703352350001,117.091653325 40.075145455,117.079871615 40.075415475000


In [5]:
sc = pyspark.SparkContext()

In [6]:
def isPoint(line):
    pntline = line.split(",")
    try:
        wkt = "POINT({0} {1})".format(float(pntline[2]),
                                      float(pntline[3]))
        geom = ogr.CreateGeometryFromWkt(wkt)
        return geom.IsValid()
    except:
        return False

In [7]:
rdd = sc.textFile("./data/eq2013.csv"
                 ).filter(lambda line : isPoint(line))

In [8]:
rdd.take(5)

['2013/5/31,23:41:56,-113.408,37.175,6.6,2.5,ML2.5,SLC,,UTAH,',
 '2013/5/31,23:09:05,-113.411,37.178,6,2.5,ML2.5,SLC,,UTAH,',
 '2013/5/31,22:45:34,-113.413,37.172,4,2.9,ML2.9,SLC,,UTAH,',
 '2013/5/31,22:34:26,-113.414,37.174,3.2,2.8,ML2.8,SLC,,UTAH,',
 '2013/5/31,22:34:02,-178.08,51.127,26,3.1,ML3.1,AEIC,,ANDREANOF ISLANDS, ALEUTIAN IS.']

### map方法，构造一个结构：
### 如果这个省包含了点，则为(省名，数量1)
### 不在中国任何一个省里面，则为(other，数量1)

In [9]:
def myMap(pnt,featArr):
    for feat in featArr:
        key = feat[0]
        geo = ogr.CreateGeometryFromWkt(feat[1])
        pntline = pnt.split(",")
        wkt = "POINT({0} {1})".format(float(pntline[2]),
                                      float(pntline[3]))
        pntGeom = ogr.CreateGeometryFromWkt(wkt)
        if geo.Contains(pntGeom):
            return (key,1)
    return ("other",1)

In [10]:
maprdd = rdd.map(lambda line: myMap(line,featArr))

In [11]:
print(maprdd.take(100))

[('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('西藏', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other', 1), ('other'

### reduceByKey是分组聚合

In [12]:
s = datetime.datetime.now()
res = maprdd.reduceByKey(lambda x,y : x+y).collect()
print(datetime.datetime.now() -s)

0:00:48.190894


### 计算结果

In [13]:
print(res)

[('other', 8042), ('云南', 10), ('四川', 69), ('青海', 17), ('山东', 1), ('贵州', 1), ('西藏', 23), ('新疆', 20), ('内蒙古', 4), ('台湾', 12), ('广东', 1), ('广西', 1), ('甘肃', 2), ('辽宁', 1)]


# 先把不是中国范围内的数据过滤掉

### 用gdal获取到中国的extent

In [16]:
cnext = layer.GetExtent()
cnext

(73.61689383500004, 135.08727119000002, 18.278927775000056, 53.56026110500005)

### 写一个过滤方法，过滤中国范围以外的数据

In [17]:
def filterExtent(line,cnext):
    pntline = line.split(",")
    x = float(pntline[2])
    y = float(pntline[3])
    if x >= cnext[0] and x <= cnext[1] \
    and y >= cnext[2] and y <= cnext[3]:
        return True
    else:
        return False

In [19]:
maprdd2 = rdd.filter(lambda line:filterExtent(line,cnext))\
.map(lambda line: myMap(line,featArr))

### 速度显著提升

In [20]:
s = datetime.datetime.now()
res = maprdd.reduceByKey(lambda x,y : x+y).collect()
print(datetime.datetime.now() -s)

0:00:06.146958


In [21]:
print(res)

[('other', 232), ('云南', 10), ('四川', 69), ('青海', 17), ('山东', 1), ('贵州', 1), ('西藏', 23), ('新疆', 20), ('内蒙古', 4), ('台湾', 12), ('广东', 1), ('广西', 1), ('甘肃', 2), ('辽宁', 1)]
