In [None]:
# Begin - startup boilerplate code

import pkgutil

if 'fibertree_bootstrap' not in [pkg.name for pkg in pkgutil.iter_modules()]:
  !python3 -m pip  install git+https://github.com/Fibertree-project/fibertree-bootstrap --quiet

# End - startup boilerplate code


from fibertree_bootstrap import *
fibertree_bootstrap(style="uncompressed", animation="spacetime")

## Initialize Sliders & default parameters

In [None]:

# Initial values
# Most cells below are modified from spMspM_pruned

M = 6 #defines M rank
N = 6 #defines N rank
K = 6 #defines K rank
density = [0.5, 1.0] #defines portion NNZ for [A, B]
interval = 5 #defines max value in A or B
seed = 10 #defines random seed
sample_rate = 0.4 # dictates sample threshold for portion of A

def set_params(rank_M, rank_N, rank_K, tensor_density, uniform_sample_rate, max_value, rand_seed):
    global M
    global N
    global K
    global density
    global seed
    global sample_rate
    global interval
    
    M = rank_M
    N = rank_N
    K = rank_K
    
    density = tensor_density[::-1]
        
    seed = rand_seed
    
    sample_rate = uniform_sample_rate
    
    interval = max_value

w = interactive(set_params,
             rank_M=widgets.IntSlider(min=2, max=10, step=1, value=M),
             rank_N=widgets.IntSlider(min=2, max=10, step=1, value=N),
             rank_K=widgets.IntSlider(min=2, max=10, step=1, value=K),
             tensor_density=widgets.FloatRangeSlider(min=0.0, max=1.0, step=0.05, value=[0.5, 1.0]),
             uniform_sample_rate=widgets.FloatSlider(min=0, max=1.0, step=0.05, value=sample_rate),
             max_value=widgets.IntSlider(min=0, max=20, step=1, value=interval),
             rand_seed=widgets.IntSlider(min=0, max=100, step=1, value=seed))

display(w)


## Input Tensors

In [None]:
a = Tensor.fromRandom(["M", "K"], [M, K], density, interval, seed=seed)
#a.setName("A")
a.setColor("blue")
displayTensor(a)

# Create swapped rank version of a
a_swapped = a.swapRanks()
#a_swapped.setName("A_swapped")
displayTensor(a_swapped)

b = Tensor.fromRandom(["N", "K"], [N, K], density, interval, seed=2*seed)
#b.setName("B")
b.setColor("green")
displayTensor(b)

# Create swapped rank version of b
b_swapped = b.swapRanks()
#b_swapped.setName("B_swapped")
displayTensor(b_swapped)


## Reference Output

In [None]:
z_validate = Tensor(rank_ids=["M", "N"], shape=[M, N])

a_m = a.getRoot()
b_n = b.getRoot()
z_m = z_validate.getRoot()

for m, (z_n, a_k) in z_m << a_m:
    for n, (z_ref, b_k) in z_n << b_n:
        for k, (a_val, b_val) in a_k & b_k:
            z_ref += a_val * b_val

displayTensor(z_validate)



def compareZ(z):
    
    n = 0
    total = 0
    
    z1 = z_validate.getRoot()
    z2 = z.getRoot()
    
    for m, (ab_n, z1_n, z2_n) in z1 | z2:
        for n, (ab_val, z1_val, z2_val) in z1_n | z2_n:
            # Unpack the values to use abs (arggh)
            z1_val = Payload.get(z1_val)
            z2_val = Payload.get(z2_val)
         
            n += 1
            total += abs(z1_val-z2_val)

    return total/n

## Prune Functions and helper functions

In [None]:
# MCMM uniform random sampling
class UniformRandomPrune():
    def __init__(self, sample_rate=0.5):
        
        self.sample_rate = sample_rate
    
    def __call__(self, n, c, p):
        
        sample = random.uniform(0,1)
        result = (sample < self.sample_rate)
         
        print(f"Preserve = {result}")

        return result
        
# MCMM sample against number of elements
class RandomSizePrune():
    def __init__(self, max_size=4):
        
        self.max_size = max_size
    
    def __call__(self, n, c, p):
        
        size = p.countValues()
        sample = random.uniform(0, 1)
        result = (sample < (size / self.max_size))
        
        print(f"Preserve = {result}")
        
        return result

# a cute recursive helper for getting total absolute magnitude of Fiber of arbitrary rank 
# this is modeled after countValues(), but I haven't tested it super thoroughly
# is this a helpful thing to add as a Fiber method? useful for computing matrix norms and stuff
def get_magnitude(p):
    mag = 0
    for el in p.payloads:
        if Payload.contains(el, Fiber):
            mag += get_magnitude(el)
        else:
            mag += np.absolute(el.v()) if not Payload.isEmpty(el) else 0
    return mag
    
# MCMM weight sample by sum of elements
class RandomMagnitudePrune():
    def __init__(self, max_mag=6):
        
        #as-written, max_mag and mag should be ints, not payloads
        self.max_mag = max_mag
    
    def __call__(self, n, c, p):

        magnitude = get_magnitude(p)
        sample = random.uniform(0, 1)
        result = (sample < (magnitude / self.max_mag))
        
        print(f"Preserve = {result}")
        
        return result
        
        

## Outer Product w/ MCMM; uniform sampling A

In [None]:
z = Tensor(rank_ids=["M", "N"], shape=[M, N])

sample_tensor = UniformRandomPrune(sample_rate)

a_k = a_swapped.getRoot()
b_k = b_swapped.getRoot()
z_m = z.getRoot()


canvas = createCanvas(a_swapped, b_swapped, z)

for k, (a_m, b_n) in a_k.prune(sample_tensor) & b_k:
    for m, (z_n, a_val) in z_m << a_m:
        for n, (z_ref, b_val) in z_n << b_n:
            z_ref += a_val * b_val
            canvas.addFrame((k, m), (k, n), (m, n))

print(f"Error = {compareZ(z)}")

displayTensor(z)
displayCanvas(canvas)

## Outer product, sample with threshold by num elements

In [None]:
z = Tensor(rank_ids=["M", "N"], shape=[M, N])


a_k = a_swapped.getRoot()
b_k = b_swapped.getRoot()
z_m = z.getRoot()

#traverse to get max elements per a_k
max_size = 0
for k, a_m in a_k:
    size = a_m.countValues()
    if size > max_size:
        max_size = size
print(max_size)

sample_tensor = RandomSizePrune(max_size)

canvas = createCanvas(a_swapped, b_swapped, z)

for k, (a_m, b_n) in a_k.prune(sample_tensor) & b_k:
    for m, (z_n, a_val) in z_m << a_m:
        for n, (z_ref, b_val) in z_n << b_n:
            z_ref += a_val * b_val
            canvas.addFrame((k, m), (k, n), (m, n))

print(f"Error = {compareZ(z)}")

displayTensor(z)
displayCanvas(canvas)

## Outer Product, sample with threshold by magnitude

In [None]:
z = Tensor(rank_ids=["M", "N"], shape=[M, N])


a_k = a_swapped.getRoot()
b_k = b_swapped.getRoot()
z_m = z.getRoot()

#traverse to get max magnitude per a_k
max_mag = 0
for k, a_m in a_k:
    mag = 0
    for m, a_val in a_m:
        mag += a_val
    if mag > max_mag:
        max_mag = mag

sample_tensor = RandomMagnitudePrune(max_mag.v()) #convert payload to value

canvas = createCanvas(a_swapped, b_swapped, z)

for k, (a_m, b_n) in a_k.prune(sample_tensor) & b_k:
    for m, (z_n, a_val) in z_m << a_m:
        for n, (z_ref, b_val) in z_n << b_n:
            z_ref += a_val * b_val
            canvas.addFrame((k, m), (k, n), (m, n))

print(f"Error = {compareZ(z)}")

displayTensor(z)
displayCanvas(canvas)