In [1]:
import torch.nn as nn
import torch
import card_embedding as ce
import filter_test_data as ftd

torch.manual_seed(1)
import timeit
def benchmark(func, *args, **kwargs):
    torch.cuda.synchronize()
    torch.cuda.reset_peak_memory_stats()
    begin = timeit.default_timer()
    output = func(*args, **kwargs)
    torch.cuda.synchronize()
    end = timeit.default_timer()
    return output, (end - begin), torch.cuda.max_memory_allocated()

d = 8
shared = ce.SharedEmbeddingHolder(d, device='cuda')
filter_embedding_compiled = ce.FilterEmbedding(shared, d, device='cuda')
filter_embedding_compiled.compile()




In [2]:
test_data = ftd.test_data_nested+ftd.test_data_simple
print(filter_embedding_compiled.forward(test_data))
print(filter_embedding_compiled.forward_v2(test_data))

tensor([[-0.8465,  0.6920, -2.7417, -2.3558, -2.6807, -3.5117,  0.2203, -4.1612],
        [ 1.1584,  1.2318, -2.7416,  2.9254,  5.9676,  0.1013, -3.7345, -4.2022],
        [ 1.8089, -1.4140, -2.3255,  1.9536, -0.9993, -1.2065,  0.4727, -7.1884],
        [-0.7145, -0.7915,  0.5151,  4.8474,  1.4616, -0.1473, -1.1910, -6.7304],
        [ 0.8837, -2.8052,  0.7395,  1.1925, -0.8534, -1.7399, -2.2889, -4.4179],
        [-0.5561, -0.7399, -0.9056, -2.0540, -2.0976,  0.7589,  0.9070, -0.9450],
        [ 1.9092,  0.4985, -1.3569,  1.5612,  2.3272, -0.5008, -0.5739, -3.0624],
        [ 0.5776, -0.4001, -0.4550, -0.5947, -2.3788,  0.0537,  0.5271, -0.2673],
        [ 0.1509,  0.3151, -0.8073,  0.9548, -0.7206,  0.2526,  0.8241, -0.6592]],
       device='cuda:0', grad_fn=<StackBackward0>)
tensor([[-0.8465,  0.6920, -2.7417, -2.3558, -2.6807, -3.5117,  0.2203, -4.1612],
        [ 1.1584,  1.2318, -2.7416,  2.9254,  5.9676,  0.1013, -3.7345, -4.2022],
        [ 1.8089, -1.4140, -2.3255,  1.9536, -0

In [4]:
for i in range(5):
    filter_embedding_compiled.forward(test_data)
    filter_embedding_compiled.forward_v2(test_data)

total_time = 0
total_time_v2 = 0
n=1000
for i in range(n):
    _, time, _ = benchmark(filter_embedding_compiled.forward, test_data)
    _, time_v2, _ =benchmark(filter_embedding_compiled.forward_v2, test_data)
    total_time += time
    total_time_v2 += time_v2

print(f"Average time for forward: {total_time/n}")
print(f"Average time for forward_v2: {total_time_v2/n}")
print(f"Speedup: {total_time/total_time_v2}")


Average time for forward: 0.0023157882159284783
Average time for forward_v2: 0.002071467038935225
Speedup: 1.1179459640925975
