In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from utils import *
from learner import *
from time_embedding import *
from einops import rearrange
import math

In [2]:
import torch
from utils import init_attr

class DDPM:
    def __init__(self, beta_min, beta_max, n_steps):
        init_attr(self, locals=locals())
        self.beta = torch.linspace(self.beta_min, self.beta_max, self.n_steps)
        self.alpha = 1 - self.beta
        self.alpha_bar = self.alpha.cumprod(dim=0)
        self.sigma = self.beta.sqrt()
    
    def schedule(self, x0):
        device = "cuda"
        n = len(x0)
        t = torch.randint(0, self.n_steps, (n,), dtype=torch.long)
        noise = torch.randn(x0.shape)
        alpha_bar_t = self.alpha_bar[t].reshape(-1, 1, 1, 1).to(device)
        mean = alpha_bar_t.sqrt().to(device) * x0.to(device) 
        variance = (1-alpha_bar_t).sqrt().to(device) * noise.to(device) 
        xt = mean + variance
        return (xt, t.to(device)), noise.to(device)

In [3]:
class SelfAttention(nn.Module):
    def __init__(self, ni, attn_chans, transpose=True):
        super().__init__()
        self.nheads = ni//attn_chans
        self.scale = math.sqrt(ni/self.nheads)
        self.norm = nn.LayerNorm(ni)
        self.qkv = nn.Linear(ni, ni*3)
        self.proj = nn.Linear(ni, ni)
        self.t = transpose
    
    def forward(self, x):
        n,c,s = x.shape
        if self.t: x = x.transpose(1, 2)
        x = self.norm(x)
        x = self.qkv(x)
        x = rearrange(x, 'n s (h d) -> (n h) s d', h=self.nheads)
        q,k,v = torch.chunk(x, 3, dim=-1)
        s = (q@k.transpose(1,2))/self.scale
        x = s.softmax(dim=-1)@v
        x = rearrange(x, '(n h) s d -> n s (h d)', h=self.nheads)
        x = self.proj(x)
        if self.t: x = x.transpose(1, 2)
        return x

class SelfAttention2D(SelfAttention):
    def forward(self, x):
        n,c,h,w = x.shape
        return super().forward(x.view(n, c, -1)).reshape(n,c,h,w)

In [4]:
def unet_conv(in_channels, out_channels, kernel_size=3, stride=1, act=nn.SiLU, norm=None, bias=True):
    layers = nn.Sequential()
    if norm: layers.append(norm(in_channels))
    if act: layers.append(act())
    layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=kernel_size//2, bias=bias))
    return layers

In [5]:
class TimestepEmbedding(nn.Module):
    def __init__(self, n_channels):
        super().__init__()
        self.n_channels = n_channels
        self.n_embedding = self.n_channels * 4
        self.timestep_mlp = nn.Sequential(
            nn.Linear(self.n_channels, self.n_embedding),
            nn.SiLU(),
            nn.Linear(self.n_embedding, self.n_embedding)
        )

    def forward(self, t):
        half_dim = self.n_embedding // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
        emb = t[:, None].float() * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=1)
        emb = self.timestep_mlp(emb)
        return emb

In [6]:
class ResBlock(nn.Module):
    def __init__(self, n_embedding, in_channels, out_channels=None, kernel_size=3, act=nn.SiLU, norm=nn.BatchNorm2d, attn_channs=0):
        super().__init__()
        if out_channels is None: out_channels = in_channels
        self.emb_proj = nn.Linear(n_embedding, out_channels*2)
        self.conv_1 = unet_conv(in_channels, out_channels, kernel_size=kernel_size, act=act, norm=norm)
        self.conv_2 = unet_conv(out_channels, out_channels, kernel_size=kernel_size, act=act, norm=norm)
        if in_channels == out_channels:
            self.id_conv = nn.Identity()  
        else:
            nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.attn = False
        if attn_channs:
            self.attn = SelfAttention2D(out_channels, attn_channs)
        
    def forward(self, x, t):
        inp = x
        x = self.conv_1(x)
        # x += self.emb_proj(act(t))[:, :, None, None]
        emb = self.emb_proj(act(t))[:, :, None, None]
        scale, shift = torch.chunk(emb, 2, dim=1)
        x = x*(1+scale) + shift
        x = self.conv2(x) + self.idconv(inp)
        if self.attn:
            x += self.attn(x)
        return x

In [7]:
class UNET_Encoder(nn.Module):
    def __init__(self, n_embedding, channels, attn_channs=0, attn_start=1):
        super().__init__()
        self.down = nn.ModuleList()
        self.down_sample = nn.ModuleList()
        
        n_resolutions = len(channels)
        out_channels = channels[0]
        for i in range(n_resolutions):
            in_channels = out_channels
            out_channels = channels[i]
            for i in range(2):
                self.down.append(
                    ResBlock(
                        n_embedding, 
                        in_channels if i==0 else out_channels, 
                        out_channels=out_channels, 
                        attn_channs=0 if i<attn_start else attn_channs
                    )
                )
            self.down_sample.append(
                nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=1) if i < (n_resolutions-1) else nn.Identity()
            )
    
    def forward(self, x, t):
        skips = []
        for i in range(len(self.down)):
            x = self.down[i](x, t)
            skips.append(x)
            x = self.down_sample[i](x)
        return x, skips

In [8]:
# unet_encoder = UNET_Encoder(256, (32, 64, 128, 256), attn_channs=8)
# unet_encoder

In [9]:
class UNET_Bottleneck(nn.Module):
    def __init__(self, n_embedding, in_channels):
        super().__init__()
        self.unet_bottleneck = nn.Sequential(
            ResBlock(n_embedding, in_channels, attn_channs=8),
            ResBlock(n_embedding, in_channels)
        )
    
    def forward(self, x):
        x = self.unet_bottleneck(x)
        return x

In [10]:
# unet_bottleneck = UNET_Bottleneck(256, 256)
# unet_bottleneck

In [11]:
class UNET_Decoder(nn.Module):
    def __init__(self, n_embedding, channels, attn_channs=0, attn_start=1):
        super().__init__()
        self.up = nn.ModuleList()
        self.up_sample = nn.ModuleList()

        n_resolutions = len(channels)
        out_channels = channels[0]
        for i in range(n_resolutions):
            prev_channels = out_channels
            in_channels = channels[min(i+1, n_resolutions-1)]
            out_channels = channels[i]
            for i in range(3):
                self.up.append(
                    ResBlock(
                        n_embedding, 
                        (prev_channels if i==0 else out_channels) + (in_channels if i==2 else out_channels), 
                        out_channels=out_channels, 
                        attn_channs=0 if i>=n_resolutions-attn_start else attn_channs
                    )
                )
            self.up_sample.append(
                nn.ConvTranspose2d(in_channels, in_channels, kernel_size=4, stride=2, padding=1) if i < (n_resolutions-1) else nn.Identity()
            )

    def forward(self, x, t, skips):
        for i in range(len(self.up)):
            x = self.up_sample[i](x)
            x = torch.cat((skips[i], x), dim=1)
            x = self.up[i](x, t)
        return x


In [12]:
# unet_decoder = UNET_Decoder(256, (256, 128, 64, 32), attn_channs=8)
# unet_decoder

In [13]:
class UNET(nn.Module):
    def __init__(self, n_classes, in_channels, out_channels, channels=(64, 128, 256, 512), attn_channs=8):
        super().__init__()
        self.n_embedding = channels[0]*4
        self.timestep_embedding = TimestepEmbedding(channels[0])
        self.condition_embedding = nn.Embedding(n_classes, self.n_embedding)

        self.conv_in = nn.Conv2d(in_channels, channels[0], kernel_size=3, padding=1)
        self.encoder = UNET_Encoder(self.n_embedding, channels, attn_channs=attn_channs)
        self.bottleneck = UNET_Bottleneck(self.n_embedding, channels[-1])
        self.decoder = UNET_Decoder(self.n_embedding, channels[::-1], attn_channs=attn_channs)
        self.conv_out = unet_conv(channels[0], out_channels, act=nn.SiLU, norm=nn.BatchNorm2d, bias=False)

    def forward(self, inp):
        x, t, c = inp
        t = self.timestep_embedding(t)
        c = self.condition_embedding(c)
        emb = t + c
        x = self.conv_in(x)
        x, skips = self.encoder(x, emb)
        x = self.bottleneck(x, emb)
        x = self.decoder(x, emb, skips[::-1])
        x = self.conv_out(x)
        return x

In [14]:
unet = UNET(10, 1, 1, (32,64,128,256))
unet

UNET(
  (timestep_embedding): TimestepEmbedding(
    (timestep_mlp): Sequential(
      (0): Linear(in_features=32, out_features=128, bias=True)
      (1): SiLU()
      (2): Linear(in_features=128, out_features=128, bias=True)
    )
  )
  (condition_embedding): Embedding(10, 128)
  (conv_in): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (encoder): UNET_Encoder(
    (down): ModuleList(
      (0): ResBlock(
        (emb_proj): Linear(in_features=128, out_features=64, bias=True)
        (conv_1): Sequential(
          (0): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (1): SiLU()
          (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (conv_2): Sequential(
          (0): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (1): SiLU()
          (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (id_conv): Ide