In [1]:
import pyspark
# sc = pyspark.SparkContext(pyspark.SparkConf().setMaster("local[12]"))
import numpy as np
import time
from collections import namedtuple
tile_id = namedtuple('tile_id', ['img_id', 'z', 'x', 'y', 'type'])
tile_item = namedtuple('tile_item', ['request' ,'target_tile_id', 'source_tile_id', 'source_tile_data', 'out_tile_data'])

In [2]:
IM_W, IM_H = 60000, 60000
IM_W, IM_H = 6000, 6000
TILE_W, TILE_H = 2000, 2000
TILES_PER_IMAGE = int(IM_W / TILE_W * IM_H / TILE_H)

def stack(arr_list, axis = 0): 
    """
    since numpy 1.8.2 does not have the stack command
    """
    assert axis == 0, "Only works for axis 0"
    return np.vstack(map(lambda x: np.expand_dims(x,0), arr_list))

def get_output_tiles(req_str):
    return [(req_str,tile_id(req_str, -1, x, y,"out")) for x in np.arange(0, IM_W, TILE_W)
            for y in np.arange(0, IM_H, TILE_H)]

def calc_input_tiles(o_tile, n_size = 1, delay = 0):
    """
    A simple 3x3 neighborhood with a delay second delay
    """
    time.sleep(delay)
    return [(o_tile, o_tile._replace(
                type = "in",
                x=ix, 
                y=iy)) 
            for ix in np.arange(o_tile.x-n_size*TILE_W,o_tile.x+(n_size+1)*TILE_W, TILE_W) if ix>=0
            for iy in np.arange(o_tile.y-n_size*TILE_H,o_tile.y+(n_size+1)*TILE_H, TILE_H) if iy>=0]

def pull_input_tile(i_tile, delay = 2.5):
    """
    Generates a random, but unique 8 bit input tile given the source tile id i_tile
    """
    assert i_tile.type.find("in")==0
    time.sleep(delay)
    np.random.seed(i_tile.__hash__() % 4294967295) # should make process deterministic
    return np.random.uniform(-127, 127, size = (TILE_W, TILE_H, 4)).astype(np.int16)

RECON_TIME = 1.0
def partial_reconstruction(src_tile, targ_tile, tile_data, delay = RECON_TIME):
    time.sleep(delay)
    out_data = tile_data.copy()
    out_data[out_data>20] = 0
    return out_data

def combine_reconstructions(many_slices):
    """
    Bring a number of partial reconstructions together
    """
    return np.sum(stack(many_slices,0),0)

def full_reconstruction(src_tiles, pr_delay = RECON_TIME/2):
    """
    Run the full reconstruction as one step
    """
    out_images = [partial_reconstruction(src.source_tile_id, src.target_tile_id, src.source_tile_data, delay = pr_delay) for src in src_tiles]
    return combine_reconstructions(out_images)
    

In [3]:
req_rdd = sc.parallelize(['test1', 'test2'])
out_tile_rdd = req_rdd.flatMap(get_output_tiles).repartition(100)
all_tile_rdd = out_tile_rdd.flatMapValues(calc_input_tiles).map(lambda x: tile_item(x[0], x[1][0], x[1][1], None, None))

In [4]:
print(all_tile_rdd.first(), all_tile_rdd.count())

# Simple / Naive Approach
The Simple Naive Approach does not attempt to group together operations reading the same input times and just runs each independently

In [6]:
# parallel reading of the data
all_tile_rdd_data = all_tile_rdd.map( lambda i: i._replace(source_tile_data = pull_input_tile(i.source_tile_id)))
def ti_full_reconstruction(x):
    k, src_tiles = x
    return k, full_reconstruction(src_tiles)
# parallel combining of the tiles
recon_tiles_rdd = all_tile_rdd_data.groupBy(lambda x: x.target_tile_id).map(ti_full_reconstruction)

In [7]:
%%time
all_shapes = recon_tiles_rdd.mapValues(lambda x: x.shape).collect()
print('All Results', len(all_shapes))

# Grouped Read Approach

In [9]:
def grp_tile_read(x):
    src_tile_id, tile_items = x
    tile_data = pull_input_tile(src_tile_id)
    return [i._replace(source_tile_data = tile_data) for i in tile_items]
single_read_tiles_rdd = all_tile_rdd.groupBy(lambda x: x.source_tile_id).flatMap(grp_tile_read)
gr_recon_tiles_rdd = single_read_tiles_rdd.groupBy(lambda x: x.target_tile_id).map(ti_full_reconstruction)

In [10]:
%%time
all_shapes = gr_recon_tiles_rdd.mapValues(lambda x: x.shape).collect()
print('All Results', len(all_shapes))

# Second Approach
Partial Results to Final Results

In [12]:
def ti_partial_reconstruction(in_tile_item):
    return in_tile_item._replace(
        out_tile_data = partial_reconstruction(in_tile_item.source_tile_id,
                                               in_tile_item.target_tile_id,
                                               in_tile_item.source_tile_data),
        source_tile_data = None # throw out the old data
        )

def ti_partial_collect(x):
    k, in_tile_items = x # don't need the key
    assert len(in_tile_items)>0, "Cannot provide empty partial collecton set"
    iti_list = list(in_tile_items)
    all_part_recon = map(lambda x: x.out_tile_data, in_tile_items)
    
    return (k, combine_reconstructions(all_part_recon))

# group together all files by source tile and then read that source tile and put in into every item
single_read_tiles_rdd = all_tile_rdd.groupBy(lambda x: x.source_tile_id).flatMap(grp_tile_read)
# run a partial reconstruction on every item
partial_recon_tiles_rdd = single_read_tiles_rdd.map(ti_partial_reconstruction)
# combine all the partial reconstructions to create the final reconstruction
full_recon_tiles_rdd = partial_recon_tiles_rdd.groupBy(lambda x: x.target_tile_id).map(ti_partial_collect)

In [13]:
%%time
all_shapes = full_recon_tiles_rdd.mapValues(lambda x: x.shape).collect()
print('All Results',len(all_shapes))

In [14]:
# check every element to ensure they are identical
for i, (x,y) in enumerate(zip(full_recon_tiles_rdd.first()[1].flatten(), 
                              recon_tiles_rdd.first()[1].flatten())):
    assert x == y, "Index {} doesnt match ({} != {})".format(i, x, y)

# Compressing Tiles and Output
Here we compress the tiles and then write the output

In [16]:
from io import BytesIO
def compress_tile(in_tile):
    out_stream = BytesIO()
    np.savez_compressed(out_stream, out_tile = in_tile)
    out_stream.seek(0) # restart at beginning
    return "\n".join(out_stream.readlines())


all_image_rdd = full_recon_tiles_rdd.mapValues(compress_tile).groupBy(lambda x: x[0].img_id)

In [17]:
%%time
all_image_rdd.saveAsPickleFile('big_image_out.jp2')

# Light Keys and Heavy Values
## DataFrame vs RDD

In [19]:
# Make a cached performant RDD to start with
x_to_tile = lambda x: tile_id("",0,int(x),int(x),"in")
keys_rdd = sc.parallelize(range(101),20).map(x_to_tile).cache()
_ = keys_rdd.collect()

# Using Standard RDDs
Using a standard RDD the entire data has to be loaded / processed in order to check the x value.

In [21]:
kimg_rdd = keys_rdd.map(lambda x: (x, pull_input_tile(x)))

In [22]:
%%time
kimg_rdd.filter(lambda kv_data: kv_data[0].x>99).collect()

# Using DataFrames
Using DataFrames the exact same query can be conducted without looking at the image column at all. Here the image column is only examined at the very end.

In [24]:
from pyspark.sql import functions as F
import pyspark.sql.types as sq_types
kmeta_df = sqlContext.createDataFrame(keys_rdd.map(lambda x: x._asdict()))
# applying python functions to DataFrames is more difficult and requires using typed UDFs
twod_arr_type = sq_types.ArrayType(sq_types.ArrayType(sq_types.IntegerType()))
# the pull_input_tile function is wrapped into a udf to it can be applied to create the new image column
# numpy data is not directly supported and typed arrays must be used instead therefor we run the .tolist command
pull_tile_udf = F.udf(lambda x: pull_input_tile(x_to_tile(x)).tolist(), returnType = twod_arr_type)
kimg_df = kmeta_df.withColumn('Image', pull_tile_udf(kmeta_df['x']))

In [25]:
%%time
s_query = kimg_df.where(kimg_df['x']>99)
s_query.show()

In [26]:
%%time
kimg_df.where(kimg_df['x']==27).show()

In [27]:
# show the array to make sure it matches
iv_arr = np.array(s_query.first().Image)
iv_arr

In [28]:
# test cartesian product later kmeta_df.join(s_query, on = ["x", "y"], how = 'inner').show()

In [29]:
s_query.explain()

In [30]:
pt_2d_type = sq_types.StructType(fields = [sq_types.StructField("_1", sq_types.IntegerType()), 
                     sq_types.StructField("_2", sq_types.IntegerType())])
brightest_point_udf = F.udf(lambda x: np.unravel_index(np.argmax(x), dims = np.shape(x)), returnType = pt_2d_type)
mean_point_udf = F.udf(lambda x: float(np.mean(x)), returnType = sq_types.DoubleType())
kimg_max_df = kimg_df.withColumn('MeanPoint', mean_point_udf(kimg_df['Image']))

In [31]:
%%time
four_img_query = kimg_max_df.where(kimg_max_df['x']>88).where(kimg_max_df['y']<92)
four_img_query.show()

In [32]:
four_img_query.explain()

In [33]:
kimg_max_df.registerTempTable("ImageTable")

In [34]:
kmeta_df.registerTempTable("MetaTable")