In [1]:
import torch

class LinearNoiseScheduler():
    def __init__(self, num_timesteps, beta_start, beta_end):
        self.num_timesteps = num_timesteps
        self.beta_start = beta_start
        self.beta_end = beta_end
        self.step = 0

        # pre-compute alphas and betas
        self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
        self.alphas = 1 - self.betas
        # \bar{\alpha}_t}
        self.alpha_cum_prod = torch.cumprod(self.alphas, 0)
        # \sqrt{\bar{\alpha}_t}}
        self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod)
        # \sqrt{1-\bar{\alpha}_t}}
        self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1 - self.alpha_cum_prod)

    # forward process
    def add_noise(self, original, noise, t):
        original_shape = original.shape
        batch_size = original_shape[0]

        sqrt_alph_cum_prod = self.sqrt_alpha_cum_prod[t].reshape(batch_size)
        sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod[t].reshape(batch_size)

        for _ in range(original.dim() - 1):
            sqrt_alph_cum_prod = sqrt_alph_cum_prod.unsqueeze(-1)
            sqrt_one_minus_alpha_cum_prod = sqrt_one_minus_alpha_cum_prod.unsqueeze(-1)

        # \sqrt{\bar{\alpha}_t}} * x_0 + (1-\sqrt{\bar{\alpha}_t}) * \epsilon_t
        return sqrt_alph_cum_prod * original + sqrt_one_minus_alpha_cum_prod * noise
    
    def sample_prev_timestep(self, xt, noise_pred, t):
        # x0 = (xt - \sqrt{1-\bar{\alpha}_t}} * \epsilon_t) / \sqrt{\bar{\alpha}_t}}
        x0 = (
            xt - self.sqrt_one_minus_alpha_cum_prod[t] * noise_pred
        ) / self.sqrt_alpha_cum_prod[t]

        x0 = torch.clamp(x0, -1, 1)

        mean = xt - (self.betas[t] * noise_pred) / self.sqrt_one_minus_alpha_cum_prod[t]
        mean = mean / torch.sqrt(self.alphas[t])

        if t == 0:
            return mean, x0

        variance = (1 - self.alpha_cum_prod[t-1]) / (1 - self.alpha_cum_prod[t])
        variance *= self.betas[t]
        sigma = torch.sqrt(variance)
        # sample from Gaussian distribution
        z = torch.randn(xt.shape).to(xt.device)
        return mean + sigma * z, x0

## UNet model

 * Using sinusoidal position embedding for time-embeddings

$$sin\left(pos / 10000^{2i / d_{model}}\right)$$
$$cos\left(pos / 10000^{2i+1 / d_{model}}\right)$$

In [2]:
import torch
import torch.nn as nn

def get_time_embedding(time_steps, t_emb_dim):
    device = time_steps.device
    factor = 10000 ** ((
        torch.arange(0, t_emb_dim//2, device=device) / (t_emb_dim // 2)
    ))

    t_emb = time_steps.unsqueeze(-1).repeat(1, t_emb_dim//2) / factor
    t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)

    return t_emb


### Down-block

In [3]:
class DownBlock(nn.Module):
    def __init__(
            self,
            in_channels,
            out_channels,
            t_emb_dim,
            down_sample,
            num_heads
        ):
        super().__init__()

        self.down_sample = down_sample
        self.resnet_conv_first = nn.Sequential(
            nn.GroupNorm(8, in_channels),
            nn.ReLU(),
            nn.Conv2d(in_channels, out_channels, 3, 1, 1)
        )

        self.t_emb_layers = nn.Sequential(
            nn.SiLU(),
            nn.Linear(t_emb_dim, out_channels),
        )

        self.resnet_conv_second = nn.Sequential(
            nn.GroupNorm(8, out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1)
        )

        self.attn_norm = nn.GroupNorm(8, out_channels)
        self.attn = nn.MultiheadAttention(
            out_channels, num_heads, batch_first=True
        )
        self.resid_input_conv = nn.Conv2d(in_channels, out_channels, 1)
        self.down_sample_conv = nn.Conv2d(
            out_channels, out_channels, 4, 2, 1
        ) if self.down_sample else nn.Identity()

    def forward(self, x, t_emb):
        out = x
        # ResNet block 1
        resnet_input = out
        out = self.resnet_conv_first(out)
        # time embedding
        out += self.t_emb_layers(t_emb).unsqueeze(-1).unsqueeze(-1)
        # ResNet block 2
        out = self.resnet_conv_second(out)
        # residual
        out = out + self.resid_input_conv(resnet_input)

        # self-attention block
        B, C, H, W = out.shape
        input_for_attn = out.view(B, C, -1)
        input_for_attn = self.attn_norm(input_for_attn)
        input_for_attn = input_for_attn.transpose(1, 2)
        out_attn, _ = self.attn(input_for_attn, input_for_attn, input_for_attn)
        out_attn = out_attn.transpose(1, 2).view(B, C, H, W)
        out = out + out_attn

        out = self.down_sample_conv(out)
        return out


### Mid-Block

In [4]:
class MidBlock(nn.Module):
    def __init__(
            self,
            in_channels,
            out_channels,
            t_emb_dim,
            num_heads
        ):
        super().__init__()

        self.resnet_conv_first = nn.ModuleList([
            nn.Sequential(
                nn.GroupNorm(8, in_channels),
                nn.ReLU(),
                nn.Conv2d(in_channels, out_channels, 3, 1, 1)
            ),
            nn.Sequential(
                nn.GroupNorm(8, out_channels),
                nn.ReLU(),
                nn.Conv2d(out_channels, out_channels, 3, 1, 1)
            )
        ])

        self.t_emb_layers = nn.ModuleList([
            nn.Sequential(
                nn.SiLU(),
                nn.Linear(t_emb_dim, out_channels),
            ),
            nn.Sequential(
                nn.SiLU(),
                nn.Linear(t_emb_dim, out_channels),
            )
        ])

        self.resnet_conv_second = nn.ModuleList([
            nn.Sequential(
                nn.GroupNorm(8, out_channels),
                nn.ReLU(),
                nn.Conv2d(out_channels, out_channels, 3, 1, 1)
            ),
            nn.Sequential(
                nn.GroupNorm(8, out_channels),
                nn.ReLU(),
                nn.Conv2d(out_channels, out_channels, 3, 1, 1)
            )
        ])

        self.attn_norm = nn.GroupNorm(8, out_channels)
        self.attn = nn.MultiheadAttention(
            out_channels, num_heads, batch_first=True
        )
        self.resid_input_conv = nn.ModuleList([
            nn.Conv2d(in_channels, out_channels, 1),
            nn.Conv2d(out_channels, out_channels, 1)
        ])

    def forward(self, x, t_emb):
        out = x
        # First ResNet block
        resnet_input = out
        out = self.resnet_conv_first[0](out)
        out = out + self.t_emb_layers[0](t_emb).unsqueeze(-1).unsqueeze(-1)
        out = self.resnet_conv_second[0](out)
        out = out + self.resid_input_conv[0](resnet_input)

        # self-attention block
        B, C, H, W = out.shape
        input_for_attn = out.view(B, C, -1)
        input_for_attn = self.attn_norm(input_for_attn)
        input_for_attn = input_for_attn.transpose(1, 2)
        out_attn, _ = self.attn(input_for_attn, input_for_attn, input_for_attn)
        out_attn = out_attn.transpose(1, 2).view(B, C, H, W)
        out = out + out_attn

        # Second ResNet block
        resnet_input = out
        out = self.resnet_conv_first[1](out)
        out = out + self.t_emb_layers[1](t_emb).unsqueeze(-1).unsqueeze(-1)
        out = self.resnet_conv_second[1](out)
        out = out + self.resid_input_conv[1](resnet_input)

        return out



### Upsample-Block

In [5]:
class UpBlock(nn.Module):
    def __init__(
            self,
            in_channels,
            out_channels,
            t_emb_dim,
            up_sample,
            num_heads
        ):
        super().__init__()

        self.up_sample = up_sample
        self.resnet_conv_first = nn.Sequential(
            nn.GroupNorm(8, in_channels),
            nn.ReLU(),
            nn.Conv2d(in_channels, out_channels, 3, 1, 1)
        )

        self.t_emb_layers = nn.Sequential(
            nn.SiLU(),
            nn.Linear(t_emb_dim, out_channels),
        )

        self.resnet_conv_second = nn.Sequential(
            nn.GroupNorm(8, out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1)
        )

        self.attn_norm = nn.GroupNorm(8, out_channels)
        self.attn = nn.MultiheadAttention(
            out_channels, num_heads, batch_first=True
        )
        self.resid_input_conv = nn.Conv2d(in_channels, out_channels, 1)
        self.up_sample_conv = nn.ConvTranspose2d(
            in_channels // 2, in_channels // 2, 4, 2, 1
        ) if self.up_sample else nn.Identity()

    def forward(self, x, out_down, t_emb):
        print(x.shape, out_down.shape)
        x = self.up_sample_conv(x)
        print("after up_sample_conv", x.shape)
        x = torch.cat([x, out_down], dim=1)

        # ResNet block
        out = x
        resnet_input = out
        out = self.resnet_conv_first(out)
        out = out + self.t_emb_layers(t_emb).unsqueeze(-1).unsqueeze(-1)
        out = self.resnet_conv_second(out)
        out = out + self.resid_input_conv(resnet_input)

        # self-attention block
        B, C, H, W = out.shape
        input_for_attn = out.view(B, C, -1)
        input_for_attn = self.attn_norm(input_for_attn)
        input_for_attn = input_for_attn.transpose(1, 2)
        out_attn, _ = self.attn(input_for_attn, input_for_attn, input_for_attn)
        out_attn = out_attn.transpose(1, 2).view(B, C, H, W)
        out = out + out_attn

        return out

### Putting everything together in UNet

In [6]:
class UNet(nn.Module):
    def __init__(self, img_channels):
        super().__init__()

        self.down_channels = [32, 64, 128, 256]
        self.mid_channels = [256, 256, 128]
        self.t_emb_dim = 128

        self.down_samples = [True, True, False]

        self.t_proj = nn.Sequential(
            nn.Linear(self.t_emb_dim, self.t_emb_dim),
            nn.SiLU(),
            nn.Linear(self.t_emb_dim, self.t_emb_dim),
        )
        self.conv_inp_layer = nn.Conv2d(img_channels, self.down_channels[0], 3, 1, 1)

        self.downs = nn.ModuleList([])
        for i in range(len(self.down_channels) - 1):
            self.downs.append(
                DownBlock(
                    self.down_channels[i],
                    self.down_channels[i+1],
                    self.t_emb_dim,
                    down_sample=self.down_samples[i],
                    num_heads=4
                )
            )

        self.mids = nn.ModuleList([])
        for i in range(len(self.mid_channels)-1):
            self.mids.append(
                MidBlock(
                    self.mid_channels[i],
                    self.mid_channels[i+1],
                    self.t_emb_dim,
                    num_heads=4
                )
            )

        self.ups = nn.ModuleList([])
        for i in reversed(range(len(self.down_channels)-1)):
            self.ups.append(
                UpBlock(
                    self.down_channels[i] * 2,
                    self.down_channels[i-1] if i != 0 else 16,
                    self.t_emb_dim,
                    up_sample=self.down_samples[i],
                    num_heads=4
                )
            )

        self.norm_out = nn.GroupNorm(8, 16)
        self.conv_out = nn.Conv2d(16, img_channels, 3, 1, 1)

    def forward(self, x, t):
        t_emb = get_time_embedding(t, self.t_emb_dim)
        t_emb = self.t_proj(t_emb)
        print("Time embedding:", t_emb.shape)
        print("Input:", x.shape)

        out = self.conv_inp_layer(x)
        out_downs = []
        for down_layer in self.downs:
            print("Down:", out.shape)
            out_downs.append(out)
            out = down_layer(out, t_emb)

        for mid_layer in self.mids:
            print("Mid:", out.shape)
            #print("mid_layer:", mid_layer)
            out = mid_layer(out, t_emb)

        for up_layer in self.ups:
            out_down = out_downs.pop()
            print("Up:", out.shape, out_down.shape)
            print("up_layer:", up_layer.up_sample)
            out = up_layer(out, out_down, t_emb)

        out = self.norm_out(out)
        out = nn.SiLU()(out)
        out = self.conv_out(out)
        return out

## Dataset

In [7]:
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

# scale the pixels to the range [-1, 1]
def scale(x):
    return x * 2 - 1

def unscale(x):
    return (x + 1) / 2

mnist = MNIST(
    root="data",
    download=True,
    transform=ToTensor()
)


class MNISTDataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img, _ = self.dataset[idx]
        img = scale(img)
        return img
    
# testing:
dataset = MNISTDataset(mnist)
print(next(iter(dataset)).shape)
print(next(iter(dataset))[0].min(), next(iter(dataset))[0].max())


torch.Size([1, 28, 28])
tensor(-1.) tensor(1.)


In [8]:
# read the config file:
import yaml
with open("config.yaml") as f:
    config = yaml.load(f, Loader=yaml.FullLoader)


In [9]:
# create the noise scheduler:
scheduler = LinearNoiseScheduler(
    config["diffusion"]["num_timesteps"],
    config["diffusion"]["beta_start"],
    config["diffusion"]["beta_end"]
)

# create the model:
model = UNet(config["model"]["img_channels"])

In [10]:
num_epochs = config["train"]["num_epochs"]
optimizer = torch.optim.Adam(
    model.parameters(), lr=config["train"]["lr"]
)
critereon = nn.MSELoss()

data_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=config["train"]["batch_size"],
    shuffle=True
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
for epoch in range(num_epochs):
    losses = []
    for imgs in data_loader:
        optimizer.zero_grad()

        imgs = imgs.to(device)
        print(imgs.shape)
        # sample noise
        noise = torch.randn_like(imgs).to(device)
        # sample timestep
        t = torch.randint(
            0, config["diffusion"]["num_timesteps"],
            (imgs.shape[0],)
        ).to(device)

        # add noise to imgs: add_noise(original, noise, t)
        noisy_imgs = scheduler.add_noise(imgs, noise, t)
        print(noisy_imgs.shape)

        # predict noise:
        noise_pred = model(noisy_imgs, t)

        # compute loss and backpropagate
        loss = critereon(noise_pred, noise)

        loss.backward()
        break
    break

# save the model
#torch.save(model.state_dict(), "model.pth")

torch.Size([64, 1, 28, 28])
torch.Size([64, 1, 28, 28])
Time embedding: torch.Size([64, 128])
Input: torch.Size([64, 1, 28, 28])
Down: torch.Size([64, 32, 28, 28])
Down: torch.Size([64, 64, 14, 14])
Down: torch.Size([64, 128, 7, 7])
Mid: torch.Size([64, 256, 7, 7])
Mid: torch.Size([64, 256, 7, 7])
Up: torch.Size([64, 128, 7, 7]) torch.Size([64, 128, 7, 7])
up_layer: False
torch.Size([64, 128, 7, 7]) torch.Size([64, 128, 7, 7])
after up_sample_conv torch.Size([64, 128, 7, 7])
Up: torch.Size([64, 64, 7, 7]) torch.Size([64, 64, 14, 14])
up_layer: True
torch.Size([64, 64, 7, 7]) torch.Size([64, 64, 14, 14])
after up_sample_conv torch.Size([64, 64, 14, 14])
Up: torch.Size([64, 32, 14, 14]) torch.Size([64, 32, 28, 28])
up_layer: True
torch.Size([64, 32, 14, 14]) torch.Size([64, 32, 28, 28])
after up_sample_conv torch.Size([64, 32, 28, 28])


## Inference

In [None]:
def inference(args):
    # read configs
    with open(args.config, "r") as f:
        try:
            config = yaml.safe_load(f)
        except yaml.YAMLError as exc:
            print(exc)
            return None
        
    diffusion_config = config["diffusion_params"]
    model_config = config["model_params"]
    train_config = config["train_params"]

    # loade model
    model = UNet(model_config["img_channels"])
    model.to(device)
    model.load_state_dict(
        torch.load(args.model_path),
        map_location=device
    )
    model.eval()

    # create the noise scheduler:
    scheduler = LinearNoiseScheduler(
        diffusion_config["num_timesteps"],
        diffusion_config["beta_start"],
        diffusion_config["beta_end"]
    )

    # TODO