In [1]:
import TensorFrost as tf
import numpy as np
import matplotlib.pyplot as plt
import time

tf.initialize(tf.opengl)

def matmul():
    A = tf.input([-1, -1], tf.float32)
    N, M = A.shape
    B = tf.input([M,  -1], tf.float32)
    K = B.shape[1]

    C = (tf.sin(A) @ tf.cos(B))**2.0

    # i,j,k = tf.indices([N, K, M])
    # C = tf.sum(tf.sin(A[i, k]) * tf.cos(B[k, j]))**2.0
    return C

mmul = tf.compile(matmul)

TensorFrost module loaded!
matmul:
  Kernel count: 2
  Intermediate buffers: 1
  Host readbacks: 0
  Host writes: 0
  Lines of generated code: 508
  IR Compile time: 1.237300 ms
  Host Compile time: 1420.039551 ms
  Shader Compile time: 37.693501 ms



In [2]:
all_kernels = tf.get_all_generated_kernels()
print("Generated kernels:")
for k in all_kernels:
    print(k)

Generated kernels:
(('\n#version 460\n\nuint pcg(uint v) {\n  uint state = v * 747796405u + 2891336453u;\n  uint word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u;\n  return (word >> 22u) ^ word;\n}\n\nfloat pcgf(uint v) {\n  return float(pcg(v)) / float(0xffffffffu);\n}\n\nfloat asfloat(uint x) {\n  return uintBitsToFloat(x);\n}\n\nuint asuint(float x) {\n  return floatBitsToUint(x);\n}\n\nuint asuint(bool x) {\n\treturn uint(x);\n}\n\nuint asuint(int x) {\n  return uint(x);\n}\n\nuint asuint(uint x) {\n  return x;\n}\n\nint asint(uint x) {\n  return int(x);\n}\n\nbool asbool(uint x) {\n  return bool(x);\n}\n\n\nstruct UBO {\n  int M;\n  int N;\n};\n\n', 'layout(std430, binding = 0) buffer buf_m0 {\n  uint m0_mem[];\n};\n\nfloat atomicAdd_m0(int index, float val) {\n\tuint uval = floatBitsToUint(val);\n\tuint tmp0 = 0;\n\tuint tmp1 = 0;\n\n\twhile (true) {\n\t\ttmp0 = atomicCompSwap(m0_mem[index], tmp1, uval);\n\t\tif (tmp1 == tmp0) break;\n\t\ttmp1 = tmp0;\n\t\tuval = fl

In [3]:
Anp = np.random.rand(4096, 4096).astype(np.float32)
Bnp = np.random.rand(4096, 4096).astype(np.float32)
A = tf.tensor(Anp)
B = tf.tensor(Bnp)

start = time.time()
repeat = 32
for i in range(repeat):
    C = mmul(A, B)
Cnp = C.numpy
tf_time = (time.time() - start) / repeat


#compare to numpy
start = time.time()
for i in range(repeat):
    Cnp2 = (np.sin(Anp) @ np.cos(Bnp))**2.0
np_time = (time.time() - start) / repeat

Cerror = np.linalg.norm(Cnp - Cnp2) / np.linalg.norm(Cnp2)
print("Error:", Cerror)
print("TF Time:", tf_time)
print("NP Time:", np_time)
print("Speedup:", np_time / tf_time)

tf_flops = 2 * Anp.shape[0] * Anp.shape[1] * Bnp.shape[1] / tf_time
print("TF GFLOPS:", tf_flops / 1e9)
np_flops = 2 * Anp.shape[0] * Anp.shape[1] * Bnp.shape[1] / np_time
print("NP GFLOPS:", np_flops / 1e9)

Error: 2.0017699e-06
TF Time: 0.07846880704164505
NP Time: 0.3770780712366104
Speedup: 4.805451815222412
TF GFLOPS: 1751.510678619827
NP GFLOPS: 364.4840789104368
