In [1]:
import torch
from torch import nn
import math

In [2]:
def unet_conv(ic, oc, ks=3, stride=1, act=nn.SiLU, norm=None, bias=True):
    layers = nn.Sequential()
    if norm: layers.append(norm(ic))
    if act : layers.append(act())
    layers.append(nn.Conv2d(ic, oc, stride=stride, kernel_size=ks, padding=ks//2, bias=bias))
    return layers

In [3]:
def lin(ic, oc, stride=1, act=nn.SiLU, norm=None, bias=True):
    layers = nn.Sequential()
    if norm: layers.append(norm(ic))
    if act : layers.append(act())
    layers.append(nn.Linear(ic, oc, bias=bias))
    return layers

In [4]:
def timestep_embedding(tsteps, emb_dim, max_period= 10000):
    exponent = -math.log(max_period) * torch.linspace(0, 1, emb_dim//2, device=tsteps.device)
    emb = tsteps[:,None].float() * exponent.exp()[None,:]
    emb = torch.cat([emb.sin(), emb.cos()], dim=-1)
    return F.pad(emb, (0,1,0,0)) if emb_dim%2==1 else emb

In [5]:
class SelfAttentionMultiHead(nn.Module):
    def __init__(self, ic, nheads, transpose=True):
        super().__init__()
        self.nheads = ic//nheads
        self.scale = math.sqrt(ic/self.nheads)
        self.norm = nn.LayerNorm(ic)
        self.qkv = nn.Linear(ic, ic*3)
        self.proj = nn.Linear(ic, ic)
        self.t = transpose
    
    def forward(self, inp):
        n,c,h,w = inp.shape
        if self.t: x = x.transpose(1, 2)
        x = self.norm(inp).view(n, c, -1).transpose(1, 2)
        x = self.qkv(x)
        x = rearrange(x, 'n s (h d) -> (n h) s d', h=self.nheads)
        q,k,v = torch.chunk(x, 3, dim=-1)
        s = (q@k.transpose(1,2))/self.scale
        x = s.softmax(dim=-1) @ v
        x = rearrange(x, '(n h) s d -> n s (h d)', h=self.nheads)
        x = self.proj(x)
        if self.t: x = x.transpose(1,2).reshape(n,c,h,w)
        return x

In [6]:
class ResBlock(nn.Module):
    def __init__(self, n_emb, ic, oc=None, ks=3, act=nn.SiLU, norm=nn.BatchNorm2d, attn_chans=0):
        super().__init__()
        self.emb_proj = nn.Linear(n_emb, oc*2)
        self.conv1 = unet_conv(ic, oc, ks, act=act, norm=norm)
        self.conv2 = unet_conv(oc, oc, ks, act=act, norm=norm)
        self.idconv = nn.Identity() if ic==oc else nn.Conv2d(ic, oc, kernel_size=1)
        self.attn = False
        if attn_chans: self.attn = SelfAttentionMultiHead(oc, attn_chans)
    
    def forward(self, x, t):
        inp = x
        x = self.conv1(x)
        emb = self.emb_proj(act(t))[:, :, None, None]
        scale, shift = torch.chunk(emb, 2, dim=1)
        x = x*(1+scale) + shift
        x = self.conv2(x) + self.idconv(inp)
        if self.attn: x = x + self.attn(x)
        return x

In [7]:
class UNET_Encoder(nn.Module):
    def __init__(self, n_emb, in_channels, nf, attn_chans):
        super().__init__()
        self.encoder = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        for f in nf:
            self.encoder.append(ResBlock(n_emb, in_channels, oc=f, attn_chans=attn_chans))
            in_channels = f
        
    def forward(self, x, t):
        skips = []
        for enc in self.encoder:
            x = enc(x, t)
            skips.append(x)
            x = self.pool(x)
        return x, skips

In [8]:
class UNET_Decoder(nn.Module):
    def __init__(self, n_emb, nf, attn_chans):
        super().__init__()
        self.decoder = nn.ModuleList()
        for f in nf:
            self.decoder.append(nn.ConvTranspose2d(f*2, f, kernel_size=2, stride=2))
            self.decoder.append(ResBlock(n_emb, f*2, oc=f, attn_chans=attn_chans))
    
    def forward(self, x, skips, t):
        for i in range(0, len(self.decoder), 2):
            x = self.decoder[i](x)
            x = torch.cat((skips[i//2], x), dim=1)
            x = self.decoder[i+1](x, t)
        return x

In [9]:
class UNET_dummy(nn.Module):
    def __init__(self, n_classes, in_channels, out_channels, nf=[64, 128, 256, 512], attn_chans=8):
        super().__init__()
        self.t_emb = nf[0]
        n_emb = self.t_emb*4
        self.cond_emb = nn.Embedding(n_classes, n_emb)

        self.emb_mlp = nn.Sequential(
            lin(self.t_emb, n_emb, act=None, norm=nn.BatchNorm1d),
            lin(n_emb, n_emb)
        )

        self.unet_encoder = UNET_Encoder(n_emb, in_channels, nf, attn_chans)

        self.bottle_neck = ResBlock(n_emb, nf[-1], nf[-1])

        self.unet_decoder = UNET_Decoder(n_emb, nf[::-1], attn_chans)
        
        self.final_conv = nn.Conv2d(nf[0], out_channels, kernel_size=1)

    def forward(self, inp):
        x, t, c = inp
        temb = timestep_embedding(t, self.t_emb)
        cemb = self.cond_emb(c)
        emb = self.emb_mlp(temb) + cemb
        x, skips = self.unet_encoder(x, emb)
        x = self.bottle_neck(x, emb)
        x = self.unet_decoder(x, skips[::-1], emb)
        return self.final_conv(x)

In [10]:
ud = UNET_dummy(10, 1, 1)
ud

UNET_dummy(
  (cond_emb): Embedding(10, 256)
  (emb_mlp): Sequential(
    (0): Sequential(
      (0): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (1): Linear(in_features=64, out_features=256, bias=True)
    )
    (1): Sequential(
      (0): SiLU()
      (1): Linear(in_features=256, out_features=256, bias=True)
    )
  )
  (unet_encoder): UNET_Encoder(
    (encoder): ModuleList(
      (0): ResBlock(
        (emb_proj): Linear(in_features=256, out_features=128, bias=True)
        (conv1): Sequential(
          (0): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (1): SiLU()
          (2): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (conv2): Sequential(
          (0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (1): SiLU()
          (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (id

In [11]:
class UNET(nn.Module):
    def __init__(self, n_classes, in_channels, out_channels, nf=[64, 128, 256, 512], attn_chans=8):
        super().__init__()
        self.encoder = nn.ModuleList()
        self.t_emb = nf[0]
        n_emb = self.t_emb*4
        self.cond_emb = nn.Embedding(n_classes, n_emb)

        self.emb_mlp = nn.Sequential(
            lin(self.t_emb, n_emb, act=None, norm=nn.BatchNorm1d),
            lin(n_emb, n_emb)
        )
        
        for f in nf:
            self.encoder.append(ResBlock(n_emb, in_channels, oc=f, attn_chans=attn_chans))
            in_channels = f

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.bottle_neck = ResBlock(n_emb, nf[-1], nf[-1])

        self.decoder = nn.ModuleList()
        rnf = nf[::-1]
        for f in rnf:
            self.decoder.append(nn.ConvTranspose2d(f*2, f, kernel_size=2, stride=2))
            self.decoder.append(ResBlock(n_emb, f*2, oc=f, attn_chans=attn_chans))
        
        self.final_conv = nn.Conv2d(nf[0], out_channels, kernel_size=1)

    def forward(self, inp):
        x, t, c = inp
        temb = timestep_embedding(t, self.t_emb)
        cemb = self.cond_emb(c)
        emb = self.emb_mlp(temb) + cemb

        skips = []
        for enc in self.encoder:
            x = enc(x, emb)
            skips.append(x)
            x = self.pool(x)
              
        x = self.bottle_neck(x)
        skips = skips[::-1]
        
        for i in range(0, len(self.decoder), 2):
            x = self.decoder[i](x)
            x = torch.cat((skips[i//2], x), dim=1)
            x = self.decoder[i+1](x, emb)

        return self.final_conv(x)

In [12]:
model = UNET(10, 1, 1)
model

UNET(
  (encoder): ModuleList(
    (0): ResBlock(
      (emb_proj): Linear(in_features=256, out_features=128, bias=True)
      (conv1): Sequential(
        (0): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): SiLU()
        (2): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (conv2): Sequential(
        (0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): SiLU()
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (idconv): Conv2d(1, 64, kernel_size=(1, 1), stride=(1, 1))
      (attn): SelfAttentionMultiHead(
        (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (qkv): Linear(in_features=64, out_features=192, bias=True)
        (proj): Linear(in_features=64, out_features=64, bias=True)
      )
    )
    (1): ResBlock(
      (emb_proj): Linear(in_features=256, out_features=256, bias=True)
      (conv1): Se