In [1]:
import TensorFrost as tf
import numpy as np
import matplotlib.pyplot as plt
import time

#tiled kernel version only works in opengl (barriers not emulated on CPU)
tf.initialize(tf.opengl)

def matmul():
    A = tf.input([-1, -1], tf.float32)
    N, M = A.shape
    B = tf.input([M,  -1], tf.float32)
    K = B.shape[1]

    # C = tf.sin(A) @ tf.cos(B)

    # i,j,k = tf.indices([N, K, M])
    # C = tf.sum(tf.sin(A[i, k]) * tf.cos(B[k, j]))**2.0

    # C = tf.buffer([N, K], tf.float32)
    # BK = 32
    
    #group tiled version
    # with tf.kernel(C.shape, group_size=[BK,BK]) as (i, j):
    #     A_tile = tf.group_buffer(BK*BK, tf.float32)
    #     B_tile = tf.group_buffer(BK*BK, tf.float32)
    #     tx = i.block_thread_index(1)
    #     ty = i.block_thread_index(0)

    #     result = tf.const(0.0)

    #     with tf.loop(0, K, BK) as blk:
    #         A_tile[tx * BK + ty] = A[i,  ty + blk]
    #         B_tile[tx * BK + ty] = B[tx + blk, j]

    #         tf.group_barrier()

    #         with tf.loop(BK) as k:
    #             result.val += A_tile[tx * BK + k] * B_tile[k * BK + ty]

    #         tf.group_barrier()
        
    #     C[i, j] = result.val

    C = tf.buffer([N, K], tf.float32)
    BK = 64
    TBK = 4
    GK = BK // TBK

    #2d local + group tiled version
    with tf.kernel([N / TBK, K / TBK], group_size=[GK, GK]) as (i, j):
        A_group_tile = tf.group_buffer(BK*BK, tf.float32)
        B_group_tile = tf.group_buffer(BK*BK, tf.float32)
        tx = i.block_thread_index(1)
        ty = i.block_thread_index(0)
        gx = i - tx
        gy = j - ty

        #compute multiple results per thread
        results = [tf.const(0.0) for _ in range(TBK*TBK)]

        with tf.loop(0, K, BK) as blk:
            #load tiles into shared memory
            for blk_ty in range(TBK):
                for blk_tx in range(TBK):
                    ltx = blk_tx * GK + tx 
                    lty = blk_ty * GK + ty
                    A_group_tile[ltx * BK + lty] = A[TBK * gx + ltx, blk + lty]

            for blk_tx in range(TBK):
                for blk_ty in range(TBK):
                    ltx = blk_tx * GK + tx 
                    lty = blk_ty * GK + ty
                    B_group_tile[ltx * BK + lty] = B[blk + ltx, TBK * gy + lty]

            #wait for all threads to finish loading
            tf.group_barrier()

            with tf.loop(BK) as k:
                #perform outer product and accumulate
                B_row = [B_group_tile[k * BK + (tid * GK + ty)] for tid in range(TBK)]
                for tid in range(TBK):
                    A_value = A_group_tile[(tid * GK + tx) * BK + k]
                    for tid2 in range(TBK):
                        results[tid * TBK + tid2].val += A_value * B_row[tid2]

            tf.group_barrier()
        
        for tid in range(TBK):
            for tid2 in range(TBK):
                ltx = tid * GK + tx
                lty = tid2 * GK + ty
                C[TBK * gx + ltx, TBK * gy + lty] = results[tid * TBK + tid2]

    return C

mmul = tf.compile(matmul)

TensorFrost module loaded!
matmul:
  Kernel count: 1
  Intermediate buffers: 0
  Host readbacks: 0
  Host writes: 0
  Lines of generated code: 543
  IR Compile time: 18.211700 ms
  Codegen time: 3.268200 ms
  Host Compile time: 1465.661987 ms
  Shader Compile time: 79.253105 ms



In [2]:
all_kernels = tf.get_all_generated_kernels()
print("Generated kernels:")
for k in all_kernels:
    print(k[0][2])

Generated kernels:
shared float B_group_tile[4096];
shared float A_group_tile[4096];

layout (local_size_x = 16, local_size_y = 16, local_size_z = 1) in;

void main() {
  int block_id = int(gl_WorkGroupID.x + var._kernel_block_offset);
  int block_thread_id0 = int(gl_LocalInvocationID.x);
  int block_thread_id1 = int(gl_LocalInvocationID.y);
  int block_thread_id2 = int(gl_LocalInvocationID.z);

  int vdiv = var.K / 4;
  int vdiv_2 = var.N / 4;
  int blocks_shape_0 = ((vdiv + 16) - 1) / 16;
  int vdiv_3 = block_id / blocks_shape_0;
  int index_0 = ((block_id - (vdiv_3 * blocks_shape_0)) * 16) + block_thread_id0;
  int index_1 = (vdiv_3 * 16) + block_thread_id1;
  bool is_inside_dispatch = (index_0 < vdiv) && (index_1 < vdiv_2);
  if (is_inside_dispatch)
  {
    //float A_group_tile[4096]
    //float B_group_tile[4096]
    int tx = block_thread_id1;
    int ty = block_thread_id0;
    int gx = index_1 - tx;
    int gy = index_0 - ty;
    float vconst_6 = 0.f;
    float vconst_7 = 0.f;
  

In [4]:
Anp = np.random.rand(4096, 4096).astype(np.float32)
Bnp = np.random.rand(4096, 4096).astype(np.float32)
A = tf.tensor(Anp)
B = tf.tensor(Bnp)

start = time.time()
repeat_tf = 256
for i in range(repeat_tf):
    C = mmul(A, B)
Cnp = C.numpy
tf_time = (time.time() - start) / repeat_tf


#compare to numpy
repeat_np = 16
start = time.time()
for i in range(repeat_np):
    Cnp2 = (Anp @ Bnp)
np_time = (time.time() - start) / repeat_np

Cerror = np.linalg.norm(Cnp - Cnp2) / np.linalg.norm(Cnp2)
print("Error:", Cerror)
print("TF Time:", tf_time)
print("NP Time:", np_time)
print("Speedup:", np_time / tf_time)

tf_flops = 2 * Anp.shape[0] * Anp.shape[1] * Bnp.shape[1] / tf_time
print("TF GFLOPS:", tf_flops / 1e9)
np_flops = 2 * Anp.shape[0] * Anp.shape[1] * Bnp.shape[1] / np_time
print("NP GFLOPS:", np_flops / 1e9)

Error: 8.507844e-07
TF Time: 0.015718752518296242
NP Time: 0.30787502229213715
Speedup: 19.586479393563717
TF GFLOPS: 8743.62983398488
NP GFLOPS: 446.4115096079039
