In [1]:
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import torch
import torch.nn as nn

from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import BaseOutput
from diffusers.models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block
from einops import rearrange, repeat, reduce

In [2]:
block_out_channels = 4
act_fn = 'silu'
num_class_embeds = 5

In [11]:
sample_size = 4
channels = 2
batch = 3
sample = torch.randn(batch, channels, sample_size)
timestep = [3, 4, 5]
labels = torch.tensor([1, 2, 3])
class_cond = torch.tensor([1, 2, 3]).unsqueeze(-1)

In [5]:
tlabels = labels.repeat(16,1).T

# Class append embedding

In [12]:
class_cond.shape

torch.Size([3, 1])

In [14]:
class_cond = class_cond.view(batch, class_cond.shape[1], 1).expand(batch, class_cond.shape[1], sample_size)

In [17]:
class_cond

tensor([[[1, 1, 1, 1]],

        [[2, 2, 2, 2]],

        [[3, 3, 3, 3]]])

# Fourier +  time step class embedding

In [124]:
time_proj = GaussianFourierProjection(embedding_size=block_out_channels,
                                        set_W_to_weight=False,
                                        log=False,
                                        flip_sin_to_cos=False)
timestep_input_dim = 2 * block_out_channels
time_embed_dim = sample_size

In [125]:
class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)

In [126]:
timesteps = torch.tensor(timestep)
timestep_embed = time_proj(timesteps)
timestep_embed.shape

torch.Size([3, 8])

In [127]:
timestep_embed = timestep_embed[..., None]
timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]                                                 ])
timestep_embed.shape

torch.Size([3, 8, 4])

In [128]:
timestep_embed = timestep_embed.broadcast_to(sample.shape[:1] + timestep_embed.shape[1:])
timestep_embed.shape

torch.Size([3, 8, 4])

In [129]:
timestep_embed

tensor([[[-0.0751, -0.0751, -0.0751, -0.0751],
         [-0.5917, -0.5917, -0.5917, -0.5917],
         [-0.5586, -0.5586, -0.5586, -0.5586],
         [ 0.6599,  0.6599,  0.6599,  0.6599],
         [-0.9972, -0.9972, -0.9972, -0.9972],
         [-0.8062, -0.8062, -0.8062, -0.8062],
         [ 0.8294,  0.8294,  0.8294,  0.8294],
         [ 0.7513,  0.7513,  0.7513,  0.7513]],

        [[-0.9117, -0.9117, -0.9117, -0.9117],
         [ 0.2016,  0.2016,  0.2016,  0.2016],
         [-0.2541, -0.2541, -0.2541, -0.2541],
         [ 0.0861,  0.0861,  0.0861,  0.0861],
         [-0.4108, -0.4108, -0.4108, -0.4108],
         [-0.9795, -0.9795, -0.9795, -0.9795],
         [-0.9672, -0.9672, -0.9672, -0.9672],
         [-0.9963, -0.9963, -0.9963, -0.9963]],

        [[-0.7967, -0.7967, -0.7967, -0.7967],
         [ 0.8620,  0.8620,  0.8620,  0.8620],
         [ 0.8941,  0.8941,  0.8941,  0.8941],
         [-0.7791, -0.7791, -0.7791, -0.7791],
         [ 0.6043,  0.6043,  0.6043,  0.6043],
         

In [130]:
class_labels = time_proj(labels)
class_labels.shape

torch.Size([3, 8])

In [131]:
class_emb = class_embedding(class_labels.to(dtype=torch.float32))
class_emb.shape

torch.Size([3, 4])

In [132]:
class_emb = class_emb.unsqueeze(1)
class_emb.shape

torch.Size([3, 1, 4])

In [133]:
class_emb

tensor([[[ 0.2214,  0.0418,  0.5393,  0.5950]],

        [[ 0.4667, -0.1553,  0.1507,  0.3344]],

        [[ 0.2767, -0.1439,  0.2361,  0.3369]]], grad_fn=<UnsqueezeBackward0>)

In [134]:
timestep_embed = (timestep_embed + class_emb)


In [135]:
timestep_embed

tensor([[[ 0.1463, -0.0333,  0.4641,  0.5199],
         [-0.3703, -0.5499, -0.0524,  0.0033],
         [-0.3373, -0.5168, -0.0194,  0.0364],
         [ 0.8813,  0.7017,  1.1992,  1.2549],
         [-0.7758, -0.9554, -0.4579, -0.4022],
         [-0.5848, -0.7644, -0.2669, -0.2112],
         [ 1.0508,  0.8712,  1.3687,  1.4244],
         [ 0.9727,  0.7931,  1.2906,  1.3463]],

        [[-0.4450, -1.0670, -0.7610, -0.5774],
         [ 0.6683,  0.0463,  0.3523,  0.5359],
         [ 0.2126, -0.4094, -0.1033,  0.0803],
         [ 0.5528, -0.0692,  0.2369,  0.4205],
         [ 0.0559, -0.5661, -0.2601, -0.0765],
         [-0.5128, -1.1348, -0.8287, -0.6451],
         [-0.5005, -1.1225, -0.8164, -0.6328],
         [-0.5296, -1.1516, -0.8455, -0.6619]],

        [[-0.5201, -0.9406, -0.5606, -0.4598],
         [ 1.1386,  0.7181,  1.0981,  1.1989],
         [ 1.1708,  0.7502,  1.1303,  1.2311],
         [-0.5024, -0.9230, -0.5430, -0.4421],
         [ 0.8810,  0.4604,  0.8404,  0.9413],
         

# Fourier + timestep embedding + time step class embedding

In [10]:
time_proj = GaussianFourierProjection(embedding_size=block_out_channels,
                                        set_W_to_weight=False,
                                        log=False,
                                        flip_sin_to_cos=False)
timestep_input_dim = 2 * block_out_channels
time_embed_dim = sample_size

In [11]:
time_embed_dim = block_out_channels * 4
time_mlp = TimestepEmbedding(
    in_channels=timestep_input_dim,
    time_embed_dim=time_embed_dim,
    act_fn=act_fn,
)
time_embed_dim

16

In [12]:
class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)

In [13]:
timesteps = torch.tensor(timestep)
timestep_embed = time_proj(timesteps)
timestep_embed.shape

torch.Size([3, 8])

In [14]:
timestep_embed = time_mlp(timestep_embed)
timestep_embed.shape

torch.Size([3, 16])

In [15]:
timestep_embed

tensor([[-0.0429,  0.4077, -0.2517,  0.1377, -0.2646, -0.1397,  0.1769, -0.2928,
         -0.2095, -0.3262, -0.1920,  0.1318,  0.2291,  0.0844,  0.3249,  0.1121],
        [-0.2860,  0.1787, -0.1658,  0.0799, -0.0692,  0.1741, -0.0674, -0.2295,
          0.0851, -0.2297, -0.2457,  0.2494,  0.2485,  0.2833, -0.2228,  0.0233],
        [ 0.0506,  0.1341, -0.1155,  0.0074, -0.1417, -0.0106, -0.1381, -0.1494,
         -0.1552,  0.1464, -0.3253,  0.0620, -0.1281,  0.2302,  0.0016,  0.2741]],
       grad_fn=<AddmmBackward0>)

In [18]:
class_labels = time_proj(labels)
class_labels.shape

torch.Size([3, 8])

In [149]:
class_emb = class_embedding(class_labels.to(dtype=torch.float32))
class_emb.shape

torch.Size([3, 16])

In [150]:
class_emb = class_emb
class_emb.shape

torch.Size([3, 16])

In [17]:
class_emb

NameError: name 'class_emb' is not defined

In [152]:
timestep_embed = (timestep_embed + class_emb)


In [153]:
timestep_embed

tensor([[-0.3112, -0.1256,  0.5594, -0.3711, -0.0022, -0.1895,  0.1036, -0.2802,
         -0.1124, -0.0380, -0.2356,  0.0416, -0.2704,  0.2143,  0.0297, -0.0095],
        [-0.4184,  0.0163,  0.3937, -0.2800,  0.0953, -0.4466,  0.2610, -0.1681,
         -0.0777,  0.0624, -0.1192, -0.3193, -0.2907,  0.1534,  0.1745, -0.0607],
        [-0.5161, -0.0913,  0.1761, -0.4073, -0.1938, -0.4360,  0.1556,  0.2373,
          0.1228,  0.3006,  0.0466, -0.1802, -0.1473, -0.0377,  0.4903, -0.4182]],
       grad_fn=<AddBackward0>)

# Fourier +  Embedding class embedding

In [235]:
time_proj = GaussianFourierProjection(embedding_size=12,
                                        set_W_to_weight=False,
                                        log=False,
                                        flip_sin_to_cos=False)
timestep_input_dim = 2 * block_out_channels
time_embed_dim = block_out_channels * 4

In [219]:
timestep_input_dim, time_embed_dim

(8, 16)

In [233]:
class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)

In [236]:
timesteps = torch.tensor(timestep)
timestep_embed = time_proj(timesteps)
timestep_embed.shape

torch.Size([3, 24])

In [222]:
timestep_embed = timestep_embed[..., None]
timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]])
timestep_embed.shape                                  

torch.Size([3, 8, 4])

In [223]:
timestep_embed = timestep_embed.broadcast_to(sample.shape[:1] + timestep_embed.shape[1:])
timestep_embed.shape

torch.Size([3, 8, 4])

In [224]:
timestep_embed

tensor([[[ 0.9573,  0.9573,  0.9573,  0.9573],
         [-0.6163, -0.6163, -0.6163, -0.6163],
         [ 0.4987,  0.4987,  0.4987,  0.4987],
         [-0.9955, -0.9955, -0.9955, -0.9955],
         [-0.2892, -0.2892, -0.2892, -0.2892],
         [-0.7875, -0.7875, -0.7875, -0.7875],
         [ 0.8668,  0.8668,  0.8668,  0.8668],
         [ 0.0945,  0.0945,  0.0945,  0.0945]],

        [[-0.9912, -0.9912, -0.9912, -0.9912],
         [-0.9353, -0.9353, -0.9353, -0.9353],
         [ 0.6412,  0.6412,  0.6412,  0.6412],
         [ 0.7962,  0.7962,  0.7962,  0.7962],
         [-0.1321, -0.1321, -0.1321, -0.1321],
         [ 0.3539,  0.3539,  0.3539,  0.3539],
         [ 0.7674,  0.7674,  0.7674,  0.7674],
         [-0.6050, -0.6050, -0.6050, -0.6050]],

        [[ 0.8482,  0.8482,  0.8482,  0.8482],
         [ 0.0594,  0.0594,  0.0594,  0.0594],
         [ 0.7644,  0.7644,  0.7644,  0.7644],
         [-0.3578, -0.3578, -0.3578, -0.3578],
         [ 0.5297,  0.5297,  0.5297,  0.5297],
         

In [206]:
# class_labels = time_proj(labels)
# class_labels.shape

torch.Size([3, 8])

In [225]:
class_emb = class_embedding(labels)
class_emb.shape

torch.Size([3, 16])

In [226]:
class_emb

tensor([[-1.8699,  0.2417, -1.5484,  0.6525,  0.1199, -0.2426, -0.5914, -0.0531,
         -0.2414,  0.0256,  0.3370,  0.4558, -0.6298,  1.3408, -0.1176, -2.3916],
        [-0.0326, -0.2886,  0.2077,  0.1122, -0.8515, -2.5415, -0.6106,  0.8874,
          3.1955, -1.1452,  1.9092,  2.6204,  0.8997, -0.8625,  0.6959, -0.3227],
        [-1.1923, -0.3724,  1.1350,  0.8426, -1.8330,  0.0184, -2.0985,  0.6326,
          0.4346, -0.2235, -1.1300,  1.2645, -0.9645,  0.5853,  1.1197,  2.0987]],
       grad_fn=<EmbeddingBackward0>)

In [181]:
class_emb = class_emb.repeat([1, sample.shape[2]])
# .reshape(
#                         (sample.shape[0], 1, sample.shape[2]))
class_emb.shape

torch.Size([3, 32])

In [227]:
class_emb

tensor([[-1.8699,  0.2417, -1.5484,  0.6525,  0.1199, -0.2426, -0.5914, -0.0531,
         -0.2414,  0.0256,  0.3370,  0.4558, -0.6298,  1.3408, -0.1176, -2.3916],
        [-0.0326, -0.2886,  0.2077,  0.1122, -0.8515, -2.5415, -0.6106,  0.8874,
          3.1955, -1.1452,  1.9092,  2.6204,  0.8997, -0.8625,  0.6959, -0.3227],
        [-1.1923, -0.3724,  1.1350,  0.8426, -1.8330,  0.0184, -2.0985,  0.6326,
          0.4346, -0.2235, -1.1300,  1.2645, -0.9645,  0.5853,  1.1197,  2.0987]],
       grad_fn=<EmbeddingBackward0>)

In [228]:
timestep_embed = (timestep_embed + class_emb)


RuntimeError: The size of tensor a (4) must match the size of tensor b (16) at non-singleton dimension 2

In [None]:
timestep_embed

tensor([[[ 0.1463, -0.0333,  0.4641,  0.5199],
         [-0.3703, -0.5499, -0.0524,  0.0033],
         [-0.3373, -0.5168, -0.0194,  0.0364],
         [ 0.8813,  0.7017,  1.1992,  1.2549],
         [-0.7758, -0.9554, -0.4579, -0.4022],
         [-0.5848, -0.7644, -0.2669, -0.2112],
         [ 1.0508,  0.8712,  1.3687,  1.4244],
         [ 0.9727,  0.7931,  1.2906,  1.3463]],

        [[-0.4450, -1.0670, -0.7610, -0.5774],
         [ 0.6683,  0.0463,  0.3523,  0.5359],
         [ 0.2126, -0.4094, -0.1033,  0.0803],
         [ 0.5528, -0.0692,  0.2369,  0.4205],
         [ 0.0559, -0.5661, -0.2601, -0.0765],
         [-0.5128, -1.1348, -0.8287, -0.6451],
         [-0.5005, -1.1225, -0.8164, -0.6328],
         [-0.5296, -1.1516, -0.8455, -0.6619]],

        [[-0.5201, -0.9406, -0.5606, -0.4598],
         [ 1.1386,  0.7181,  1.0981,  1.1989],
         [ 1.1708,  0.7502,  1.1303,  1.2311],
         [-0.5024, -0.9230, -0.5430, -0.4421],
         [ 0.8810,  0.4604,  0.8404,  0.9413],
         

---

In [40]:
time_proj_t = Timesteps(block_out_channels, flip_sin_to_cos=False, downscale_freq_shift=0.0)
timestep_input_dim = block_out_channels

In [41]:
time_embed_dim = block_out_channels * 4
time_mlp = TimestepEmbedding(
                in_channels=timestep_input_dim,
                time_embed_dim=time_embed_dim,
                act_fn=act_fn,
            )

In [30]:
class_embedding_e = nn.Embedding(num_class_embeds,
                                    time_embed_dim)

In [31]:
class_embedding_t = TimestepEmbedding(timestep_input_dim,
                                                     time_embed_dim)

In [59]:
data = torch.randn(3, 2, 3)

In [61]:
data.shape[1:]

torch.Size([2, 3])

In [64]:
timesteps = [3, 4, 5]
ts = torch.tensor(timesteps)

In [69]:
tsp = time_proj_t(ts)
tsp.shape
# tp = time_proj_g(ts)

torch.Size([3, 4])

In [74]:
tspt = time_mlp(tsp)
tspt.shape

torch.Size([3, 16])

In [72]:
tspt = tsp.repeat([1, 1, data.shape[2]])
tspt.shape

torch.Size([1, 3, 12])

In [73]:
tspt = tspt.broadcast_to(data.shape[:1] + tspt.shape[1:])
tspt.shape

torch.Size([3, 3, 12])