In [52]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Self Attention

Introduced in the now famous paper, *Attention Is All You Need* (2017), self-attention is a method to enhance contextual information between elements in an input. It was first proposed for U-Nets in *Attention U-Net: Learning Where to Look for the Pancreas* (2018).

Let $\boldsymbol{x}$ be a sequence of $T$ vector inputs. The query matrix $\boldsymbol{W}_q$ can be thought of as embedding a question about the input sequence. The key matrix $\boldsymbol{W}_q$ embeds how relevant a token in the sequence is in query space.

$$\boldsymbol{q}_i = \boldsymbol{W}_q\boldsymbol{x}_i$$
$$\boldsymbol{k}_i = \boldsymbol{W}_k\boldsymbol{x}_i$$

The dot product between the ith query and jth key sequences indicates how well the key *answers* the query. These are called the unnormalised attention weights $\boldsymbol{\omega}$.

$$\boldsymbol{\omega}_{ij} = \boldsymbol{q}_i^\top \cdot \boldsymbol{k}_j$$

These weights are scaled by $\sqrt{d_k}$ where $d_k$ is the length of $x_i$ and normalised by the softmax function to get the normalised attention scores $\boldsymbol{\alpha}$:

$$\boldsymbol{\alpha}_{ij} = \text{softmax}\left(\frac{\boldsymbol{\omega}_{ij}}{\sqrt{d_k}}\right)$$

Scaling by $\sqrt{d_k}$ ensures that the magnitude of the weight vectors will be similar. The last step is to compute the context vector $\boldsymbol{z}$ which is an attention-weighted version of the original input $\boldsymbol{x}$. First, the value matrix is multiplied to get the value sequence:

 $$\boldsymbol{v}_i = \boldsymbol{W}_v\boldsymbol{x}_i$$
 
And the context vector is taken to be the attention-weighted sum of this sequence:

$$\boldsymbol{z}_i = \sum_{j = 1}^T \boldsymbol{\alpha}_{ij}\boldsymbol{v}_i$$

In multi-headed attention, multiple key, query, and value matrices are applied parallely.

In [55]:
class SelfAttention(nn.Module):
    def __init__(self, channels, size, heads=4, activation=nn.GELU):
        super().__init__()
        self.channels = channels
        self.size = size

        self.attention = nn.MultiheadAttention(channels, heads, batch_first=True)
        self.norm = nn.LayerNorm([channels])
        self.ff_self = nn.Sequential(
            nn.LayerNorm([channels]),
            nn.Linear(channels, channels),
            activation(),
            nn.Linear(channels, channels)
        )

    def forward(self, x):
        x = x.view(-1, self.channels, self.size ** 2).swapaxes(1, 2)
        z = self.norm(x)
        z, _ = self.mha(z, z, z)
        z += x
        z += self.ff_self(z)
        return z.swapaxes(2, 1).view(-1, self.channels, self.size, self.size)

In [56]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None, residual=False, activation=nn.GELU):
        super().__init__()
        self.residual = residual

        if mid_channels is None:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(1, mid_channels),
            activation(),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(1, out_channels)
        )
        
    def forward(self, x):
        if self.residual:
            return F.gelu(x + self.double_conv())
        return self.double_conv(x)
    

class DownConv(nn.Module):
    def __init__(self, in_channels, out_channels, emb_dim=256):
        super().__init__()
        
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, in_channels, residual=True),
            DoubleConv(in_channels, out_channels)
        )
        
        self.embedding = nn.Sequential(
            nn.SiLU(),
            nn.Linear(emb_dim, out_channels)
        )
    
    def forward(self, x, t):
        x = self.maxpool_conv(x)
        embedding = self.embedding(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
        return x + embedding


class UpConv(nn.Module):
    def __init__(self, in_channels, out_channels, emb_dim=256):
        super().__init__()
        
        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.conv = nn.Sequential(
            DoubleConv(in_channels, in_channels, residual=True),
            DoubleConv(in_channels, out_channels, in_channels // 2)
        )

        self.embedding = nn.Sequential(
            nn.SiLU(),
            nn.Linear(emb_dim, out_channels)
        )

    def forward(self, x, skip_x, t):
        x = self.up(x)
        x = torch.concat([skip_x, x], dim=1)
        x = self.conv(x)
        embedding = self.embedding(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
        return x + embedding

In [57]:
class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, time_dim=256):
        super().__init__()
        
        self.time_dim = time_dim
        
        self.inc = DoubleConv(in_channels, 64)
        self.down1 = DownConv(64, 128)
        self.sa1 = SelfAttention(128, 32)
        self.down2 = DownConv(128, 256)
        self.sa2 = SelfAttention(256, 16)
        self.down3 = DownConv(256, 256)
        self.sa3 = SelfAttention(256, 8)

        self.bot1 = DoubleConv(256, 512)
        self.bot2 = DoubleConv(512, 512)
        self.bot3 = DoubleConv(512, 256)

        self.up1 = UpConv(512, 128)
        self.sa4 = SelfAttention(128, 16)
        self.up2 = UpConv(256, 64)
        self.sa5 = SelfAttention(64, 32)
        self.up3 = UpConv(128, 64)
        self.sa6 = SelfAttention(64, 64)
        
        self.outc = nn.Conv2d(64, out_channels, kernel_size=1)

    def pos_encoding(self, t, channels):
        inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels))
        pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
        pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
        pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
        return pos_enc

    def forward(self, x, t):
        t = t.unsqueeze(-1).type(torch.float)
        t = self.pos_encoding(t, self.time_dim)

        x1 = self.inc(x)
        x2 = self.down1(x1, t)
        x2 = self.sa1(x2)
        x3 = self.down2(x2, t)
        x3 = self.sa2(x3)
        x4 = self.down3(x3, t)
        x4 = self.sa3(x4)

        x4 = self.bot1(x4)
        x4 = self.bot2(x4)
        x4 = self.bot3(x4)

        x = self.up1(x4, x3, t)
        x = self.sa4(x)
        x = self.up2(x, x2, t)
        x = self.sa5(x)
        x = self.up3(x, x1, t)
        x = self.sa6(x)
        return self.outc(x)

In [None]:
class Diffusion:
    def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=0.02, img_size=256):
        self.noise_steps = noise_steps
    
        self.beta = torch.linspace(beta_start, beta_end, self.noise_steps)
        self.alpha = 1. - self.beta
        self.alpha_hat = torch.cumprod(self.alpha, dim=0)
        
        self.img_size = img_size

    def noise_images(self, x, t):
        sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
        sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
        err = torch.randn_like(x)
        return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * err, err

    def sample_timesteps(self, n):
        return torch.randint(low=1, high=self.noise_steps, size=(n,))
    
    def sample(self):
        pass