In [13]:
from typing import Union
from torch import Tensor
from torch import linalg as LA
import torch

In [20]:
def get_normal_tensors(x: Tensor) -> Union[Tensor, None]:
    gradients_norm = LA.matrix_norm(x, dim=(1, 2))

    gradients_mean = torch.mean(gradients_norm)
    gradients_dev = torch.std(gradients_norm)

    lower_mask = gradients_norm <= (gradients_mean + gradients_dev * 1.5)
    upper_mask = gradients_norm >= (gradients_mean - gradients_dev * 1.5)

    final_mask = torch.mul(lower_mask, upper_mask)
    selected_tensor = x[final_mask, :, :]

    if not torch.any(final_mask):
        return None

    return selected_tensor


input_gradients = torch.rand((100, 10, 256))
selected_gradients = get_normal_tensors(input_gradients)

print(selected_gradients)
print(f"Input shape {input_gradients.size()}")
print(f"Output shape {selected_gradients.size()}")


tensor([[[0.6736, 0.7104, 0.4067,  ..., 0.7132, 0.8789, 0.7312],
         [0.6957, 0.6053, 0.5786,  ..., 0.7887, 0.0229, 0.3599],
         [0.6427, 0.4423, 0.3498,  ..., 0.5449, 0.4429, 0.1655],
         ...,
         [0.5649, 0.7931, 0.3568,  ..., 0.0497, 0.3231, 0.8504],
         [0.2210, 0.3201, 0.9285,  ..., 0.2219, 0.0550, 0.1944],
         [0.7691, 0.6161, 0.3649,  ..., 0.3164, 0.6924, 0.0797]],

        [[0.8193, 0.2650, 0.8584,  ..., 0.7130, 0.8624, 0.8869],
         [0.2401, 0.0912, 0.7904,  ..., 0.7126, 0.1617, 0.7804],
         [0.8993, 0.6937, 0.4602,  ..., 0.6567, 0.1363, 0.0419],
         ...,
         [0.3232, 0.4270, 0.4479,  ..., 0.4303, 0.2171, 0.0577],
         [0.7198, 0.1235, 0.8972,  ..., 0.6614, 0.5884, 0.9415],
         [0.1952, 0.0280, 0.5648,  ..., 0.0106, 0.6290, 0.6400]],

        [[0.7096, 0.5537, 0.4614,  ..., 0.4037, 0.9977, 0.0287],
         [0.7317, 0.3911, 0.7303,  ..., 0.2966, 0.0729, 0.0932],
         [0.8224, 0.3804, 0.1416,  ..., 0.2332, 0.7741, 0.