In [1]:
import TensorFrost as tf
import numpy as np
import matplotlib.pyplot as plt
import time

tf.initialize(tf.opengl)

def matmul():
    A = tf.input([-1, -1], tf.float32)
    N, M = A.shape
    B = tf.input([M,  -1], tf.float32)
    K = B.shape[1]

    C = (tf.sin(A) @ tf.cos(B))**2.0

    return [C]

mmul = tf.compile(matmul)

TensorFrost module loaded!
matmul:
  Kernel count: 1
  Intermediate buffers: 0
  Host readbacks: 0
  Host writes: 0
  Lines of generated code: 424
  IR Compile time: 0.944300 ms
  Compiler time: 1385.295044 ms



In [2]:
all_kernels = tf.get_all_generated_kernels()
print("Generated kernels:")
for k in all_kernels:
    print(k)

Generated kernels:

#version 460

uint pcg(uint v) {
  uint state = v * 747796405u + 2891336453u;
  uint word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u;
  return (word >> 22u) ^ word;
}

float pcgf(uint v) {
  return float(pcg(v)) / float(0xffffffffu);
}

float asfloat(uint x) {
  return uintBitsToFloat(x);
}

uint asuint(float x) {
  return floatBitsToUint(x);
}

uint asuint(int x) {
  return uint(x);
}

uint asuint(uint x) {
  return x;
}

int asint(uint x) {
  return int(x);
}

uniform int var[32];
layout(std430, binding = 0) buffer buf_A {
  uint A_mem[];
};

float atomicAdd_A(int index, float val) {
	uint uval = floatBitsToUint(val);
	uint tmp0 = 0;
	uint tmp1 = 0;

	while (true) {
		tmp0 = atomicCompSwap(A_mem[index], tmp1, uval);
		if (tmp1 == tmp0) break;
		tmp1 = tmp0;
		uval = floatBitsToUint(val + uintBitsToFloat(tmp1));
	}

	return uintBitsToFloat(tmp1);
}

layout(std430, binding = 1) buffer buf_B {
  uint B_mem[];
};

float atomicAdd_B(int index, float val)

In [5]:
Anp = np.random.rand(4096, 4096).astype(np.float32)
Bnp = np.random.rand(4096, 4096).astype(np.float32)
A = tf.tensor(Anp)
B = tf.tensor(Bnp)

start = time.time()
repeat = 100
for i in range(repeat):
    C, = mmul(A, B)
Cnp = C.numpy
tf_time = (time.time() - start) / repeat


#compare to numpy
start = time.time()
for i in range(repeat):
    Cnp2 = (np.sin(Anp) @ np.cos(Bnp))**2.0
np_time = (time.time() - start) / repeat

Cerror = np.linalg.norm(Cnp - Cnp2) / np.linalg.norm(Cnp2)
print("Error:", Cerror)
print("TF Time:", tf_time)
print("NP Time:", np_time)
print("Speedup:", np_time / tf_time)

tf_flops = 2 * Anp.shape[0] * Anp.shape[1] * Bnp.shape[1] / tf_time
print("TF GFLOPS:", tf_flops / 1e9)
np_flops = 2 * Anp.shape[0] * Anp.shape[1] * Bnp.shape[1] / np_time
print("NP GFLOPS:", np_flops / 1e9)

Error: 2.0023163e-06
TF Time: 0.09573355436325073
NP Time: 0.37513524055480957
Speedup: 3.9185345519649153
TF GFLOPS: 1435.6403497827166
NP GFLOPS: 366.37174707642356
