In [3]:
import numpy as np
import math
import torch
from torch import nn
from torch.nn import functional as F

In [4]:
def get_timestep_embedding(
    timesteps: torch.Tensor,
    embedding_dim: int,
    flip_sin_to_cos: bool = False,
    downscale_freq_shift: float = 1,
    scale: float = 1,
    max_period: int = 10000,
) -> torch.Tensor:
    """
    This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.

    Args
        timesteps (torch.Tensor):
            a 1-D Tensor of N indices, one per batch element. These may be fractional.
        embedding_dim (int):
            the dimension of the output.
        flip_sin_to_cos (bool):
            Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
        downscale_freq_shift (float):
            Controls the delta between frequencies between dimensions
        scale (float):
            Scaling factor applied to the embeddings.
        max_period (int):
            Controls the maximum frequency of the embeddings
    Returns
        torch.Tensor: an [N x dim] Tensor of positional embeddings.
    """
    assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"

    half_dim = embedding_dim // 2
    exponent = -math.log(max_period) * torch.arange(
        start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
    )
    exponent = exponent / (half_dim - downscale_freq_shift)

    emb = torch.exp(exponent)
    emb = timesteps[:, None].float() * emb[None, :]

    # scale embeddings
    emb = scale * emb

    # concat sine and cosine embeddings
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)

    # flip sine and cosine embeddings
    if flip_sin_to_cos:
        emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)

    # zero pad
    if embedding_dim % 2 == 1:
        emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
    return emb

In [5]:
timesteps = torch.arange(0, 10, dtype=torch.float32)
timesteps

tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.])

In [8]:
get_timestep_embedding(timesteps, embedding_dim=10)

tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          1.0000e+00,  1.0000e+00,  1.0000e+00,  1.0000e+00,  1.0000e+00],
        [ 8.4147e-01,  9.9833e-02,  9.9998e-03,  1.0000e-03,  1.0000e-04,
          5.4030e-01,  9.9500e-01,  9.9995e-01,  1.0000e+00,  1.0000e+00],
        [ 9.0930e-01,  1.9867e-01,  1.9999e-02,  2.0000e-03,  2.0000e-04,
         -4.1615e-01,  9.8007e-01,  9.9980e-01,  1.0000e+00,  1.0000e+00],
        [ 1.4112e-01,  2.9552e-01,  2.9995e-02,  3.0000e-03,  3.0000e-04,
         -9.8999e-01,  9.5534e-01,  9.9955e-01,  1.0000e+00,  1.0000e+00],
        [-7.5680e-01,  3.8942e-01,  3.9989e-02,  4.0000e-03,  4.0000e-04,
         -6.5364e-01,  9.2106e-01,  9.9920e-01,  9.9999e-01,  1.0000e+00],
        [-9.5892e-01,  4.7943e-01,  4.9979e-02,  5.0000e-03,  5.0000e-04,
          2.8366e-01,  8.7758e-01,  9.9875e-01,  9.9999e-01,  1.0000e+00],
        [-2.7942e-01,  5.6464e-01,  5.9964e-02,  6.0000e-03,  6.0000e-04,
          9.6017e-01,  8.2534e-0

In [9]:

class PeriodicEncoding(nn.Module):
    def __init__(
            self,
            n_dim: int,
            max_period: int = 10000,
    ):
        super().__init__()
        assert n_dim % 2 == 0

        half_dim = n_dim // 2
        freq_base = torch.linspace(0, 1, steps=half_dim, dtype=torch.float32)
        freq = torch.exp(-math.log(max_period) * freq_base)
        self.register_buffer('freq', freq)

    def forward(self, x: torch.Tensor):
        freq = x[:, None] * self.freq[None, :]
        emb = torch.cat([torch.sin(freq), torch.cos(freq)], dim=-1)
        return emb

In [10]:
PeriodicEncoding(n_dim=10)(timesteps)

tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          1.0000e+00,  1.0000e+00,  1.0000e+00,  1.0000e+00,  1.0000e+00],
        [ 8.4147e-01,  9.9833e-02,  9.9998e-03,  1.0000e-03,  1.0000e-04,
          5.4030e-01,  9.9500e-01,  9.9995e-01,  1.0000e+00,  1.0000e+00],
        [ 9.0930e-01,  1.9867e-01,  1.9999e-02,  2.0000e-03,  2.0000e-04,
         -4.1615e-01,  9.8007e-01,  9.9980e-01,  1.0000e+00,  1.0000e+00],
        [ 1.4112e-01,  2.9552e-01,  2.9995e-02,  3.0000e-03,  3.0000e-04,
         -9.8999e-01,  9.5534e-01,  9.9955e-01,  1.0000e+00,  1.0000e+00],
        [-7.5680e-01,  3.8942e-01,  3.9989e-02,  4.0000e-03,  4.0000e-04,
         -6.5364e-01,  9.2106e-01,  9.9920e-01,  9.9999e-01,  1.0000e+00],
        [-9.5892e-01,  4.7943e-01,  4.9979e-02,  5.0000e-03,  5.0000e-04,
          2.8366e-01,  8.7758e-01,  9.9875e-01,  9.9999e-01,  1.0000e+00],
        [-2.7942e-01,  5.6464e-01,  5.9964e-02,  6.0000e-03,  6.0000e-04,
          9.6017e-01,  8.2534e-0