In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch.nn as nn
import math
import torch


class FourierFeatures(nn.Module):
    def __init__(self, in_features, out_features, std=1.):
        super().__init__()
        assert out_features % 2 == 0
        self.register_buffer('weight', torch.randn([out_features // 2, in_features]) * std)

    def forward(self, input):
        f = 2 * math.pi * input @ self.weight.T
        return torch.cat([f.cos(), f.sin()], dim=-1)


class Base2FourierFeatures(nn.Module):
    # jax to torch adaptation of VDM code
    # https://github.com/google-research/vdm/blob/main/model_vdm.py#L618
    def __init__(self, start=4, stop=8, step=1):
        self.start = start
        self.stop = stop
        self.step = step

    def forward(self, inputs):
        freqs = torch.arange(self.start, self.stop, self.step).to(dtype=inputs.dtype)

        # Create Base 2 Fourier features
        w = 2. ** freqs * 2 * math.pi
        w = torch.tile(w[None, :], (1, inputs.shape[-1]))

        # Compute features
        h = inputs.repeat(1, 1, len(freqs))
        h = w * h
        h = torch.concatenate([torch.sin(h), torch.cos(h)], axis=-1)
        return h

In [11]:
fourier_features = FourierFeatures(1, 128)
t = torch.randint(1, 1000, (100, 1))
temb = fourier_features(t)

In [12]:
print(t[0])
print(temb[0])

tensor([59])
tensor([-0.1310,  0.4321,  0.1970,  0.8801, -0.6391,  0.8497,  0.6323,  0.0497,
         0.6066,  0.9331,  0.9964, -0.6086, -0.8721, -0.8267, -0.2844, -0.3201,
        -0.2683, -0.1854,  0.9878,  0.7592, -0.0132,  0.1819,  0.2991,  0.3332,
         0.7246, -0.0344, -0.8963,  0.5551,  0.6776, -0.3859, -0.0868,  0.6080,
        -0.8535, -0.9583, -0.2037, -0.8667,  0.3768,  0.7629, -0.9520, -0.3909,
         0.9445, -0.7017,  0.9801,  0.2043,  0.7360,  0.9140, -0.7241,  0.5020,
         0.7858,  0.5664,  0.7246,  0.3241,  0.5446,  0.9582,  0.6491,  0.9756,
        -0.6356,  0.9341, -0.9754, -0.4616, -0.9167, -0.6897, -0.9992, -0.8050,
         0.9914,  0.9018,  0.9804, -0.4748, -0.7691, -0.5273, -0.7747,  0.9988,
        -0.7950, -0.3597,  0.0849, -0.7934,  0.4893, -0.5627, -0.9587, -0.9474,
        -0.9633, -0.9827, -0.1559,  0.6509, -0.9999,  0.9833,  0.9542, -0.9428,
         0.6892,  0.9994, -0.4435,  0.8318,  0.7354, -0.9226, -0.9962, -0.7939,
        -0.5212,  0.2858, -

In [15]:
start, stop, step = 4, 8, 1
N, L, C = 4, 128, 1024
inputs = torch.randn(N, L, C)
freqs = torch.arange(start, stop, step).to(dtype=inputs.dtype)
freqs

tensor([4., 5., 6., 7.])

In [17]:
w = 2. ** freqs * 2 * math.pi
print(w)
w = torch.tile(w[None, :], (1, inputs.shape[-1]))
print(w.shape)
print(w)


tensor([100.5310, 201.0619, 402.1239, 804.2477])
torch.Size([1, 4096])
tensor([[100.5310, 201.0619, 402.1239,  ..., 201.0619, 402.1239, 804.2477]])


In [30]:
h = inputs.repeat(1, 1, len(freqs)).shape

In [31]:
w * h

TypeError: only integer tensors of a single element can be converted to an index