In [1]:
import torch

In [2]:
def boost_max(x: torch.Tensor):
    q = torch.tensor([0.25, 0.75])
    copy = x.detach().clone().abs()
    copy = copy.view(512, -1).amax(dim=(1,))
    q1, q3 = torch.quantile(copy, q)
    bound = q3 + 1_000 * (q3 - q1)
    max = copy.max()
    print('bound: {}, max_value: {}'.format(bound, max))
    if max > bound:
        print('[FAULT DETECTED]')
    else:
        print('No fault detected')

In [3]:
def boost_sum(x: torch.Tensor):
    q = torch.tensor([0.25, 0.75])
    copy = x.detach().clone()
    copy = copy.view(512, -1).sum(dim=(1,)).abs()
    q1, q3 = torch.quantile(copy, q)
    bound = q3 + 1_000 * (q3 - q1)
    max = copy.max()
    print('bound: {}, max_value: {}'.format(bound, max))
    if max > bound:
        print('[FAULT DETECTED]')
    else:
        print('No fault detected')

In [5]:
def test(fault: float, type: str):
    if type == 'grads':
        tensor = torch.normal(mean=0, std=1, size=(4096, 4096)) * 1e-8
    elif type == 'weights':
        tensor = torch.normal(mean=0, std=1, size=(4096, 4096))
    else:
        raise ValueError('Type must be either \'weights\' or \'grads\'')
    original_value = tensor[0][0].item()
    tensor[0][0] = fault
    print('---------------------------------------------------------')
    print('fault injected in {}: from {} to {}'.format(type, original_value, fault))
    print('result for function max')
    boost_max(tensor)
    print('result for function sum')
    boost_sum(tensor)
    print('\n')
    tensor = None

In [7]:
test(1e+23, 'weights')
test(1e+23, 'grads')

test(1e+3, 'weights')
test(1e+3, 'grads')

test(2, 'weights')
test(2, 'grads')

---------------------------------------------------------
fault injected in weights: from 0.2096564918756485 to 1e+23
result for function max
bound: 343.3172302246094, max_value: 9.999999778196308e+22
[FAULT DETECTED]
result for function sum
bound: 151825.015625, max_value: 9.999999778196308e+22
[FAULT DETECTED]


---------------------------------------------------------
fault injected in grads: from -1.1620688411539959e-08 to 1e+23
result for function max
bound: 3.6666260712081566e-06, max_value: 9.999999778196308e+22
[FAULT DETECTED]
result for function sum
bound: 0.0015446824254468083, max_value: 9.999999778196308e+22
[FAULT DETECTED]


---------------------------------------------------------
fault injected in weights: from 2.2651472091674805 to 1000.0
result for function max
bound: 338.6874084472656, max_value: 1000.0
[FAULT DETECTED]
result for function sum
bound: 160468.875, max_value: 1104.6865234375
No fault detected


---------------------------------------------------------
