In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


In [5]:
class DiffusionEmbedding(nn.Module):
    def __init__(self, num_steps, shape, embedding_dim=128, projection_dim=None):
        super(DiffusionEmbedding, self).__init__()
        if projection_dim is None:
            projection_dim = embedding_dim
        self.register_buffer(
            "embedding",
            self._build_embedding(num_steps, embedding_dim / 2),
            persistent=False,
        )
        self.shape = shape
        self.projection1 = nn.Linear(embedding_dim, projection_dim)
        self.projection2 = nn.Linear(projection_dim, shape[-1])

    def forward(self, diffusion_step):
        x = self.embedding[diffusion_step]
        x = self.projection1(x)
        x = F.silu(x)
        x = self.projection2(x)
        x = F.silu(x)
        x = torch.zeros(self.shape) + x.unsqueeze(1).unsqueeze(1)
        return x

    def _build_embedding(self, num_steps, dim=64):
        steps = torch.arange(num_steps).unsqueeze(1)  # (T,1)
        frequencies = 10.0 ** (torch.arange(dim) / (dim - 1) * 4.0).unsqueeze(0)  # (1,dim)
        table = steps * frequencies  # (T,dim)
        table = torch.cat([torch.sin(table), torch.cos(table)], dim=1)  # (T,dim*2)
        return table

In [6]:
diffusion_embedding = DiffusionEmbedding(100, (2, 3, 4, 6))
embeddings = diffusion_embedding([1, 3])
print(embeddings.shape)

torch.Size([2, 3, 4, 6])


In [8]:
class TimeEmbedding(nn.Module):
    def __init__(self, shape, max_len=10000.0):
        super(TimeEmbedding, self).__init__()
        self.shape = shape
        self.max_len = max_len
        self.learnable = nn.Linear(shape[-1], shape[-1])

    def forward(self):
        b, l, f, e = self.shape
        pe = torch.arange(l).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
        pe = torch.zeros(self.shape) + pe
        
        div_term = 1 / torch.pow(
            self.max_len, torch.arange(0, f, 2) / f
        ).unsqueeze(-1)

        pe[:, :, 0::2] = torch.sin(pe[:, :, 0::2] * div_term)
        pe[:, :, 1::2] = torch.cos(pe[:, :, 1::2] * div_term)

        return self.learnable(pe)

In [9]:
a = torch.randn((2, 3, 4, 6))

time_embedding = TimeEmbedding(a.shape)
embeddings = time_embedding.forward()
print(embeddings.shape)

torch.Size([2, 3, 4, 6])


In [10]:
class FeatureEmbedding(nn.Module):
    def __init__(self, shape, max_len=10000.0):
        super(FeatureEmbedding, self).__init__()
        self.shape = shape
        self.max_len = max_len
        self.learnable = nn.Linear(shape[-1], shape[-1])

    def forward(self):
        b, l, f, e = self.shape
        pe = torch.arange(f).unsqueeze(0).unsqueeze(0).unsqueeze(-1)
        pe = torch.zeros(self.shape) + pe

        div_term = 1 / torch.pow(
            self.max_len, torch.arange(0, e, 2) / e
        )

        pe[:, :, :, 0::2] = torch.sin(pe[:, :, :, 0::2] * div_term)
        pe[:, :, :, 1::2] = torch.cos(pe[:, :, :, 1::2] * div_term) 

        return self.learnable(pe)

In [11]:
a = torch.randn((2, 3, 4, 6))

feature_embedding = FeatureEmbedding(a.shape)
embeddings = feature_embedding()
print(embeddings.shape)

torch.Size([2, 3, 4, 6])
