# Exploring Cartesian Product

First, include some libraries

In [None]:
# Begin - startup boilerplate code

import pkgutil

if 'fibertree_bootstrap' not in [pkg.name for pkg in pkgutil.iter_modules()]:
  !python3 -m pip  install git+https://github.com/Fibertree-project/fibertree-bootstrap --quiet

# End - startup boilerplate code


from fibertree_bootstrap import *
fibertree_bootstrap()

## Z_m = A_mk * B_mk

#### Load basic tensors

In [None]:
As = Tensor.fromUncompressed(["M", "K"], [[1, 0, 1, 1], [0, 0, 0, 0], [0, 0, 3, 3], [0, 0, 0, 0]])
Bs = Tensor.fromUncompressed(["M", "K"], [[1, 0, 0, 1], [0, 0, 0, 0], [0, 0, 0, 3], [4, 0, 4, 0]])
Zs_verify = Tensor.fromUncompressed(["M"],   [2,            0,            9,            0])


#### Untiled Baseline Traversal

In [None]:
Zs = Tensor(rank_ids = ["M"])

as_m = As.getRoot()
bs_m = Bs.getRoot()
zs_m = Zs.getRoot()

canvas = createCanvas(As, Bs, Zs)

for m, (z, (as_k, bs_k)) in zs_m << (as_m & bs_m):
  for k, (a, b) in as_k & bs_k:
    z += a * b
    canvas.addFrame((m, k), (m, k), (m,))

displayCanvas(canvas, width="50%")

#### Verify Result

In [None]:
Zs_verify == Zs

## Tiled K -- Offline

#### Load pre-tiled tensors

In [None]:
K0 = 2

As_tiled_K = Tensor.fromUncompressed(["K1", "M", "K0"], [[[1, 0], [0, 0], [0, 0], [0, 0]], [[1, 1], [0, 0], [3, 3], [0, 0]]])
Bs_tiled_K = Tensor.fromUncompressed(["K1", "M", "K0"], [[[1, 0], [0, 0], [0, 0], [4, 0]], [[0, 1], [0, 0], [0, 3], [4, 0]]])


#### Tiled traversal

In [None]:
Zs = Tensor(rank_ids = ["M"])

as_k1 = As_tiled_K.getRoot()
bs_k1 = Bs_tiled_K.getRoot()
zs_m  = Zs.getRoot()

canvas = createCanvas(As_tiled_K, Bs_tiled_K, Zs)
for k1, (as_m, bs_m) in as_k1 & bs_k1:
  for m, (z, (as_k0, bs_k0)) in zs_m << (as_m & bs_m):
    for k0, (a, b) in as_k0 & bs_k0:
      z += a * b
      canvas.addFrame((k1, m, k0), (k1, m, k0), (m,))
       
displayCanvas(canvas, width="50%")

#### Verify result

In [None]:
Zs_verify == Zs

## Tiled K -- Online, Monolithic, Separate

#### Define tensors for online tiling

In [None]:
As_tiled_K = Tensor(rank_ids = ["K1", "M", "K0"])
Bs_tiled_K = Tensor(rank_ids = ["K1", "M", "K0"])
Zs = Tensor(rank_ids = ["M"])

K0 = 2

#### Tile Tensor A

In [None]:
Zs = Tensor(rank_ids = ["M"])

canvas = createCanvas(As, As_tiled_K)

as_m = As.getRoot()
as_tiled_k1 = As_tiled_K.getRoot()

for (m, as_k) in as_m:
    for (k, a) in as_k:
        k1 = k // K0
        k0 = k %  K0
        as_tiled_k0 = as_tiled_k1.getPayloadRef(k1, m, k0)
        as_tiled_k0 <<= a
        canvas.addFrame((m, k), (k1, m, k0))

print(As_tiled_K.getRoot() == as_tiled_k1)
displayCanvas(canvas, width="50%")

#### Tile Tensor B

In [None]:
canvas = createCanvas(Bs, Bs_tiled_K)

bs_m = Bs.getRoot()
bs_tiled_k1 = Bs_tiled_K.getRoot()

for (m, bs_k) in bs_m:
    for (k, b) in bs_k:
        k1 = k // K0
        k0 = k % K0
    
        bs_tiled_k0 = bs_tiled_k1.getPayloadRef(k1, m, k0)
        bs_tiled_k0 <<= b
        canvas.addFrame((m, k), (k1, m, k0))

print(Bs_tiled_K.getRoot() == bs_tiled_k1)
displayCanvas(canvas, width="50%")

#### Tiled Traversal

In [None]:
Zs = Tensor(rank_ids = ["M"])

as_k1 = As_tiled_K.getRoot()
bs_k1 = Bs_tiled_K.getRoot()
zs_m  = Zs.getRoot()

canvas = createCanvas(As_tiled_K, Bs_tiled_K, Zs)

for k1, (as_m, bs_m) in as_k1 & bs_k1:
  for m, (z, (as_k0, bs_k0)) in zs_m << (as_m & bs_m):
    for k0, (a, b) in as_k0 & bs_k0:
      z += a * b
      canvas.addFrame((k1, m, k0), (k1, m, k0), (m,))
       
displayCanvas(canvas, width="50%")

#### Verify Result

In [None]:
Zs_verify == Zs

### Tiled K -- Online, Monolithic, Combined

#### Define tensors for online tiling (post-intersection)

In [None]:
ABs_tiled_K = Tensor(rank_ids = ["K1", "M", "K0"])
Zs = Tensor(rank_ids = ["M"])

K0 = 2

#### Co-Tile A and B

In [None]:
canvas = createCanvas(As, Bs, ABs_tiled_K)

as_m = As.getRoot()
bs_m = Bs.getRoot()
abs_tiled_k1 = ABs_tiled_K.getRoot()

for m, (as_k, bs_k) in as_m & bs_m:
  for k, (a, b) in as_k & bs_k:
    k1 = k // K0
    k0 = k %  K0
    #print("Inserting ({}, {}) as ({}, {}, {})".format(m,k,k1,m,k0))
    abs_tiled_k0 = abs_tiled_k1.getPayloadRef(k1, m, k0)
    abs_tiled_k0 <<= a*b
    canvas.addFrame((m, k), (m, k), (k1, m, k0))

displayCanvas(canvas, width="50%")

#### Tiled Traversal

In [None]:
Zs = Tensor(rank_ids = ["M"])

canvas = createCanvas(ABs_tiled_K, Zs)

abs_tiled_k1 = ABs_tiled_K.getRoot()
zs_m  = Zs.getRoot()

# NOTE: Worker loop no longer contains intersections!
for k1, abs_tiled_m in abs_tiled_k1:
  for m, (z, abs_tiled_k0) in zs_m << abs_tiled_m:
    for k0, ab in abs_tiled_k0:
      z += ab
      canvas.addFrame((k1, m, k0), (m,))

displayCanvas(canvas, width="50%")


#### Verify Result

In [None]:
Zs_verify == Zs

## Tiled -- Online, Incremental, Combined

#### Define workspace and current positions tensors

In [None]:
K0 = 2
K1 = 2 # XXX MAGIC FOR NOW

workspace = Tensor(rank_ids = ["K0"])

current_positions = Tensor(rank_ids = ["M"])
current_positions.setDefault((0,0))

Zs = Tensor(rank_ids = ["M"])

### Traverse and Tile Simultaneously

In [None]:
canvas = createCanvas(As, Bs, workspace, Zs)

as_m = As.getRoot()
bs_m = Bs.getRoot()
zs_m = Zs.getRoot()

workspace_k0 = workspace.getRoot()
current_positions_m = current_positions.getRoot()


for k1 in range(K1):  # TODO: improve this outer loop
  for m, (z, (pos_ref, (as_k, bs_k))) in zs_m << (current_positions_m << (as_m & bs_m)):
    
    workspace_k0.clear()
    
    # Get the starting positions
    (a_pos, b_pos) = pos_ref
    
    as_k0 = as_k.getRange(k1 * K0, K0, start_pos = a_pos)
    bs_k0 = bs_k.getRange(k1 * K0, K0, start_pos = b_pos)
    
    # Update the starting positions
    pos_ref <<= (as_k.getSavedPos(), bs_k.getSavedPos())
    
    # Tiling loop (with multiplication)
    for k, (a, b) in as_k0 & bs_k0:
      workspace.getRoot().append(k // K0, a * b)
      print("Inserting ({}, {})".format(m, k))
        
    # Reduction and update loop
    for k0, ab in workspace_k0:
      z += ab
      canvas.addFrame((m, k), (m, k), (k,), (m,))
      print("Working on ({}, {}, {})".format(k1, m, k0))

displayCanvas(canvas, width="50%")

#### Verify result

In [None]:
Zs_verify == Zs

## Testing area

For running alternative algorithms