In [16]:
import torch

# Define the function as given
def bh_tstat_torch_precomputed(t_stats: torch.Tensor, crit: torch.Tensor) -> torch.BoolTensor:
    flat = t_stats.flatten()
    M = flat.numel()
    device = flat.device

    abs_flat = flat.abs()
    order = torch.argsort(abs_flat, descending=True)
    t_desc = abs_flat[order]

    hits = t_desc >= crit
    keep_flat = torch.zeros(M, dtype=torch.bool, device=device)
    if hits.any():
        last = torch.nonzero(hits, as_tuple=False).max()
        keep_flat[order[: last.item() + 1 ]] = True

    return keep_flat.view(t_stats.shape)

# Example tensor
t_stats = torch.tensor([
    [ 0.5, -2.0,  3.0],
    [ 1.1, -4.2,  0.1]
])

crit = torch.tensor(1.5)

# Run the mask function
mask = bh_tstat_torch_precomputed(t_stats, crit)

# Show the results
print("Original t_stats tensor:")
print(t_stats)
print("\nAbsolute flattened values and their original indices:")
abs_flat = t_stats.flatten().abs()
for idx, val in enumerate(abs_flat):
    print(f"Index {idx:>2}: {val.item():.2f}")

order = torch.argsort(abs_flat, descending=True)
print("\nIndices in descending abs order:", order.tolist())
print("Sorted abs values:", abs_flat[order].tolist())

print("\nMask (True = keep) reshaped to original shape:")
print(mask)

Original t_stats tensor:
tensor([[ 0.5000, -2.0000,  3.0000],
        [ 1.1000, -4.2000,  0.1000]])

Absolute flattened values and their original indices:
Index  0: 0.50
Index  1: 2.00
Index  2: 3.00
Index  3: 1.10
Index  4: 4.20
Index  5: 0.10

Indices in descending abs order: [4, 2, 1, 3, 0, 5]
Sorted abs values: [4.199999809265137, 3.0, 2.0, 1.100000023841858, 0.5, 0.10000000149011612]

Mask (True = keep) reshaped to original shape:
tensor([[False,  True,  True],
        [False,  True, False]])


In [11]:
t_stats

tensor([[ 3.0820, -0.5869, -4.3576,  1.1369],
        [-2.1690, -2.7972,  0.8067,  1.6761],
        [-1.4385, -0.8067, -1.1933,  0.3641]])