In [1]:
import os
os.environ['PYSPARK_PYTHON'] = '/nfshome/lj1230/.conda/envs/myEnv/bin/python3.5'

from pyspark import SparkContext
sc = SparkContext('local', 'pyspark')

from pyspark.sql.session import SparkSession
spark = SparkSession(sc)

In [49]:
def createIndex(shapefile):
    import rtree
    import fiona.crs
    import geopandas as gpd
    zones = gpd.read_file(shapefile).to_crs(fiona.crs.from_epsg(5070))
    index = rtree.Rtree()
    for idx, geometry in enumerate(zones.geometry):
        index.insert(idx, geometry.bounds)
    return (index, zones)

def findZone(p, index, zones):
    match = index.intersection((p.x, p.y, p.x, p.y))
    for idx in match:
        if zones.geometry[idx].contains(p):
            return zones.plctract10[idx], zones.plctrpop10[idx]
    return None

def processTweets(pid, records):
    import re
    import pyproj
    import shapely.geometry as geom
    
    pattern = re.compile("\w+")
    proj = pyproj.Proj(init="epsg:5070", preserve_units=True)
    index, zones = createIndex("500cities_tracts.geojson")
    drug_set = set(open('drug_illegal.txt', 'r').read().split("\n")) | set(open('drug_sched2.txt', 'r').read().split("\n"))
    drug_wor = {e for e in drug_set if " " not in e}
    drug_pha = {e for e in drug_set if " " in e}
    
    counts = {}
    for record in records:
        flag = 0
        row = record.strip().split(",")
        if len(set(row[-1].split(" ")) & drug_wor) > 0:  # First check words
            flag = 1
        else:  # if no words then check phrases
            try:
                words = pattern.findall(row[-2].lower())
            except:
                continue
            length = len(words)
            if length > 1:
                phrases = set()
                for i in range(2, min(9, length + 1)):  # Longest length of possible phrases is 8
                    for j in range(len(words) - i + 1):
                        phrases.add(" ".join(words[j:j + i]))
                if len(phrases & drug_pha) > 0:
                    flag = 1
        if flag == 1:
            try:
                p = geom.Point(proj(float(row[3]), float(row[2])))
                zone_id, zone_pop = findZone(p, index, zones)
            except:
                continue
            if zone_id and zone_pop > 0:
                counts[zone_id] = counts.get(zone_id, 0.0) + 1.0 / zone_pop
    return counts.items()

if __name__=="__main__":
    counts = sc.textFile("tweets-sample.csv").mapPartitionsWithIndex(processTweets) \
                                             .reduceByKey(lambda x, y: x + y).sortBy(lambda x: x[0]).collect()
counts

[('0107000-01073005800', 0.0009487666034155598),
 ('0137000-01089000202', 0.00022732439190725165),
 ('0137000-01089001402', 0.00020627062706270627),
 ('0137000-01089002801', 0.00025680534155110427),
 ('0137000-01089010622', 0.00022922636103151864),
 ('0137000-01089010702', 0.001366120218579235),
 ('0137000-01089011200', 0.0001590836780146357),
 ('0150000-01097000500', 0.0005509641873278236),
 ('0150000-01097002900', 0.0002727024815925825),
 ('0150000-01097003100', 0.0002288329519450801),
 ('0150000-01097003205', 0.000281135788585887),
 ('0150000-01097003605', 0.0012642225031605564),
 ('0150000-01097003606', 0.00032),
 ('0150000-01097003607', 0.0005256241787122207),
 ('0151000-01101001400', 0.000552791597567717),
 ('0151000-01101001500', 0.0005319148936170213),
 ('0151000-01101002400', 0.00099601593625498),
 ('0151000-01101002600', 0.000167897918065816),
 ('0151000-01101002800', 0.00047393364928909954),
 ('0151000-01101003302', 0.00015683814303638644),
 ('0151000-01101005402', 0.0001768