In [None]:
import torch
from torch import autograd

'''
    Implementation from:
        https://github.com/numenta/htmpapers/blob/master/arxiv/how_can_we_be_so_dense/src/pytorch/functions/k_winners.py
'''

class k_winners(autograd.Function):
  '''
  A simple K-winner take all autograd function for creating layers with sparse
  output.
   .. note::
      Code adapted from this excellent tutorial:
      https://github.com/jcjohnson/pytorch-examples
  '''


  @staticmethod
  def forward(ctx, x, k, dutyCycles, boostStrength):
    '''
        Use the boost strength to compute a boost factor for each unit represented
        in x. These factors are used to increase the impact of each unit to improve
        their chances of being chosen. This encourages participation of more columns
        in the learning process.
        The boosting function is a curve defined as: boostFactors = exp[ -
        boostStrength * (dutyCycle - targetDensity)] Intuitively this means that
        units that have been active (i.e. in the top-k) at the target activation
        level have a boost factor of 1, meaning their activity is not boosted.
        Columns whose duty cycle drops too much below that of their neighbors are
        boosted depending on how infrequently they have been active. Unit that has
        been active more than the target activation level have a boost factor below
        1, meaning their activity is suppressed and they are less likely to be in 
        the top-k.
        Note that we do not transmit the boosted values. We only use boosting to
        determine the winning units.
        The target activation density for each unit is k / number of units. The
        boostFactor depends on the dutyCycle via an exponential function:
                boostFactor
                    ^
                    |
                    |\
                    | \
                1 _ |  \
                    |    _
                    |      _ _
                    |          _ _ _ _
                    +--------------------> dutyCycle
                    |
                targetDensity
        :param ctx: 
            Context object that can be used to stash information for 
            backward computation.
      
        :param x: 
            Current activity of each unit.  
      
        :param k: 
            The activity of the top k units will be allowed to remain, 
            the rest are set to zero.
      
        :param dutyCycles: 
            The averaged duty cycle of each unit.
                
        :param boostStrength:     
            A boost strength of 0.0 has no effect on x.
        :return: 
            A tensor representing the activity of x after k-winner take all.
    '''
    if boostStrength > 0.0:
      targetDensity = float(k) / x.size(1)
      boostFactors = torch.exp((targetDensity - dutyCycles) * boostStrength)
      boosted = x.detach() * boostFactors
    else:
      boosted = x.detach()

    # Take the boosted version of the input x, find the top k winners.
    # Compute an output that contains the values of x corresponding to the top k
    # boosted values
    res = torch.zeros_like(x)
    _, indices = boosted.topk(k, sorted=False)
    for i in range(x.shape[0]):
      res[i, indices[i]] = x[i, indices[i]]

    ctx.save_for_backward(indices)
    return res


  @staticmethod
  def backward(ctx, grad_output):
    '''
        In the backward pass, we set the gradient to 1 for the winning units, and 0
        for the others.
    '''
    indices, = ctx.saved_tensors
    grad_x = torch.zeros_like(grad_output, requires_grad=True)

    # Probably a better way to do it, but this is not terrible as it only loops
    # over the batch size.
    for i in range(grad_output.size(0)):
      grad_x[i, indices[i]] = grad_output[i, indices[i]]

    # Output per inputs
    return grad_x, None, None, None

In [None]:
dtype = torch.float
# device = torch.device("cpu")
device = torch.device("cuda:0")  # Uncomment this to run on GPU
# torch.backends.cuda.matmul.allow_tf32 = False  # Uncomment this to run on GPU

# The above line disables TensorFloat32. This a feature that allows
# networks to run at a much faster speed while sacrificing precision.
# Although TensorFloat32 works well on most real models, for our toy model
# in this tutorial, the sacrificed precision causes convergence issue.
# For more information, see:
# https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold input and outputs.
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Create random Tensors for weights.
w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)

learningIterations = 0
dutyCyclePeriod = 1000
k = 25

dutyCycle = torch.zeros(H).to(device)


learning_rate = 1e-6
for t in range(20):
    # To apply our Function, we use Function.apply method. We alias this as 'relu'.
    relu = k_winners.apply

    # Forward pass: compute predicted y using operations; we compute
    # ReLU using our custom autograd operation.
    z = x.mm(w1)
    y_pred = relu(z, dutyCycle, k, 1).mm(w2)
    # Compute and print loss
    loss = (y_pred - y).pow(2).sum()
    if t % 100 == 99:
        print(t, loss.item())

    # Use autograd to compute the backward pass.
    loss.backward()

    # Update weights using gradient descent
    with torch.no_grad():
        w1 -= learning_rate * w1.grad
        w2 -= learning_rate * w2.grad

        # Manually zero the gradients after updating weights
        w1.grad.zero_()
        w2.grad.zero_()

    learningIterations += N
    period = min(dutyCyclePeriod, learningIterations)
    dutyCycle.mul_(period - N)
    dutyCycle.add_(z.gt(0).sum(dim=0, dtype=torch.float))
    dutyCycle.div_(period)


In [None]:
dutyCycle

In [None]:
batch_size = 12
learningIterations = 0
dutyCyclePeriod = 1000
n = 10
x = torch.randn(batch_size, n, device=device, dtype=dtype)

dutyCycle = torch.zeros(n).to(device)
for i in range(1000):
    x = torch.randn(batch_size, n, device=device, dtype=dtype)
    learningIterations += batch_size
    period = min(dutyCyclePeriod, learningIterations)
    period
    dutyCycle.mul_(period - batch_size)
    dutyCycle.add_(x.gt(0).sum(dim=0, dtype=torch.float))
    dutyCycle.div_(period)
    print(dutyCycle)

In [None]:
x[x>0].shape

In [None]:
a = x.gt(0)
a.shape

In [None]:
a.sum(dim=0, dtype=torch.float).shape

In [None]:
a[0:3].sum(dim=0)

In [None]:
a[0:3]

In [None]:
x.gt(0).sum(dim=0, dtype=torch.float)

In [None]:
x.gt(0)