In [1]:
import torch
import torch.nn as nn

In [None]:
class TimeEmbedding(nn.Module):
    def __init__(self, embedding_dim):
        super().__init__()
        self.embedding_dim = embedding_dim 
    
    # (B, ) -----> (B, embedding_dim)
    def forward(self, x):
        x = x.reshape((len(x),1))
        embeddings = torch.zeros(size=(x.shape[0], self.embedding_dim),device=x.device)
        denominators = 10000 ** (2 * torch.arange(self.embedding_dim//2, device=x.device) / self.embedding_dim)
        embeddings[:,::2] = torch.sin(x/denominators)
        embeddings[:,1::2] = torch.cos(x/denominators)
        return embeddings

In [None]:
class WeightStandardizedConv2d(nn.Conv2d):
    """ 
        compute the mean and std for every batch weights (across Channel, Hight, Weight dimensions) and normalize Conv2d weights
    """
    def forward(self, x):
        weights = self.weight 
        w_mean = weights.mean(dim=[1,2,3], keep_dim=1)
        w_var = weights.var(dim=[1,2,3], keep_dim=True, correction=0)
        eps = 1e-5 if x.dtype == torch.float32 else 1e-3

        weights = (weights - w_mean) / torch.sqrt(w_var + eps)

        return torch.nn.functional.conv2d(
            x,
            weights,
            self.bias,
            self.stride,
            self.padding,
            self.dilation,
            self.groups
        )

In [None]:
class Block(nn.Module):
    """ 
        weight_standardized conv2d + group_norm + (optional) time_embedding
            details of weight_standardization is wrapped around WeightStandardizedConv2d class
    """
    def __init__(self, in_channels, out_channels, groups=8):
        super().__init__()
        self.proj = WeightStandardizedConv2d(in_channels, out_channels, 3, padding=1)   # dimensions don't change
        self.norm = nn.GroupNorm(groups, out_channels)
        self.act = nn.SiLU() 
    
    def forward(self, x, scale_shift=None):
        x = self.proj(x)
        x = self.norm(x)

        if scale_shift:
            scale, shift = scale_shift 
            x = x * (scale + 1) + shift 
        
        x = self.act(x)
        return x


In [None]:
class ResnetBlock(nn.Module):
    """ 
        Apply Residual Connection between original x and processed x, processes include two layers of normalized_weighted conv2d + group_norm + time_embed(only first layer)
            time_embed has shape (B, 2 * C, 1, 1), split into scale and shift and broadcast to the x with shape (B, C, H, W)
            details of each processing layer are wrapped around Block class
    """
    def __init__(self, in_channels, out_channels, time_emb_dim=None, groups=8):
        super().__init__()
        if time_emb_dim:
            self.mlp = nn.Sequential(
                nn.SiLU(),
                nn.Linear(time_emb_dim, 2 * out_channels)
            )
        else:
            self.mlp=None 
        
        self.block1 = Block(in_channels, out_channels, groups)      
        self.block2 = Block(out_channels, out_channels, groups)     

        if in_channels == out_channels:
            self.res_conv = nn.Identity()
        else:
            self.res_conv = nn.Conv2d(in_channels, out_channels, 1)
        
    def forward(self, x, time_emb=None):
        scale_shift = None 
        if self.mlp and time_emb:
            time_emb = self.mlp(time_emb)
            time_emb = time_emb.view(*time_emb.shape, 1, 1) # (batch, 2 * out_channels, 1, 1)
            scale_shift = time_emb.chunk(2, dim=1)
        
        h = self.block1(x, scale_shift=scale_shift)     # normalized-weight Conv2d + GroupNorm + time_embedding, channel dimension changed
        h = self.block2(h)                              # normalized-weight Conv2d + GroupNorm, no dimension change

        return h + self.res_conv(x)                     # residual connection

In [None]:
class PreGroupNorm(nn.Module):
    """ 
        Enable applying GN before applying the func to the input (in this case, the func is LinearAttention)
    """
    def __init__(self, dim, func, groups=1):
        super().__init__()
        self.func = func 
        self.group_norm = nn.GroupNorm(groups,dim)
    
    def forward(self, x):
        x = self.group_norm(x)
        x = self.func(x)
        return x

In [None]:
class diffusion_unet(nn.Module):
    def __init__(self, init_dim, output_dim, dim_mults, resnet_block_groups, input_channel=3, time_mult = 4):
        super().__init__()

        # step 1

        ## initial conv layer: (B, 3, H, W) ----------> (B, init_dim, H, W)
        self.init_conv = nn.Conv2d(input_channel, init_dim, kernel_size=1)
        
        ## set up downsampling/upsampling channel dimension changes: [(init_dim, init_dim*dim_mults[0]), (init_dim*dim_mults[0], init_dim*dim_mults[1]), ...]
        dims = [init_dim] + [n * init_dim for n in dim_mults]
        dims = list(zip(dims[:-1], dims[1:]))

        # step 2: time embedding: (batch, ) --------> (batch, time_dim)
        time_dim = time_mult * init_dim

        self.time_embed = nn.Sequential(
            TimeEmbedding(init_dim),
            nn.Linear(init_dim, time_dim),
            nn.GELU(),
            nn.Linear(time_dim,time_dim)
        )

        # step 3: downsampling
        self.down_layers = nn.ModuleList([])

        for i, (dim_in, dim_out) in enumerate(dims, 1):
            is_last = (i == len(dims))
            self.down_layers.append(
                nn.ModuleList([
                    ResnetBlock(dim_in, dim_in, time_emb_dim=time_dim, groups=resnet_block_groups), # residual(W_std_conv1+GN1+time_e, W_std_conv2+GN2)
                    ResnetBlock(dim_in, dim_in, time_emd_dim=time_dim, groups=resnet_block_groups), # residual(W_std_conv1+GN1+time_e, W_std_conv2+GN2)
                    Residual(PreGroupNorm(dim_in, LinearAttention(dim_in))),
                    ................
                ])
            )

