In [1]:
import numpy as np
import TensorFrost as tf
import matplotlib.pyplot as plt

tf.initialize(tf.opengl)

test_axis = 1

def PrefixSum(A, axis = -1):
    axis = len(A.shape) + axis if axis < 0 else axis
    group_size = 128
    grouped = tf.split_dim(A, group_size, axis)
    group_scan = tf.prefix_sum(tf.sum(grouped, axis = axis + 1), axis = axis)
    ids = grouped.indices
    gid, eid = ids[axis], ids[axis + 1]
    ids = [ids[i] for i in range(len(ids)) if i != axis + 1]
    ids[axis] = gid - 1
    group_scan = tf.prefix_sum(grouped + tf.select((gid == 0) | (eid != 0), 0, group_scan[tuple(ids)]), axis = axis + 1)
    full_scan = tf.merge_dim(group_scan, target_size = A.shape[axis], axis = axis + 1)
    return full_scan

def Scan():
    data = tf.input([-1, -1, -1], tf.int32)
    return PrefixSum(data, axis = test_axis)

scan_program = tf.compile(Scan)

TensorFrost module loaded!
Scan:
  Kernel count: 5
  Intermediate buffers: 0
  Host readbacks: 0
  Host writes: 0
  Lines of generated code: 493
  IR Compile time: 5.439100 ms
  Steps time: 1511.394409 ms



In [2]:
# Generate some random data to scan (ints between 0 and 10)
data = np.random.randint(0, 10, (25, 1), dtype=np.int32)

data_tf = tf.tensor(data)
scan_tf, grouped, group_scan1, group_scan, sums = scan_program(data_tf)

# do scan in numpy
scan_np = np.cumsum(data, axis=test_axis)

#print error
print("Error: ", np.max(np.abs(scan_tf.numpy - scan_np)))

print("Data: ", data.flatten())
print("Grouped Scan: ", scan_np.flatten())
print("Grouped Scan: ", group_scan.numpy.flatten())
print("Grouped Scan: ", scan_tf.numpy.flatten())
print("Grouped: ", grouped.numpy.flatten())
print("Grouped Scan1: ", group_scan1.numpy.flatten())
print("Full Scan: ", scan_tf.numpy.flatten())
print("Sums: ", sums.numpy.flatten())
print("Diff: ", np.abs(scan_tf.numpy - scan_np).flatten())

Error:  0
Data:  [1 9 2 2 5 6 2 2 3 5 3 0 1 7 6 6 6 5 5 5 6 7 9 5 5]
Grouped Scan:  [  1  10  12  14  19  25  27  29  32  37  40  40  41  48  54  60  66  71
  76  81  87  94 103 108 113]
Grouped Scan:  [  1  10  12  14  19  25  27  29  32  37  40  40  41  48  54  60  66  71
  76  81  87  94 103 108 113 113 113 113 113 113 113 113]
Grouped Scan:  [  1  10  12  14  19  25  27  29  32  37  40  40  41  48  54  60  66  71
  76  81  87  94 103 108 113]
Grouped:  [1 9 2 2 5 6 2 2 3 5 3 0 1 7 6 6 6 5 5 5 6 7 9 5 5 0 0 0 0 0 0 0]
Grouped Scan1:  [ 29  60 108 113]
Full Scan:  [  1  10  12  14  19  25  27  29  32  37  40  40  41  48  54  60  66  71
  76  81  87  94 103 108 113]
Sums:  [29 31 48  5]
Diff:  [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
