code adapted from https://nn.labml.ai/diffusion/ddpm/unet.html

In [30]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

from diffusion_examples.modules import Swish

from typing import Optional, Union, List, Tuple

In [54]:


from typing import Tuple, Union


class TimeEmbedding(nn.Module):
    
    def __init__(self, n_channels: int):
        super().__init__()
        
        self.n_channels = n_channels
        self.embedding_mlp = nn.Sequential(
            nn.Linear(n_channels // 4, n_channels),
            Swish(),
            nn.Linear(n_channels, n_channels)
        )
    
    def forward(self, t: torch.Tensor) -> torch.Tensor:
        
        """
        PE^{(1)}_{t, i} = sin(t/10000^(i/d-1))
        PE^{(2)}_{t, i} = cos(t/10000^(i/d-1))
        """
        
        half_dim = self.n_channels // 8
        embedding = math.log(10000) / (half_dim - 1)
        embedding = torch.exp(torch.arange(half_dim, device=t.device) * -embedding)
        embedding = t[:, None] * embedding[None, :]
        embedding = torch.cat((embedding.sin(), embedding.cos()), dim=1)
        
        return self.embedding_mlp(embedding)



class ResidualBlock(nn.Module):
    
    def __init__(self, 
        in_channels: int, 
        out_channels: int, 
        time_channels: int, 
        n_groups: int=32, 
        dropout: float=0.1,
    ):
        super().__init__()
        
        self.conv1 = nn.Sequential(
            nn.GroupNorm(n_groups, in_channels),
            Swish(),
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        )
        
        self.conv2 = nn.Sequential(
            nn.GroupNorm(n_groups, in_channels),
            Swish(),
            nn.Dropout(dropout),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        )
        
        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        else:
            self.shortcut = nn.Identity()
        
        self.time_embedding = nn.Sequential(
            nn.Linear(time_channels, out_channels),
            Swish(),
        )
        
        
    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        
        h = self.conv1(x)
        h += self.time_embedding(t)[:, :, None, None]
        h = self.conv2(h)
        return h + self.shortcut(x) # residual connection



class AttentionBlock(nn.Module):
    
    def __init__(self, 
        n_channels: int, 
        n_heads: int=4,
        d_k: int=None,
        n_groups: int=32,
    ):
        super().__init__()
        
        if d_k is None:
            d_k = n_channels
        self.n_heads = n_heads
        self.d_k = d_k
        
        self.norm = nn.GroupNorm(n_groups, n_channels)
        
        self.proj_k = nn.Linear(n_channels, n_heads * d_k)
        self.proj_q = nn.Linear(n_channels, n_heads * d_k)
        self.proj_v = nn.Linear(n_channels, n_heads * d_k)
        
        self.out = nn.Linear(n_heads * d_k, n_channels)
        self.scale = d_k ** -0.5
    
    def forward(self, x: torch.Tensor, t: Optional[torch.Tensor]=None) -> torch.Tensor:
        
        batch_size, n_channels, height, width = x.shape
        x = x.view(batch_size, n_channels, -1).permute(0, 2, 1)  # reshape to (batch, seq, n_channels)
        
        k = self.proj_k(x).view(batch_size, -1, self.n_heads, self.d_k) # reshape to (batch, seq, n_heads, d_k)
        q = self.proj_q(x).view(batch_size, -1, self.n_heads, self.d_k)
        v = self.proj_v(x).view(batch_size, -1, self.n_heads, self.d_k)
        
        attention = torch.einsum('bihd,bjhd->bijh', q, k) # dot product attention Q@K^t
        attention = attention * self.scale # scale by inverse of sqrt(d_k)
        attention = attention.softmax(dim=2)
        attention = torch.einsum('bijh,bjhd->bihd', attention, v)
        
        attention = attention.view(batch_size, -1, self.n_heads * self.d_k) # reshape (concat heads)
        attention = self.output(attention) # project back down
        attention += x
        
        return attention.perumte(0, 2, 1).view(batch_size, n_channels, height, width)


class Block(nn.Module):
    
    def __init__(self, 
        in_channels: int, 
        out_channels: int, 
        time_channels: int, 
        use_attention: bool=False
    ):
        super().__init__()
        
        self.block = nn.Sequential(
            nn.ResidualBlock(in_channels, out_channels, time_channels),
            nn.AttentionBlock(out_channels) if use_attention else nn.Identity()
        )
    
    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        return self.block(x, t)


# class Bottleneck(nn.Module):
    
#     def __init__(self, n_channels: int, time_channels: int):
#         super().__init__()
        
#         self.block = nn.Sequential(
#             ResidualBlock
#         )

class UNet(nn.Module):
    
    def __init__(self,
        image_channels: int=3,
        n_channels: int=64,
        channel_multipliers: Union[Tuple[int, ...], List[int]]=(1, 2, 2, 4),
        uses_attention: Union[Tuple[bool, ...], List[bool]]=(False, False, True, True),
        n_blocks: int=2,
    ):
        super().__init__()
        self.image_chennels = image_channels
        self.n_channels = n_channels
        self.channel_multipliers = channel_multipliers
        self.uses_attention = uses_attention
        self.n_blocks = n_blocks
        
        self.time_channels = n_channels * 4
        self.time_embedding = TimeEmbedding(n_channels * self.time_channels)
        
        in_channels = out_channels = n_channels
        
        self.proj = nn.Conv2d(image_channels, n_channels, kernel_size=1)
        self.encoder_modules = nn.ModuleList()
        for i, channel_multiplier in enumerate(channel_multipliers):
            
            if i > 0:
                out_channels = in_channels * channel_multiplier
            else:
                out_channels = n_channels
            print(in_channels, out_channels)
            
            for _ in range(n_blocks):
                self.encoder_modules.append(nn.Sequential(
                    ResidualBlock(in_channels, out_channels, self.time_channels),
                    AttentionBlock(out_channels) if uses_attention[i] else nn.Identity(),
                ))
                in_channels = out_channels
            
            if i < len(uses_attention) - 1: # if not at end of list
                self.encoder_modules.append(nn.Upsample(scale_factor=0.5, mode='bilinear'))
        
        self.bottleneck = nn.Sequential(
            ResidualBlock(in_channels, in_channels, self.time_channels),
            AttentionBlock(in_channels),
            ResidualBlock(in_channels, in_channels, self.time_channels)
        )
        
        self.decoder_modules = nn.ModuleList()
        for i, channel_multiplier in enumerate(reversed(channel_multipliers)):
            
            i = len(channel_multipliers) - (i + 1)
            print(i)
            
            out_channels = in_channels // channel_multiplier
            
            for _ in range(n_blocks):
                self.decoder_modules.append(nn.Sequential(
                    ResidualBlock(in_channels, out_channels, self.time_channels),
                    AttentionBlock(out_channels) if uses_attention[i] else nn.Identity(),
                ))
                in_channels = out_channels
            
            if i > 0:
                self.encoder_modules.append(nn.Upsample(scale_factor=2, mode='bilinear'))
        
        self.final = nn.Conv2d(in_channels, image_channels, 1),
    
    
    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        
        t = self.time_embedding(t)
        h_list = [x]
        
        x = self.proj(x)
        for module in self.encoder_modules:
            x = module(x, t)
            h_list.append(x)
        
        x = self.bottleneck(x, t)
        
        for module in self.decoder_modules:
            if isinstance(module, nn.Upsample):
                x = module(x, t)
            else:
                s = h_list.pop()
                x = torch.cat((x, s), dim=1)
                x = module(x, t)
        
        return self.final(x)


model = UNet()
X = torch.rand((1, 3, 256, 256))
t = torch.randint(1, 10, (1, 1))

model(X, t)

64 64
64 128
128 256
256 1024
3
2
1
0


RuntimeError: mat1 and mat2 shapes cannot be multiplied (2x2048 and 4096x16384)

In [None]:
from typing import Iterable


Iterable

NameError: name 'Iterable' is not defined

In [3]:
torch.arange(10)

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])