In [2]:
import torch
import torch.nn as nn

In [None]:
class TimeEmbedding(nn.Module):
    def __init__(self, embedding_dim):
        super().__init__()
        self.embedding_dim = embedding_dim 
    
    # (B, ) -----> (B, embedding_dim)
    def forward(self, x):
        x = x.reshape((len(x),1))
        print(x.shape)
        embeddings = torch.zeros(size=(x.shape[0], self.embedding_dim),device=x.device)
        denominators = 10000 ** (2 * torch.arange(self.embedding_dim//2, device=x.device) / self.embedding_dim)
        embeddings[:,::2] = torch.sin(x/denominators)
        embeddings[:,1::2] = torch.cos(x/denominators)
        return embeddings

In [None]:
class diffusion_unet(nn.Module):
    def __init__(self, init_dim, output_dim, dim_mults, resnet_block_groups, input_channel=3, time_mult = 4):
        super().__init__()

        # initial conv layer: (B, 3, H, W) ----------> (B, init_dim, H, W)
        self.init_conv = nn.Conv2d(input_channel, init_dim, kernel_size=1)
        
        # set up downsampling/upsampling channel dimension changes: [(init_dim, init_dim*dim_mults[0]), (init_dim, init_dim*dim_mults[1]), ...]
        dims = [init_dim] + [n * init_dim for n in dim_mults]
        dims = list(zip(dims[:-1], dims[1:]))

        # time embedding: (batch, ) --------> (batch, time_dim)
        time_dim = time_mult * init_dim
        self.time_embed = nn.Sequential(
            TimeEmbedding(init_dim),
            nn.Linear(init_dim, time_dim),
            nn.GELU(),
            nn.Linear(time_dim,time_dim)
        )