In [None]:
from typing import Optional, List

import torch
import torch.nn  as nn 
import torch.nn.functional as F

import math
import numpy as np 

from stable_diffusion_model import Encoder, Decoder, GaussianDistribution, ResNet, Downsample, AttnBlock, Upsample
from helper import Swish

In [None]:
class Upsample(nn.Module):
    def __init__(self, 
                 channels:int
                 ):
        super().__init__()
        self.conv = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1)

    def forward(self, x: torch.Tensor):
        x = F.interpolate(x, scale_factor=2, mode='nearest')
        output = self.conv(x)
        return output

In [None]:
class Downsaple(nn.Module):
    def __init__(self,
                 channels: int
                 ):
        super().__init__()
        self.Conv = nn.Conv2d(channels, channels, kernel_size=3, stride=2, padding=0)

    def forward(self, x: torch.Tensor):
        x = F.pad(x, (0, 1, 0, 1), mode='constant', value=0)
        output = self.Conv(x) 
        return output

In [None]:
class ResNet(nn.Module):
    def __init__(self, 
                 in_channels: int,
                 out_channels: int):
        super().__init_()
        self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)

        self.norm2 = nn.GroupNorm(num_groups=32, num_channels=in_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        
        if in_channels != out_channels:
            self.short_cut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        else:
            self.short_cut = nn.Identity()
        
    def forward(self, x:torch.Tensor) -> torch.Tensor:
        x = self.norm1(x)
        x = Swish(x)
        x = self.conv1(x)
        
        x = self.norm2(x)
        x = Swish(x)
        x = self.conv2(x)

        x_short = self.short_cut(x)
        output = x_short + x
        return output

In [None]:
class Encoder(nn.Module):
    def __init__(self, 
                 channels: int,
                 channels_multiplier: List[int],
                 num_blocks: int,
                 in_channels: int,
                 z_channels: int,
                 ):
        """
        channels: the channels in the first CNN layer,
        channels_multiplier: multiply the channels in the following layer,
        num_blocks: the number of resnet blocks in the model,
        in_channels: input image channels,
        z_channels: hidden space channels
        """
        super().__init__()
        n_resolution = len(channels_multiplier)

        self.conv_in = nn.Conv2d(in_channels, channels, kernel_size=3, stride=1, padding=1)
        channels_list = [m*channels for m in [1]+channels_multiplier]

        self.down = nn.ModuleList()
        for i in range(n_resolution):
            # each layer of the model has several resnet block as well as downsample blocks
            resnet_block = []
            for j in range(num_blocks):
                resnet_block.append(ResNet(channels, channels_list[i + 1]))
                # update the channels of the next resnet block to the output layer of this block
                channels = channels_list[i + 1]
            
            down = nn.Module()
            down.block = resnet_block

            if i != n_resolution-1:
                down.downsample = Downsample(channels)
            else:
                down.downsample = nn.Identity()

            self.down.append(down)
        
        self.mid = nn.Module()
        self.mid.block_1 = ResNet(channels, channels)
        self.mid.attn_1 = AttnBlock(channels)
        self.mid.block_2 = ResNet(channels, channels)

        self.norm_out = nn.GroupNorm(num_groups=32, num_channels=channels)
        self.conv_out = nn.Conv2d(channels, 2 * z_channels, kernel_size=3, stride=1, padding=1)

    def forward(self, x:torch.Tensor) -> torch.Tensor:
        x = self.conv_in(x)
        for i in self.down:
            for block in self.down.block:
                x = block(x)
            x = self.down.downsample(x)

            x = self.mid.block_1(x)
            x = self.mid.attn_1(x)
            x = self.mid.block_2(x)

            x = self.norm_out(x)
            x = Swish(x)
            x = self.conv_out(x)

In [None]:
class Decoder(nn.Module):
    def __init__(self, 
                 channels: int,
                 channels_multiplier: List[int],
                 num_blocks: int,
                 out_channels: int,
                 z_channels: int,
                 ):
        """
        out_channels: the output channel of the generated image
        other features are the same as encoder
        """
        super().__init__()
        
        num_resolutions = len(channels_multiplier)

        channels_list = [m * channels for m in channels_multiplier]
        
        channels = channels_list[-1]

        self.conv_in = nn.Conv2d(z_channels, channels, kernel_size=3, stride=1, padding=1)
        
        self.mid = nn.Module()
        self.mid.block_1 = ResNet(channels, channels)
        self.mid.attn_1 = AttnBlock(channels)
        self.mid.block_2 = ResNet(channels, channels)

        self.up = nn.ModuleList()

        for i in range(reversed(num_resolutions)):
            resnet_blocks = nn.ModuleList()
            for _ in range(num_blocks+1):
                resnet_blocks.append(ResNet(channels, channels_list[i]))
                channels = channels_list[i]

            up = nn.Module()
            up.block = resnet_blocks
            # Upsample at the end of each block except the first one
            if i != 0:
                up.upsample = Upsample(channels)
            else:
                up.upsample = nn.Identity()
            
            # This is just for the consistency with the pretrained model
            self.up.insert(0, up)
            
        self.norm_out = nn.GroupNorm(num_groups=32, num_channels=channels)
        self.conv_out = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
    
    def forward(self, z:torch.Tensor) -> torch.Tensor:
        h = self.conv_in(z)
        h = self.mid.block_1(h)
        h = self.mid.attn_1(h)
        h = self.mid.block_2(h)

        for up in reversed(self.up):
            for block in up.block:
                h = block(h)
            h = up.upsample(h)
        h = self.norm_out(h)
        h = Swish(h)
        image = self.conv_out(h)
        return image

In [None]:
class GaussianDistribution(nn.Module):
    def __init__(self, 
                 parameters: torch.Tensor,
                 upper_bound: int = 20,
                 lower_bound: int = -30
                 ):
        """
        The parameters has shape: (batch_size, z_channels *2, z_height, z_width)
        This is the output shape of the Encoder, mean and variance parameters are needed for reparameterization tricks
        """
        super().__init()
        self.mean,log_var = torch.chunk(parameters, 2, dim=1)
        self.log_var = torch.clamp(log_var, lower_bound, upper_bound)
        self.std = torch.exp(0.5 * self.log_var)

    def sample(self):
        output = self.mean + self.std * torch.randn_like(self.std)
        return output

In [None]:
class AttnBlock(nn.Module):
    def __init__(self, 
                 channels: torch.Tensor
                 ):
        super().__init__()
        self.channels = channels
        self.norm = nn.GroupNorm(num_groups=32, num_channels=channels)
        self.to_q = nn.Conv2d(channels, channels, kernel_size=1, stride=1, padding=0)
        self.to_k = nn.Conv2d(channels, channels, kernel_size=1, stride=1, padding=0)
        self.to_v = nn.Conv2d(channels, channels, kernel_size=1, stride=1, padding=0)
        
        # Channelwise CNN
        self.proj_out = nn.Conv2d(channels, channels, 1)
        self.scale = channels ** -0.5

    def forward(self, x: torch.Tensor):
        b, c, h, w = x.shape
        x = self.norm(x)

        q = self.to_q(x)
        k = self.to_k(x)
        v = self.to_v(x)
        
        q = q.view(b, c, h*w)
        k = k.view(b, c, h*w)
        v = v.view(b, c, h*w)

        attn = torch.einsum('bci, bcj -> bij', q, k)
        attn = attn* self.scale
        attn = torch.softmax(attn, dim=-1)
        output = torch.einsum('bej, bfj -> bfe', attn, v)

        output = output.view(b, c, h, w)
        output = self.proj_out(x)
        return output + x

In [None]:
class AutoEncoder(nn.Module):
    def __init__(self, 
                 encoder: 'Encoder', 
                 decoder: 'Decoder', 
                 emb_channels: int, 
                 z_channels: int
                 ):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        
        # The convolution layer to project embedding space to the reparameterization space
        self.quant_conv = nn.Conv2d(2*z_channels, 2*emb_channels, kernel_size=1, stride=1, padding=0)

        # The convolution layer to projrct the embedding from reparameterization space to embedding space
        self.post_quant_conv = nn.Conv2d(emb_channels, z_channels, kernel_size=1, stride=1, padding=0)

    def encode(self, x: torch.Tensor) -> 'GaussianDistribution':
        z = self.encoder(x)
        moments = self.quant_conv(x)
        output = GaussianDistribution(moments)
        return output
    
    def decode(self, z: torch.Tensor) -> torch.Tensor:
        x = self.post_quant_conv(z)
        output = self.decoder(x)
        return output