In [1]:
import TensorFrost as tf
import numpy as np
import matplotlib.pyplot as plt
import time

tf.initialize(tf.opengl)

def matmul():
    A = tf.input([-1, -1], tf.float32)
    N, M = A.shape
    B = tf.input([M,  -1], tf.float32)
    K = B.shape[1]

    C = (tf.sin(A) @ tf.cos(B))**2.0

    return [C]

mmul = tf.compile(matmul)

TensorFrost module loaded!
matmul:
  Kernel count: 1
  Intermediate buffers: 0
  Host readbacks: 0
  Host writes: 0
  Lines of generated code: 414
  IR Compile time: 0.857300 ms
  Compiler time: 4804.893555 ms



In [2]:
all_kernels = tf.get_all_generated_kernels()
print("Generated kernels:")
for k in all_kernels:
    print(k)

Generated kernels:

#version 460

uint pcg(uint v) {
  uint state = v * 747796405u + 2891336453u;
  uint word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u;
  return (word >> 22u) ^ word;
}

float pcgf(uint v) {
  return float(pcg(v)) / float(0xffffffffu);
}

float asfloat(uint x) {
  return uintBitsToFloat(x);
}

uint asuint(float x) {
  return floatBitsToUint(x);
}

uint asuint(int x) {
  return uint(x);
}

uint asuint(uint x) {
  return x;
}

int asint(uint x) {
  return int(x);
}

uniform int off[32];
uniform int var[32];

layout(std430, binding = 0) buffer memory {
  uint mem[];
};
layout (local_size_x = 16, local_size_y = 16, local_size_z = 1) in;

void main() {
  int block_id = int(gl_WorkGroupID.x);
  int block_thread_id0 = int(gl_LocalInvocationID.x);
  int block_thread_id1 = int(gl_LocalInvocationID.y);
  int block_thread_id2 = int(gl_LocalInvocationID.z);

  int v2_0 = block_id;
  int in_block_index_0 = block_thread_id1;
  int in_block_index_1 = block_thread_id0;

In [3]:
Anp = np.random.rand(512, 512).astype(np.float32)
Bnp = np.random.rand(512, 512).astype(np.float32)
A = tf.tensor(Anp)
B = tf.tensor(Bnp)

start = time.time()
repeat = 4000
for i in range(repeat):
    C, = mmul(A, B)
Cnp = C.numpy
tf_time = (time.time() - start) / repeat


#compare to numpy
start = time.time()
for i in range(repeat):
    Cnp2 = (np.sin(Anp) @ np.cos(Bnp))**2.0
np_time = (time.time() - start) / repeat

Cerror = np.linalg.norm(Cnp - Cnp2) / np.linalg.norm(Cnp2)
print("Error:", Cerror)
print("TF Time:", tf_time)
print("NP Time:", np_time)
print("Speedup:", np_time / tf_time)

tf_flops = 2 * Anp.shape[0] * Anp.shape[1] * Bnp.shape[1] / tf_time
print("TF GFLOPS:", tf_flops / 1e9)
np_flops = 2 * Anp.shape[0] * Anp.shape[1] * Bnp.shape[1] / np_time
print("NP GFLOPS:", np_flops / 1e9)

Error: 9.362458e-07
TF Time: 0.0011913295984268188
NP Time: 0.00409303492307663
Speedup: 3.4356864200147355
TF GFLOPS: 225.32425648995533
NP GFLOPS: 65.58347559815687
