# Update batching

This notebook explores the concept of update batching using a very simple example

First, include some libraries

In [None]:
# Begin - startup boilerplate code

import pkgutil

if 'fibertree_bootstrap' not in [pkg.name for pkg in pkgutil.iter_modules()]:
  !python3 -m pip  install git+https://github.com/Fibertree-project/fibertree-bootstrap --quiet

# End - startup boilerplate code


from fibertree_bootstrap import *
fibertree_bootstrap(style="tree", animation="movie")

## Create a simple graph

The graph consists of a vertex vector (__vtx__) and an adjacency matrix (__g_SD__). 

In [None]:
S = 10
D = S

vtx = Tensor.fromRandom(["S"], [S], (1.0,), 9, seed=10)
#vtx = Tensor.fromUncompressed(["S"], [1]*S)
vtx.setMutable(True).setColor("blue")
vtx.setName("vertices")

g_SD = Tensor.fromRandom(["S", "D"], [S, D], (1.0, 0.28), 1, seed=100)
g_SD.setMutable(True).setColor("green")
g_SD.setName("graph-SD")

displayTensor(vtx)
displayTensor(g_SD)
                         
print("Graph")
displayGraph(g_SD.getRoot())


## Basic algorithm - Source Stationary

In this notebook, we use a simple example where we update the value of each vertex (in __vtx__) with the sum of the values of all its incoming neighbor vertices. Following is a source-stationary (push) dataflow via a concordant traversal through the adjacency matrix (__g_SD__) to do that. Note the widely scattered writes to the new vertices (in __vtx_new__), i.e., a discordant traversal.

In [None]:
vtx_new = Tensor(rank_ids=["S"])
vtx_new.setColor("blue")
vtx_new.setName("vertices_new")

vtx_s = vtx.getRoot()
g_SD_s = g_SD.getRoot()
vtx_new_d = vtx_new.getRoot()

canvas = createCanvas(vtx, g_SD, vtx_new)

for s, (vtx_val, g_SD_d) in vtx_s & g_SD_s:
    for d, (vtx_new_ref, _) in vtx_new_d << g_SD_d:
        vtx_new_ref += vtx_val
        canvas.addFrame((s, ), (s, d), (d,))

displayTensor(vtx_new)
displayCanvas(canvas)

## Basic algorithm - Destination stationary

Just for the record, here is the same algorithm in a destination stationary (pull) form. To faciliate concordant traversal of the adjacency matrix, we create a rank swapped version (__g_DS__). Not the discordant traversal of the original vertex matrix (__vtx__) for this dataflow.

In [None]:
# Create destination-to-source adjacency matrix

g_DS = g_SD.swapRanks()
g_DS.setName("graph-DS")

displayTensor(g_DS)

In [None]:
vtx_new = Tensor(rank_ids=["S"])
vtx_new.setColor("blue")
vtx_new.setName("vertices_new")

vtx_s = vtx.getRoot()
g_DS_d = g_DS.getRoot()
vtx_new_d = vtx_new.getRoot()

canvas = createCanvas(vtx, g_DS, vtx_new)

for d, (vtx_new_ref, g_DS_s) in vtx_new_d << g_DS_d:
    for s, (vtx_val, _) in vtx_s & g_DS_s:
        
        vtx_new_ref += vtx_val
        canvas.addFrame((s, ), (d, s), (d,))

displayTensor(vtx_new)
displayCanvas(canvas)

## Functions used in the following steps

To generalize the implementation of algorithms that use update batching, the dataflows call these functions to calculate a value to be reduced (__generateShard()__) and a value to do a final update of each vertex (__updateValue()__). The very simple functions below implement the simple algorithm described above, and could be replaced to implement different algorithms. 

In [None]:

#
# The function for the first step
#
def generateShard(vtx_val):
    """Create a value to be reduced later"""
    
    return vtx_val

#
# The function for the final step
#

def updateValue(vtx_val_old, vtx_val_new):
    """Update the vertex value"""
    
    return vtx_val_new

## Step 1 - Update batch sequence

Do a (concordant) source-stationary traversal of the vertices (__vtx__) and adjacency matrix (__g_SD__), but instead of trying to directly update each destination vertex (a discordant traversal), just log the destination vertex id and source vertex value (as a tuple) in a rank-2 tensor (__bins__). When logging, select a bin id (top rank coordinate) by a partitioning of the destination vertex ids (e.g., divide destination vertex id (coordinate) by 2). Note that the additions to the fibers in the lower rank of __bins__ is always just at the end of the fibers.

In [None]:
coordinates_per_bin = 2

bins = Tensor(rank_ids=["B", "N"])
bins.setColor("purple").setName("bins")

vtx_s = vtx.getRoot()
g_SD_s = g_SD.getRoot()
bins_b = bins.getRoot()

canvas = createCanvas(vtx, g_SD, bins)
n = 0

for s, (vtx_val, g_SD_d) in vtx_s & g_SD_s:
    for d, _ in g_SD_d:
        n += 1
        
        b = d//coordinates_per_bin
        bins_n = bins_b.getPayloadRef(b)
        
        bins_n.append(n, (d,  generateShard(vtx_val)))
        canvas.addFrame([(s, )], [(s, d)], [(b, n)])
        
displayTensor(bins)
displayCanvas(canvas)


## Step 2  - Replay the log

Do a concordant traversal of all the elements of the log tensor (__bins__) and reduce on a per vertex id basis using the logged vertex id and value into a new vertex vector (__vtx_new__). Although the traversal of the new vertex vector will be discordant, while processing each bin (top rank coordinate of __bins__) the range of active source vertex ids (coordinates) will be quite small.

In [None]:
vtx_new = Tensor(rank_ids=["S"])
vtx_new.setColor("blue")
vtx_new.setName("vertices_new")

bins_b = bins.getRoot()
vtx_new_s = vtx_new.getRoot()

canvas = createCanvas(bins, vtx_new)

for b, bins_n in bins_b:
    for n, (s, val) in bins_n:
        vtx_new_s = vtx_new.getPayloadRef(s)
        vtx_new_s += val
        canvas.addFrame((b, n), (s,))
        
        
displayTensor(vtx_new)
displayCanvas(canvas)
        

## Step 2  - Replay the log - with shortcuts

Given the nice pattern of the values returned by the getPayloadRef() method call exhibited by the above dataflow, one can use a shortcut to optimize the search for the desired coordinate in the __getPayloadRef()__ call by using the "start_pos" shortcut. The following cell displays a control to enable or disable the use of the shortcut for the following log replay dataflow.

In [None]:
createEnableControl("Use shortcut")

In [None]:

vtx_new = Tensor(rank_ids=["S"])
vtx_new.setColor("blue")
vtx_new.setName("vertices_new")

bins_b = bins.getRoot()
vtx_new_s = vtx_new.getRoot()

canvas = createCanvas(bins, vtx_new)

next_start_pos = 0

for b, bins_n in bins_b:
    start_pos = next_start_pos
    
    for n, (s, val) in bins_n:
        vtx_new_ref = vtx_new_s.getPayloadRef(s, start_pos=start_pos)
        if enable["Use shortcut"]:
            next_start_pos = max(next_start_pos, vtx_new_s.getSavedPos())
        
        vtx_new_ref += val
        canvas.addFrame((b, n), (s,))
        
(n, distance) = vtx_new_s.getSavedPosStats()
print(f"Average search distance = {distance/n:4.2f}")
      
displayTensor(vtx_new)
displayCanvas(canvas)
        

## Step 3 - New vertex value mapping

Sometimes there is an algorithmic step to update the orginal vertex values with a function of the original and new vertex values. Note that for iterative algorithms this step can be fused with the next update batch sequence (step 1). If processed separately it is a simple (concordant) traveral of the original and new vertex values. Note, we copy the orginal vertex tensor (__vtx__) into a new tensor (__vtx_copy__) to hold the updated vertices to avoid clobering the original vertex tensor...

In [None]:
import copy

vtx_copy = copy.deepcopy(vtx)

vtx_s = vtx_copy.getRoot()
vtx_new_s = vtx_new.getRoot()

canvas = createCanvas(vtx_new, vtx_copy)

for s, (vtx_ref, vtx_new_val) in vtx_s << vtx_new_s:
    vtx_ref <<= updateValue(vtx_ref, vtx_new_val)
    canvas.addFrame((s,), (s,))

displayTensor(vtx_copy)
displayCanvas(canvas)

## Testing area

For running alternative algorithms

In [None]:
enable