In [None]:
# Begin - startup boilerplate code

import pkgutil

if 'fibertree_bootstrap' not in [pkg.name for pkg in pkgutil.iter_modules()]:
  !python3 -m pip  install git+https://github.com/Fibertree-project/fibertree-bootstrap --quiet

# End - startup boilerplate code


from fibertree_bootstrap import *
fibertree_bootstrap(style="uncompressed", animation="spacetime")

## Initialize Sliders & default parameters

In [None]:

# Initial values
# Most cells below are modified from spMspM_pruned

M = 4
N = 4
K = 4
density = [1.0, 1.0]
cutoff = 2 #reuse cutoff value for K rank budget in sampling methods
magnitude_thres = 1
interval = 5
seed = 10
sample_rate = 0.4 # dictates sample threshold for portion of A

def set_params(rank_M, rank_N, rank_K, tensor_density, thres_cutoff, mag_thres, uniform_sample_rate, max_value, rand_seed):
    global M
    global N
    global K
    global density
    global seed
    global cutoff
    global magnitude_thres
    global sample_rate
    global interval
    
    M = rank_M
    N = rank_N
    K = rank_K
    
    density = tensor_density[::-1]
        
    seed = rand_seed
    
    cutoff = thres_cutoff
    
    magnitude_thres = mag_thres
    
    sample_rate = uniform_sample_rate
    
    interval = max_value

w = interactive(set_params,
             rank_M=widgets.IntSlider(min=2, max=10, step=1, value=M),
             rank_N=widgets.IntSlider(min=2, max=10, step=1, value=N),
             rank_K=widgets.IntSlider(min=2, max=10, step=1, value=K),
             tensor_density=widgets.FloatRangeSlider(min=0.1, max=1.0, step=0.05, value=density),
             thres_cutoff=widgets.IntSlider(min=0, max=10, step=1, value=cutoff),
             mag_thres=widgets.IntSlider(min=0, max=10, step=1, value=magnitude_thres),
             uniform_sample_rate=widgets.FloatSlider(min=0, max=1.0, step=0.05, value=sample_rate),
             max_value=widgets.IntSlider(min=0, max=20, step=1, value=interval),
             rand_seed=widgets.IntSlider(min=0, max=100, step=1, value=seed))

display(w)


## Input Tensors

In [None]:
a = Tensor.fromRandom(["M", "K"], [M, K], density, interval, seed=seed)
a.setColor("blue")
displayTensor(a)

# Create swapped rank version of a
a_swapped = a.swapRanks()
displayTensor(a_swapped)

b = Tensor.fromRandom(["N", "K"], [N, K], density, interval, seed=2*seed)
b.setColor("green")
displayTensor(b)

# Create swapped rank version of b
b_swapped = b.swapRanks()
displayTensor(b_swapped)


## Reference Output

In [None]:
z_validate = Tensor(rank_ids=["M", "N"], shape=[M, N])

a_m = a.getRoot()
b_n = b.getRoot()
z_m = z_validate.getRoot()

for m, (z_n, a_k) in z_m << a_m:
    for n, (z_ref, b_k) in z_n << b_n:
        for k, (a_val, b_val) in a_k & b_k:
            z_ref += a_val * b_val

displayTensor(z_validate)



def compareZ(z):
    
    n = 0
    total = 0
    
    z1 = z_validate.getRoot()
    z2 = z.getRoot()
    
    for m, (ab_n, z1_n, z2_n) in z1 | z2:
        for n, (ab_val, z1_val, z2_val) in z1_n | z2_n:
            # Unpack the values to use abs (arggh)
            z1_val = Payload.get(z1_val)
            z2_val = Payload.get(z2_val)
         
            n += 1
            total += abs(z1_val-z2_val)

    return total/n

## Prune Functions and helper functions

In [None]:
# Threshold (number of elements > cutoff) prune
class ThresholdPrune():
    def __init__(self, threshold=2):
        
        self.threshold = threshold
        
    def __call__(self, n, c, p):
        
        size = p.countValues()
        result = size > self.threshold
        
        print(f"Preserve = {result}")
        
        return result
    
# MCMM uniform random sampling
class UniformRandomPrune():
    def __init__(self, sample_rate=0.5):
        
        self.sample_rate = sample_rate
    
    def __call__(self, n, c, p):
        
        sample = random.uniform(0,1)
        result = (sample < self.sample_rate)
         
        print(f"Preserve = {result}")

        return result
        
# MCMM sample against number of elements
class RandomSizePrune():
    def __init__(self, max_size=4):
        
        self.max_size = max_size
    
    def __call__(self, n, c, p):
        
        size = p.countValues()
        sample = random.uniform(0, 1)
        result = (sample < (size / self.max_size))
        
        print(f"Preserve = {result}")
        
        return result

# a cute recursive helper for getting total absolute magnitude of Fiber of arbitrary rank 
# this is modeled after countValues(), but I haven't tested it super thoroughly
# is this a helpful thing to add as a Fiber method? useful for computing matrix norms and stuff
def get_magnitude(p):
    mag = 0
    if not (Payload.contains(p, Fiber)):
        return p.v()
    for el in p.payloads:
        if Payload.contains(el, Fiber):
            mag += get_magnitude(el)
        else:
            mag += np.absolute(el.v()) if not Payload.isEmpty(el) else 0
    return mag
        
#not a sampling method, just aggregate value-based pruning of arbitrary rank
class ValueMagnitudePrune():
    def __init__(self, mag_thres=1):
        
        self.mag_thres = mag_thres
    
    def __call__(self, n, c, p):
        #This will prune by individual element or aggregate absolute magnitude 
        if Payload.contains(p, Fiber):
            magnitude = get_magnitude(p)
        else:
            magnitude = p.v()
        result = (magnitude > self.mag_thres)

        return result


# Uniform sampling with a budget that can be dynamically updated
class UniformBudgetPrune():
    def __init__(self, n_sampled=0, max_n_sampled=5, sample_rate=0.5):
        
        self.n_sampled = n_sampled
        self.max_n_sampled = max_n_sampled
        self.sample_rate = sample_rate
        
    def __call__(self, n, c, p):
        
        if self.n_sampled < self.max_n_sampled:
            
            sample = random.uniform(0,1)
            result = (sample < self.sample_rate)
            
            if result == True:
                self.n_sampled += 1
            
            print(f"Preserve = {result}")
            
            return result
        
        else:
            
            print("Preserve = False")
            
            return False
        

# N-element-weighted sampling with a budget that gets dynamically updated
class OccupancyBudgetPrune():
    def __init__(self, n_sampled=0, max_n_sampled=5, max_size=6):
        
        self.n_sampled = n_sampled
        self.max_n_sampled = max_n_sampled
        self.max_size = max_size
        
    def __call__(self, n, c, p):
        
        if self.n_sampled < self.max_n_sampled:
            
            size = p.countValues()
            sample = random.uniform(0, 1)
            result = (sample < (size / self.max_size))
            
            if result == True:
                self.n_sampled += 1
            
            print(f"Preserve = {result}")
            
            return result
        
        else:
            
            print("Preserve = False")
            
            return False
        

# Random Sampling with dynamic sample threshold and budget
class DynamicUniformPrune():
    def __init__(self, counter=0, budget=5, sample_rate=0.5):
        
        self.counter = counter
        self.budget = budget
        self.sample_rate = sample_rate
    
    def update(self, idx):
        
        if idx > 1: #DEBUG; placeholder for a fancier update, for now, just force rest to be kept
            self.budget = 1000
            self.sample_rate = 1.0
        
    def __call__(self, n, c, p):
        
        if self.counter < self.budget:
            
            sample = random.uniform(0, 1)
            result = (sample < self.sample_rate)
            
            if result == True:
                self.counter += 1
            
            print(f"Preserve = {result}")
            
            return result
        
        else:
            
            print("Preserve = False")
            
            return False


# Update budget AND magnitude thres dynamically
class DynamicOccupancyMCMM():
    def __init__(self, counter=0, budget=5, threshold=1):
        
        self.counter = counter
        self.budget = budget
        self.threshold = threshold
        
    def __call__(self, n, c, p):
        
        if self.counter < self.budget:
            
            size = p.countValues()
            
            #dynamic update of threshold 
            if size > self.threshold:
                self.threshold = size
             
            sample = random.uniform(0, 1)
            result = (sample < (size / self.threshold))
            
            if result == True:
                self.counter += 1
            
            print(f"Preserve = {result}")
            
            return result
        
        else:
            
            print("Preserve = False")
            
            return False

# Update budget AND magnitude thres dynamically
class DynamicMagnitudePrune():
    def __init__(self, counter=0, budget=5, threshold=0):
        
        self.counter = counter
        self.budget = budget
        self.threshold = threshold
        
    def __call__(self, n, c, p):
        
        if self.counter < self.budget:
            
            val = get_magnitude(p)

            result = (val > self.threshold)
            
            if result == True:
                self.counter += 1
            
            print(f"Preserve = {result}")
            
            return result
        
        else:
            
            print("Preserve = False")
            
            return False



## Tiling A&B (offline)

In [None]:
#FIXME add sliders for tile params (also to-do: online tiling)
K0=int(K/2)
M0=int(M/2)
N0=int(N/2)

a_tiled = Tensor(rank_ids=["K1", "M1", "K0", "M0"])
a_tiled.setColor("blue")
b_tiled = Tensor(rank_ids=["K1", "N1", "K0", "N0"])
b_tiled.setColor("green")

#first, fill out versions tiled along K
a_m = a.getRoot()
a_tiled_k1 = a_tiled.getRoot()

for (m, a_k) in a_m:
    for (k, value) in a_k:
        k1 = k // K0
        k0 = k %  K0
        m1 = m // M0
        m0 = m % M0
        a_tiled_m0 = a_tiled_k1.getPayloadRef(k1, m1, k0, m0)
        a_tiled_m0 <<= value

b_n = b.getRoot()
b_tiled_k1 = b_tiled.getRoot()

for (n, b_k) in b_n:
    for (k, value) in b_k:
        k1 = k // K0
        k0 = k % K0
        n1 = n // N0
        n0 = n % N0
        b_tiled_n0 = b_tiled_k1.getPayloadRef(k1, n1, k0, n0)
        b_tiled_n0 <<= value
        
displayTensor(a_tiled)
displayTensor(b_tiled)

## Simple Tiled Execution

In [None]:
z_tiled = Tensor(rank_ids=["M1", "N1", "M0", "N0"])

a_tiled_k1 = a_tiled.getRoot()
b_tiled_k1 = b_tiled.getRoot()
z_tiled_m1 = z_tiled.getRoot()

canvas = createCanvas(a_swapped, b_swapped)

for k1, (a_m1, b_n1) in a_tiled_k1 & b_tiled_k1:
    for m1, (z_n1, a_k0) in z_tiled_m1 << a_m1:
        for n1, (z_m0, b_k0) in z_n1 << b_n1:
            for k0, (a_m0, b_n0) in a_k0 & b_k0:
                for m0, (z_n0, a_val) in z_m0 << a_m0:
                    for n0, (z_ref, b_val) in z_n0 << b_n0:
                        z_ref += a_val * b_val
                        canvas.addFrame(((k1*K0)+k0, (m1*M0)+m0), ((k1*K0)+k0, (n1*N0)+n0))
                        
            
            
displayTensor(z_tiled)
displayCanvas(canvas)

## Prune entire tiles

In [None]:
z_tiled = Tensor(rank_ids=["M1", "N1", "M0", "N0"])

a_tiled_k1 = a_tiled.getRoot()
b_tiled_k1 = b_tiled.getRoot()
z_tiled_m1 = z_tiled.getRoot()

canvas = createCanvas(a_swapped, b_swapped)

tile_budget = 1 #subsample tiles according to aggregate magnitude

sample_tensor = DynamicMagnitudePrune(0, tile_budget, 11)

for k1, (a_m1, b_n1) in a_tiled_k1 & b_tiled_k1:
    for m1, (z_n1, a_k0) in z_tiled_m1 << a_m1.prune(sample_tensor):
        for n1, (z_m0, b_k0) in z_n1 << b_n1:
            for k0, (a_m0, b_n0) in a_k0 & b_k0:
                for m0, (z_n0, a_val) in z_m0 << a_m0:
                    for n0, (z_ref, b_val) in z_n0 << b_n0:
                        z_ref += a_val * b_val
                        canvas.addFrame(((k1*K0)+k0, (m1*M0)+m0), ((k1*K0)+k0, (n1*N0)+n0))
    # Sample feedback mechanism; set a per-col budget; increase thres if we aren't reaching budget
    if sample_tensor.counter < tile_budget:
        sample_tensor.threshold -= 1
        print(f"Updated Threshold: {sample_tensor.threshold}")
    sample_tensor.counter = 0 #reset per A M1 fiber
            
            
displayTensor(z_tiled)
displayCanvas(canvas)

## Prune to budget within tile (dynamically force balanced load)

In [None]:
z_tiled = Tensor(rank_ids=["M1", "N1", "M0", "N0"])

a_tiled_k1 = a_tiled.getRoot()
b_tiled_k1 = b_tiled.getRoot()
z_tiled_m1 = z_tiled.getRoot()

canvas = createCanvas(a_swapped, b_swapped)

per_tile_budget = 1 #subsample within tiles according to magnitude

sample_tensor = DynamicMagnitudePrune(0, per_tile_budget, 3)

for k1, (a_m1, b_n1) in a_tiled_k1 & b_tiled_k1:
    for m1, (z_n1, a_k0) in z_tiled_m1 << a_m1:
        for n1, (z_m0, b_k0) in z_n1 << b_n1:
            for k0, (a_m0, b_n0) in a_k0 & b_k0:
                for m0, (z_n0, a_val) in z_m0 << a_m0.prune(sample_tensor):
                    for n0, (z_ref, b_val) in z_n0 << b_n0:
                        z_ref += a_val * b_val
                        canvas.addFrame(((k1*K0)+k0, (m1*M0)+m0), ((k1*K0)+k0, (n1*N0)+n0))
                # Sample feedback mechanism; set a per-col budget; increase thres if we aren't reaching budget
                if sample_tensor.counter < per_tile_budget:
                    sample_tensor.threshold -= 1
                    print(f"Updated Threshold: {sample_tensor.threshold}")
                sample_tensor.counter = 0 #reset per A tile
            
            
displayTensor(z_tiled)
displayCanvas(canvas)

## MCMM per tile

In [None]:
z_tiled = Tensor(rank_ids=["M1", "N1", "M0", "N0"])

a_tiled_k1 = a_tiled.getRoot()
b_tiled_k1 = b_tiled.getRoot()
z_tiled_m1 = z_tiled.getRoot()

canvas = createCanvas(a_swapped, b_swapped)

per_tile_budget=4

sample_tensor = UniformBudgetPrune(0, per_tile_budget, sample_rate)

for k1, (a_m1, b_n1) in a_tiled_k1 & b_tiled_k1:
    for m1, (z_n1, a_k0) in z_tiled_m1 << a_m1:
        for n1, (z_m0, b_k0) in z_n1 << b_n1:
            for k0, (a_m0, b_n0) in a_k0.prune(sample_tensor) & b_k0:
                for m0, (z_n0, a_val) in z_m0 << a_m0:
                    for n0, (z_ref, b_val) in z_n0 << b_n0:
                        z_ref += a_val * b_val
                        canvas.addFrame(((k1*K0)+k0, (m1*M0)+m0), ((k1*K0)+k0, (n1*N0)+n0))
    sample_tensor.n_sampled = 0 #reset per A tile
            
displayTensor(z_tiled)
displayCanvas(canvas)