In [5]:
import torch
from typing import Callable
import time

In [9]:
x = torch.normal(mean=0, std=1, size=(10000, 10000)) * 1e-8
x[0][0] = 1e-23
# x[0][0] = 2

x = x.abs()
k = 6
mean = torch.mean(x)
std = torch.std(x)
min, max = torch.aminmax(x)
print('max: {}, upper bound: {}'.format(max, mean + k * std))
if max > mean + k * std:
    print('[FAULT DETECTED]')
else:
    print('No fault detected')

max: 5.709773986950495e-08, upper bound: 4.4150699096690005e-08
[FAULT DETECTED]


In [10]:
x = torch.normal(mean=0, std=1, size=(10000, 10000)) * 1e-8
x[0][0] = 1e-23
# x[0][0] = 2

x = x.abs()
group_size = 100
x = x.view(-1, group_size, group_size).amax(dim=(1,2))
print('shape: {}'.format(x.shape))
k = 6
mean = torch.mean(x)
std = torch.std(x)
min, max = torch.aminmax(x)
print('max: {}, upper bound: {}'.format(max, mean + k * std))
if max > mean + k * std:
    print('[FAULT DETECTED]')
else:
    print('No fault detected')

    # max: 1.000000071494503e+25, upper bound: 6.100000459174899e+24

shape: torch.Size([10000])
max: 5.66467939222548e-08, upper bound: 5.7984074430805777e-08
No fault detected


In [177]:
def chebyshev(x: torch.Tensor):
    x = x.abs()
    k = 6
    min, max = torch.aminmax(x)
    mean = torch.mean(x)
    std = torch.std(x)
    return max > mean + k * std

def twoBoundaries(x: torch.Tensor):
    x = x.abs()
    k = 6
    group_size = 100
    x = x.view(-1, group_size, group_size).mean(dim=(1,2))
    min, max = torch.aminmax(x)
    mean = torch.mean(x)
    std = torch.std(x)
    return max > mean + k * std
    


In [178]:
def test(x: torch.Tensor, f: Callable[[torch.Tensor], bool], iters: int):
    start_time = time.time()
    for i in range(iters):
        f(x)
    return time.time() - start_time

In [184]:
x = torch.normal(mean=0, std=1e-9, size=(20000, 20000)) * 1e-8
x[0][0] = 1e-30

result1 = test(x, chebyshev, 100)
result2 = test(x, twoBoundaries, 100)
print('time without optimization: {}'.format(result1))
print('time with optimization: {}'.format(result2))
print('overhead: {}%'.format((result1 / result2 - 1) * 100))

time without optimization: 54.331501483917236
time with optimization: 29.21265935897827
overhead: 85.98615352429013%
