In [2]:
import torch
import time

In [55]:
def test_mask(N: int, shape: tuple, device: str = 'cuda'):
    start = time.time()    
    input = torch.round(torch.randn(*shape))
    min_max = torch.tensor([0, 0])
    if device == 'cuda':
        input = input.cuda()
        min_max = min_max.cuda()
    for i in range(N):
        grad = torch.abs(input)
        mask = grad > 0
        min_max[0] = torch.min(grad[mask])
        min_max[1] = torch.max(grad[mask])
    return time.time() - start


def test_refactor_mask_filter(N: int, shape: tuple, device: str = 'cuda'):
    start = time.time()
    input = torch.round(torch.randn(*shape))
    min_max = torch.tensor([0, 0]).cuda()
    if device == 'cuda':
        input = input.cuda()
        min_max = min_max.cuda()
    for i in range(N):
        grad = torch.abs(input)
        min_max[1] = torch.max(grad)
        grad = torch.where(grad == 0, float('inf'), grad)
        min_max[0] = torch.min(grad)
    return time.time() - start


def test_refactor_mask_filter_mean(N: int, shape: tuple, device: str = 'cuda'):
    start = time.time()
    input = torch.round(torch.randn(*shape))
    min_max = torch.tensor([0, 0]).cuda()
    if device == 'cuda':
        input = input.cuda()
        min_max = min_max.cuda()
    for i in range(N):
        grad = torch.abs(input)
        grad = torch.where(grad == 0, ((grad[0][0] + grad[0][1])/2).item(), grad)
        min_max[0], min_max[0] = torch.aminmax(grad)
    return time.time() - start


def test_refactor_mask_math(N: int, shape: tuple, device: str = 'cuda'):
    start = time.time()
    input = torch.round(torch.randn(*shape))
    min_max = torch.tensor([0, 0]).cuda()
    min_denorm = torch.tensor([2 ** -133], dtype=torch.bfloat16)
    if device == 'cuda':
        input = input.cuda()
        min_max = min_max.cuda()
        min_denorm = min_denorm.cuda()
    for i in range(N):
        grad = torch.abs(input)
        min_max[1] = torch.max(grad)
        # grad.add_(min_denorm / grad)
        grad = grad + min_denorm / grad
        min_max[0] = torch.min(grad)
    return time.time() - start


def test_without_mask(N: int, shape: tuple, device: str = 'cuda'):
    start = time.time()    
    input = torch.round(torch.randn(*shape))
    min_max = torch.tensor([0,0]).cuda()
    if device == 'cuda':
        input = input.cuda()
        min_max = min_max.cuda()
    for i in range(N):
        grad = torch.abs(input)
        min_max[0] = torch.min(grad)
        min_max[1] = torch.max(grad)
    return time.time() - start

In [56]:
shape = (10000, 10000)
N = 10000
time_test_mask = test_refactor_mask_filter_mean(N, shape, device='cuda')
time_test_without_mask = test_without_mask(N, shape, device='cuda')
print('overhead of test_refactor_mask_filter_mean for GPU: ', str(time_test_mask / time_test_without_mask))
# overhead of test_refactor_mask_math for GPU:  2.2570415979145313

overhead of test_refactor_mask_filter_mean for GPU:  1.734465464374379


In [41]:
shape = (10000, 10000)
N = 100
time_test_mask = test_mask(N, shape, device='cpu')
time_test_without_mask = test_without_mask(N, shape, device='cpu')
print('overhead of test_mask for CPU: ', str(time_test_mask / time_test_without_mask))

overhead of test_mask for CPU:  7.014615993037051


In [46]:
shape = (10000, 10000)
N = 100
time_test_mask = test_mask(N, shape, device='cuda')
time_test_without_mask = test_without_mask(N, shape, device='cuda')
print('overhead of test_mask for GPU: ', str(time_test_mask / time_test_without_mask))

overhead of test_mask for GPU:  4.447599304294528


In [47]:
shape = (10000, 10000)
N = 100
time_test_mask = test_refactor_mask_filter(N, shape, device='cpu')
time_test_without_mask = test_without_mask(N, shape, device='cpu')
print('overhead of test_refactor_mask_filter for CPU: ', str(time_test_mask / time_test_without_mask))

overhead of test_refactor_mask_filter for CPU:  2.246481991501602


In [52]:
shape = (10000, 10000)
N = 10000
time_test_mask = test_refactor_mask_filter(N, shape, device='cuda')
time_test_without_mask = test_without_mask(N, shape, device='cuda')
print('overhead of test_refactor_mask_filter for GPU: ', str(time_test_mask / time_test_without_mask))

overhead of test_refactor_mask_filter for GPU:  1.9041356556827576


In [51]:
shape = (10000, 10000)
N = 100
time_test_mask = test_refactor_mask_math(N, shape, device='cpu')
time_test_without_mask = test_without_mask(N, shape, device='cpu')
print('overhead of test_refactor_mask_math for CPU: ', str(time_test_mask / time_test_without_mask))

overhead of test_refactor_mask_math for CPU:  3.1817730339385304


In [53]:
shape = (10000, 10000)
N = 10000
time_test_mask = test_refactor_mask_math(N, shape, device='cuda')
time_test_without_mask = test_without_mask(N, shape, device='cuda')
print('overhead of test_refactor_mask_math for GPU: ', str(time_test_mask / time_test_without_mask))

overhead of test_refactor_mask_math for GPU:  2.2570415979145313
