In [2]:
import torch
import torch.nn.functional as F

In [3]:
A = torch.randn((262000, 2000)) / 2000 # generate A so that Ah ~ normal(0,1)
h = torch.randn((2000, 1))
b = torch.randn((262000, 1))

W_size = 256
N_size = 16000

In [21]:
def Softmax(A, h, b):
    L = torch.autograd.Variable(F.linear(h.transpose(0, 1), weight=A, bias=b[:, 0]))
    return F.softmax(L, dim=1).data.transpose(0, 1)

def SVDSoftmax(A, h, b, W_size, N_size, B, V_t):
    h = torch.mm(V_t,h)
    z = torch.zeros_like(b)
    torch.add(torch.mm(B[:, :W_size], h[:W_size]), b, out=z)

    top_k_ind = torch.topk(z, k=W_size, dim=0, sorted=False)[1][:,0]
    torch.add(torch.mm(B[top_k_ind], h), b[top_k_ind], out=z[top_k_ind])
    z_exp = torch.exp(z - torch.max(z))
    return z_exp / z_exp.sum()

In [5]:
U, S, V = torch.svd(A)
V_t = V.transpose(0, 1)
B = U * S

In [6]:
%timeit SVDSoftmax(A, h, b, W_size, N_size, B, V_t)

10 loops, best of 3: 33.9 ms per loop


In [8]:
%timeit Softmax(A, h, b)

10 loops, best of 3: 96.2 ms per loop


In [22]:
import numpy as np
prob_softmax = Softmax(A, h, b)
prob_svdsoftmax = SVDSoftmax(A, h, b, W_size, N_size, B, V_t)
np.linalg.norm(prob_svdsoftmax - prob_softmax) / np.linalg.norm(prob_softmax)

0.020496925