In [1]:
import torch

acts = torch.Tensor(
    [[5,   4,   3],
     [1,   1,   1],
     [1,  10,   2],
     ]
    )
self_W_dec = torch.Tensor(
    [[1,2],
     [3,4],
     [5,6]]
    )
acts

tensor([[ 5.,  4.,  3.],
        [ 1.,  1.,  1.],
        [ 1., 10.,  2.]])

In [2]:
no_act = acts @ self_W_dec # No activation; same shape with input x.
no_act

tensor([[32., 44.],
        [ 9., 12.],
        [41., 54.]])

In [3]:
batch_size, h = acts.shape
batch_size, h

(3, 3)

In [4]:
torch.norm(self_W_dec, p=2, dim=-1)**2

tensor([ 5.0000, 25.0000, 61.0000])

In [5]:
torch.norm(self_W_dec, p=2, dim=-1)

tensor([2.2361, 5.0000, 7.8102])

In [6]:
dec_scaled_acts = (acts * torch.norm(self_W_dec, p=2, dim=-1)).pow(2)
print("If D (decoder) were unit vectors...")
dec_scaled_acts

If D (decoder) were unit vectors...


tensor([[ 125.0000,  400.0000,  549.0001],
        [   5.0000,   25.0000,   61.0000],
        [   5.0000, 2500.0000,  244.0000]])

In [7]:
max_indices = torch.argsort(dec_scaled_acts, dim=-1, descending=True)
max_indices

tensor([[2, 1, 0],
        [2, 1, 0],
        [1, 2, 0]])

In [8]:
cumulative_fa = torch.cumsum(torch.gather(dec_scaled_acts, -1, max_indices), dim=-1)
cumulative_fa[..., -1] = 10000 # Ensure the last cumulative norm is large enough
cumulative_fa

tensor([[  549.0001,   949.0001, 10000.0000],
        [   61.0000,    86.0000, 10000.0000],
        [ 2500.0000,  2744.0000, 10000.0000]])

In [9]:
# afa = 1
# afa = 70
# afa = 80
# afa = 900
# afa = 2000
# afa = 2600
# afa = 9999
afa = torch.norm(no_act, p=2, dim=1, keepdim=True).pow(2)
afa

tensor([[2960.0002],
        [ 225.0000],
        [4596.9995]])

In [10]:
cumulative_fa = torch.sqrt(cumulative_fa)
afa = torch.sqrt(afa)
cumulative_fa, afa

(tensor([[ 23.4307,  30.8058, 100.0000],
         [  7.8102,   9.2736, 100.0000],
         [ 50.0000,  52.3832, 100.0000]]),
 tensor([[54.4059],
         [15.0000],
         [67.8012]]))

In [11]:
# Basic way: if exceed
# print(cumulative_fa >= afa)
# k = (cumulative_fa >= afa).int().argmax(dim=-1) + 1

# Advance way: nearest index
k = torch.abs(cumulative_fa - afa).argmin(dim=-1) + 1

k

tensor([2, 2, 2])

In [12]:
# k = torch.clamp(k, max=2)  # when using parameter k (but, this will converge to TopK)
# k

In [13]:
acts_topk = acts * torch.zeros_like(acts, dtype=torch.bool).scatter_(
    dim=1,
    index=max_indices,
    src=torch.arange(acts.shape[1], device=acts.device).unsqueeze(0) < k.unsqueeze(1), # Mask of (batch_size, h); True for top-k positions
)
acts_topk


tensor([[ 0.,  4.,  3.],
        [ 0.,  1.,  1.],
        [ 0., 10.,  2.]])