In [None]:
import einops as ein
from einops.layers.torch import Rearrange
import torch
from torch import nn, optim
from torch.nn import functional as F
from torchvision.utils import save_image
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np


In [None]:
class TimeEmbedding(nn.Module):
    def __init__(self, n_channels):
        super().__init__()
        self.lin1 = nn.Linear(n_channels // 4, n_channels)
        self.act = nn.Mish()
        self.lin2 = nn.Linear(n_channels, n_channels)
        half_dim = n_channels // 8
        emb_scale = np.log(10000) / (half_dim - 1)
        self.emb = torch.exp(torch.arange(half_dim) * -emb_scale)[None, :]

    def forward(self, t):
        emb = t[:, None] * self.emb.to(t.device)
        emb = torch.cat((emb.sin(), emb.cos()), dim=1)
        emb = self.act(self.lin1(emb))
        emb = self.lin2(emb)
        return emb


class Block(nn.Module):
    def __init__(self, in_channels, out_channels, time_channels):
        super().__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1),
            nn.GroupNorm(32, out_channels),
            nn.Mish(),
        )
        self.block2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, 3, 1, 1),
            nn.GroupNorm(32, out_channels),
            nn.Mish(),
        )
        self.shortcut = nn.Conv2d(in_channels, out_channels, 1)
        self.time_emb = nn.Linear(time_channels, out_channels)

    def forward(self, x, t):
        h = self.block1(x)
        h += self.time_emb(t)[:, :, None, None]
        h = self.block2(h)
        return h + self.shortcut(x)


class Encoder(nn.Module):
    def __init__(self, time_channels, chs):
        super().__init__()
        self.enc_blocks = nn.ModuleList(
            [Block(chs[i], chs[i + 1], time_channels)
             for i in range(len(chs) - 1)]
        )
        self.pool = nn.MaxPool2d(2)

    def forward(self, x, t):
        ftrs = []
        for block in self.enc_blocks:
            x = block(x, t)
            ftrs.append(x)
            x = self.pool(x)
        return ftrs


class Decoder(nn.Module):
    def __init__(self, time_channels, chs):
        super().__init__()
        self.upconvs = nn.ModuleList()
        self.dec_blocks = nn.ModuleList()
        for in_channels, out_channels in zip(chs[:-1], chs[1:]):
            self.upconvs.append(nn.ConvTranspose2d(
                in_channels, out_channels, 2, 2))
            self.dec_blocks.append(
                Block(in_channels, out_channels, time_channels))

    def forward(self, x, t, encoder_features):
        for upconv, dec_block, enc_ftrs in zip(
            self.upconvs, self.dec_blocks, encoder_features
        ):
            x = upconv(x)
            enc_ftrs = nn.functional.interpolate(
                enc_ftrs, size=x.shape[2:], mode="bilinear", align_corners=True
            )
            x = torch.cat([x, enc_ftrs], dim=1)
            x = dec_block(x, t)
        return x


class UNet(nn.Module):
    def __init__(self, image_channels, n_channels):
        super().__init__()
        time_channels = n_channels * 4
        self.image_proj = nn.Conv2d(image_channels, n_channels, 3, 1, 1)
        self.time_emb = TimeEmbedding(time_channels)
        self.encoder = Encoder(time_channels, (32, 64, 128, 256, 512, 1024))
        self.decoder = Decoder(time_channels, (1024, 512, 256, 128, 64))
        self.final = nn.Conv2d(64, 1, 3, 1, 1)
        self.act = nn.Mish()

    def forward(self, x, t):
        x = self.image_proj(x)
        time = self.time_emb(t)
        enc_ftrs = self.encoder(x, time)
        out = self.decoder(enc_ftrs[::-1][0], time, enc_ftrs[::-1][1:])
        return self.final(self.act(out))


In [None]:
class Diffusion(nn.Module):
    def __init__(self, noise_steps=1000, device="cpu"):
        super(Diffusion, self).__init__()
        self.beta = torch.linspace(1e-4, 0.02, noise_steps, device=device)
        self.alpha = 1 - self.beta
        self.alpha_bar = torch.cumprod(self.alpha, dim=0)
        self.T = noise_steps
        self.eps_model = UNet(image_channels=1, n_channels=32)
        self.device = device

    def q(self, x0, t):
        temp = ein.rearrange(self.alpha_bar[t], "b -> b () () ()")
        mean = torch.sqrt(temp) * x0
        var = 1 - temp
        return mean, var

    def sample_q(self, x0, t, epsilon=None):
        if epsilon is None:
            epsilon = torch.randn_like(x0)

        mean, var = self.q(x0, t)
        return mean + torch.sqrt(var) * epsilon

    def p_sample(self, xt, t):
        eps_theta = self.eps_model(xt, t)
        alpha_bar = ein.rearrange(self.alpha_bar[t], "b -> b () () ()")
        alpha = ein.rearrange(self.alpha[t], "b -> b () () ()")
        eps_coef = (1 - alpha) / torch.sqrt((1 - alpha_bar))
        mean = (1 / torch.sqrt(alpha)) * (xt - eps_coef * eps_theta)
        var = self.beta[t]
        epsilon = torch.randn(xt.shape, device=xt.device)
        return (mean + torch.sqrt(var) * epsilon).detach()

    def loss(self, x0, noise=None):
        t = torch.randint(0, self.T, (x0.size(0),),
                          device=x0.device, dtype=torch.long)
        if noise is None:
            noise = torch.randn_like(x0)
        xt = self.sample_q(x0, t, noise)
        eps_theta = self.eps_model(xt, t)
        return F.mse_loss(noise, eps_theta)

    def show_image(self, img, title=""):
        img = img.cpu().detach().numpy()
        plt.imshow(img[0, 0, :, :], cmap="gray")
        plt.title(title)
        plt.show()

    def _sample_x0(self, xt: torch.Tensor, n_steps: int):
        n_samples = xt.shape[0]
        for t_ in range(n_steps):
            t = n_steps - t_ - 1
            xt = self.p_sample(xt, xt.new_full(
                (n_samples,), t, dtype=torch.long))
        return xt

    def sample(self, n_steps: int):
        xt = torch.randn([1, 1, 32, 32], device=self.device)
        x0 = self._sample_x0(xt, n_steps)
        self.show_image(x0, title="Sampled image")
        return x0


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Diffusion(device=device).to(device)
optimizer = optim.RAdam(model.parameters(), lr=4e-5)

batch_size = 128
# Get train and test data
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "../data", train=True, download=True, transform=transforms.ToTensor()
    ),
    batch_size=batch_size,
    shuffle=True,
)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST("../data", train=False, transform=transforms.ToTensor()),
    batch_size=batch_size,
    shuffle=True,
)

# Make images 32x32
def pad_image(img):
    img = F.pad(img, (2, 2, 2, 2))
    return img


train_loader.dataset.transform = transforms.Compose(
    [transforms.ToTensor(), pad_image])

test_loader.dataset.transform = transforms.Compose(
    [transforms.ToTensor(), pad_image])


In [None]:
def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        loss = model.loss(data)
        print(loss)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
    print(
        "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
            epoch,
            batch_idx * len(data),
            len(train_loader.dataset),
            100.0 * batch_idx / len(train_loader),
            loss.item() / len(data),
        )
    )


def test(epoch):
    model.train()
    test_loss = 0
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)
            loss = model.loss(data)
            test_loss += loss.item()

    test_loss /= len(test_loader.dataset)
    print("====> Test set loss: {:.4f}".format(test_loss))


In [None]:
epochs = 10
for epoch in range(1, epochs + 1):
    train(epoch)
    test(epoch)


In [None]:
torch.save(model, "diffusion.pth")
