In [18]:
# test if old and new add are the same
# this is concerning correspondence with 
import torch

In [19]:
# contributed by @yannadani.
def one_hot_add(inputs, shift):
    """Performs (inputs + shift) % vocab_size in the one-hot space.
    Args:
        inputs: Tensor of shape `[..., vocab_size]`. Typically a soft/hard one-hot
        Tensor.
        shift: Tensor of shape `[..., vocab_size]`. Typically a soft/hard one-hot
        Tensor specifying how much to shift the corresponding one-hot vector in
        inputs. Soft values perform a "weighted shift": for example,
        shift=[0.2, 0.3, 0.5] performs a linear combination of 0.2 * shifting by
        zero; 0.3 * shifting by one; and 0.5 * shifting by two.
    Returns:
        Tensor of same shape and dtype as inputs.
    """
    inputs = torch.stack((inputs, torch.zeros_like(inputs)), dim = -1)
    shift = torch.stack((shift, torch.zeros_like(shift)), dim = -1)
    inputs_fft = torch.fft(inputs, 1) #ignore last and first dimension to do batched fft
    shift_fft = torch.fft(shift, 1)
    result_fft_real = inputs_fft[...,0]*shift_fft[...,0] - inputs_fft[...,1]*shift_fft[...,1]
    result_fft_imag = inputs_fft[...,0]*shift_fft[...,1] + inputs_fft[...,1]*shift_fft[...,0]
    result_fft = torch.stack((result_fft_real,result_fft_imag), dim = -1)
    return torch.ifft(result_fft, 1)[...,0] #return only the real part

def one_hot_add_old(inputs, shift):
    """Performs (inputs - shift) % vocab_size in the one-hot space.
    Args:
        inputs: Tensor of shape `[..., vocab_size]`. Typically a soft/hard one-hot
        Tensor.
        shift: Tensor of shape `[..., vocab_size]`. Typically a soft/hard one-hot
        Tensor specifying how much to shift the corresponding one-hot vector in
        inputs. Soft values perform a "weighted shift": for example,
        shift=[0.2, 0.3, 0.5] performs a linear combination of 0.2 * shifting by
        zero; 0.3 * shifting by one; and 0.5 * shifting by two.
    Returns:
        Tensor of same shape and dtype as inputs.
    """
    shift = shift.type(inputs.dtype)
    vocab_size = inputs.shape[-1]
    # Form a [..., vocab_size, vocab_size] matrix. Each batch element of
    # inputs will vector-matrix multiply the vocab_size x vocab_size matrix. This
    # "shifts" the inputs batch element by the corresponding shift batch element.
    shift_matrix = torch.stack([torch.roll(shift, i, dims=-1)
                            for i in range(vocab_size)], dim=-2)
    shift_matrix = torch.transpose(shift_matrix, -1, -2)
    outputs = torch.einsum('...v,...uv->...u', inputs, shift_matrix)
    return outputs

In [20]:
inputs = torch.Tensor([[0,0,1],[0,1,0], [1,0,0]])
shift = torch.Tensor([[0,0,1],[0,1,0], [0,0,1]])

In [23]:
%timeit disc_utils.one_hot_add_old(inputs, shift).long()

79.6 µs ± 1.53 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [22]:
%timeit disc_utils.one_hot_add(inputs, shift).long()

84.2 µs ± 611 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [11]:
disc_utils.one_hot_add_old(inputs, shift).long() == disc_utils.one_hot_add(inputs, shift).long()

tensor([[True, True, True],
        [True, True, True],
        [True, True, True]])

In [None]:
# Why are these soft shifts not working for either? 

In [41]:
inputs = torch.Tensor([[0,0,1],[0,1,0], [1,0,0]])
shift = torch.randn((3,3)).softmax(dim=-1)
print(inputs, shift)

tensor([[0., 0., 1.],
        [0., 1., 0.],
        [1., 0., 0.]]) tensor([[0.0449, 0.7862, 0.1688],
        [0.3378, 0.4390, 0.2232],
        [0.2466, 0.2959, 0.4575]])


In [42]:
disc_utils.one_hot_add_old(inputs, shift).long()

tensor([[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]])

In [43]:
disc_utils.one_hot_add(inputs, shift).long()

tensor([[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]])