In [1]:
import torch

class LinearNoiseScheduler():
    def __init__(self, num_steps, beta_start, beta_end):
        self.num_steps = num_steps
        self.beta_start = beta_start
        self.beta_end = beta_end
        self.step = 0

        # pre-compute alphas and betas
        self.betas = torch.linspace(beta_start, beta_end, num_steps)
        self.alphas = 1 - self.betas
        # \bar{\alpha}_t}
        self.alpha_cum_prod = torch.cumprod(self.alphas, 0)
        # \sqrt{\bar{\alpha}_t}}
        self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod)
        # \sqrt{1-\bar{\alpha}_t}}
        self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1 - self.alpha_cum_prod)

    # forward process
    def add_noise(self, original, noise, t):
        original_shape = original.shape
        batch_size = original_shape[0]

        sqrt_alph_cum_prod = self.sqrt_alpha_cum_prod[t].repeat(batch_size, 1)
        sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod[t].repeat(batch_size, 1)

        for _ in range(original.dim() - 1):
            sqrt_alph_cum_prod = sqrt_alph_cum_prod.unsqueeze(-1)
            sqrt_one_minus_alpha_cum_prod = sqrt_one_minus_alpha_cum_prod.unsqueeze(-1)

        # \sqrt{\bar{\alpha}_t}} * x_0 + (1-\sqrt{\bar{\alpha}_t}) * \epsilon_t
        return sqrt_alph_cum_prod * original + sqrt_one_minus_alpha_cum_prod * noise
    
    def sample_prev_timestep(self, xt, noise_pred, t):
        # x0 = (xt - \sqrt{1-\bar{\alpha}_t}} * \epsilon_t) / \sqrt{\bar{\alpha}_t}}
        x0 = (
            xt - self.sqrt_one_minus_alpha_cum_prod[t] * noise_pred
        ) / self.sqrt_alpha_cum_prod[t]

        x0 = torch.clamp(x0, -1, 1)

        mean = xt - (self.betas[t] * noise_pred) / self.sqrt_one_minus_alpha_cum_prod[t]
        mean = mean / torch.sqrt(self.alphas[t])

        if t == 0:
            return mean, x0

        variance = (1 - self.alpha_cum_prod[t-1]) / (1 - self.alpha_cum_prod[t])
        variance *= self.betas[t]
        sigma = torch.sqrt(variance)
        # sample from Gaussian distribution
        z = torch.randn(xt.shape).to(xt.device)
        return mean + sigma * z, x0

## UNet model

 * Using sinusoidal position embedding for time-embeddings

$$sin\left(pos / 10000^{2i / d_{model}}\right)$$
$$cos\left(pos / 10000^{2i+1 / d_{model}}\right)$$

In [2]:
import torch
import torch.nn as nn

def get_time_embedding(time_steps, t_emb_dim):
    device = time_steps.device
    factor = 10000 ** ((
        torch.arange(0, t_emb_dim//2, device=device) / (t_emb_dim // 2)
    ))

    t_emb = time_steps.unsqueeze(-1).repeat(1, t_emb_dim//2) / factor
    t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)

    return t_emb


### Down-block

In [4]:
class DownBlock(nn.Module):
    def __init__(
            self,
            in_channels,
            out_channels,
            t_emb_dim,
            down_sample,
            num_heads
        ):
        super().__init__()

        self.down_sample = down_sample
        self.resnet_conv_first = nn.Sequential(
            nn.GroupNorm(8, in_channels),
            nn.ReLU(),
            nn.Conv2d(in_channels, out_channels, 3, 1, 1)
        )

        self.t_emb_layers = nn.Sequential(
            nn.SiLU(),
            nn.Linear(t_emb_dim, out_channels),
        )

        self.resnet_conv_second = nn.Sequential(
            nn.GroupNorm(8, out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1)
        )

        self.attn_norm = nn.GroupNorm(8, out_channels)
        self.attn = nn.MultiheadAttention(
            out_channels, num_heads, batch_first=True
        )
        self.resid_input_conv = nn.Conv2d(in_channels, out_channels, 1)
        self.down_sample_conv = nn.Conv2d(
            out_channels, out_channels, 4, 2, 1
        ) if self.down_sample else nn.Identity()

    def forward(self, x, t_emb):
        out = x
        # ResNet block 1
        resnet_input = out
        out = self.resnet_conv_first(out)
        # time embedding
        out += self.t_emb_layers(t_emb).unsqueeze(-1).unsqueeze(-1)
        # ResNet block 2
        out = self.resnet_conv_second(out)
        # residual
        out += self.resid_input_conv(resnet_input)

        # self-attention block
        B, C, H, W = out.shape
        input_for_attn = out.view(B, C, -1)
        input_for_attn = self.attn_norm(input_for_attn)
        input_for_attn = input_for_attn.transpose(1, 2)
        out_attn, _ = self.attn(input_for_attn, input_for_attn, input_for_attn)
        out_attn = out_attn.transpose(1, 2).view(B, C, H, W)
        out += out_attn

        out = self.down_sample_conv(out)
        return out


### Mid-Block

In [5]:
class MidBlock(nn.Module):
    def __init__(
            self,
            in_channels,
            out_channels,
            t_emb_dim,
            num_heads
        ):
        super().__init__()

        self.resnet_conv_first = nn.ModuleList([
            nn.Sequential(
                nn.GroupNorm(8, in_channels),
                nn.ReLU(),
                nn.Conv2d(in_channels, out_channels, 3, 1, 1)
            ),
            nn.Sequential(
                nn.GroupNorm(8, out_channels),
                nn.ReLU(),
                nn.Conv2d(out_channels, out_channels, 3, 1, 1)
            )
        ])

        self.t_emb_layers = nn.ModuleList([
            nn.Sequential(
                nn.SiLU(),
                nn.Linear(t_emb_dim, out_channels),
            ),
            nn.Sequential(
                nn.SiLU(),
                nn.Linear(t_emb_dim, out_channels),
            )
        ])

        self.resnet_conv_second = nn.ModuleList([
            nn.Sequential(
                nn.GroupNorm(8, out_channels),
                nn.ReLU(),
                nn.Conv2d(out_channels, out_channels, 3, 1, 1)
            ),
            nn.Sequential(
                nn.GroupNorm(8, out_channels),
                nn.ReLU(),
                nn.Conv2d(out_channels, out_channels, 3, 1, 1)
            )
        ])

        self.attn_norm = nn.GroupNorm(8, out_channels)
        self.attn = nn.MultiheadAttention(
            out_channels, num_heads, batch_first=True
        )
        self.resid_input_conv = nn.ModuleList([
            nn.Conv2d(in_channels, out_channels, 1),
            nn.Conv2d(out_channels, out_channels, 1)
        ])

    def forward(self, x, t_emb):
        out = x
        # First ResNet block
        resnet_input = out
        out = self.resnet_conv_first[0](out)
        out += self.t_emb_layers[0](t_emb).unsqueeze(-1).unsqueeze(-1)
        out = self.resnet_conv_second[0](out)
        out += self.resid_input_conv[0](resnet_input)

        # self-attention block
        B, C, H, W = out.shape
        input_for_attn = out.view(B, C, -1)
        input_for_attn = self.attn_norm(input_for_attn)
        input_for_attn = input_for_attn.transpose(1, 2)
        out_attn, _ = self.attn(input_for_attn, input_for_attn, input_for_attn)
        out_attn = out_attn.transpose(1, 2).view(B, C, H, W)
        out += out_attn

        # Second ResNet block
        resnet_input = out
        out = self.resnet_conv_first[1](out)
        out += self.t_emb_layers[1](t_emb).unsqueeze(-1).unsqueeze(-1)
        out = self.resnet_conv_second[1](out)
        out += self.resid_input_conv[1](resnet_input)

        return out



### Up-Block

In [6]:
class UpBlock(nn.Module):
    def __init__(
            self,
            in_channels,
            out_channels,
            t_emb_dim,
            up_sample,
            num_heads
        ):
        super().__init__()

        self.up_sample = up_sample
        self.resnet_conv_first = nn.Sequential(
            nn.GroupNorm(8, in_channels),
            nn.ReLU(),
            nn.Conv2d(in_channels, out_channels, 3, 1, 1)
        )

        self.t_emb_layers = nn.Sequential(
            nn.SiLU(),
            nn.Linear(t_emb_dim, out_channels),
        )

        self.resnet_conv_second = nn.Sequential(
            nn.GroupNorm(8, out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1)
        )

        self.attn_norm = nn.GroupNorm(8, out_channels)
        self.attn = nn.MultiheadAttention(
            out_channels, num_heads, batch_first=True
        )
        self.resid_input_conv = nn.Conv2d(in_channels, out_channels, 1)
        self.up_sample_conv = nn.ConvTranspose2d(
            in_channels // 2, in_channels // 2, 4, 2, 1
        ) if self.up_sample else nn.Identity()

    def forward(self, x, out_down, t_emb):
        x = self.up_sample_conv(x)
        x = torch.cat([x, out_down], dim=1)

        # ResNet block
        out = x
        resnet_input = out
        out = self.resnet_conv_first(out)
        out += self.t_emb_layers(t_emb).unsqueeze(-1).unsqueeze(-1)
        out = self.resnet_conv_second(out)
        out += self.resid_input_conv(resnet_input)

        # self-attention block
        B, C, H, W = out.shape
        input_for_attn = out.view(B, C, -1)
        input_for_attn = self.attn_norm(input_for_attn)
        input_for_attn = input_for_attn.transpose(1, 2)
        out_attn, _ = self.attn(input_for_attn, input_for_attn, input_for_attn)
        out_attn = out_attn.transpose(1, 2).view(B, C, H, W)
        out += out_attn

        return out