In [None]:
from abc import ABC, abstractmethod
import math
import numpy as np
import torch
from torch import utils
from torch import nn
from torch import optim
from torchvision import datasets
from torchvision import transforms
from torchvision.transforms import ToPILImage
from IPython.display import Image

Stable Diffusion Example

In [None]:
data_dir = '/Users/armandli/data/'

In [None]:
to_img = ToPILImage()

In [None]:
use_cuda = torch.cuda.is_available()
use_mps = torch.backends.mps.is_built()
if use_cuda:
    device = torch.device('cuda')
elif use_mps:
    #mps has bugs that cannot handle ConvTranspose2d, reverting to cpu
    #device = torch.device('mps')
    device = torch.device('cpu')
else:
    device = torch.device('cpu')
cpu = torch.device('cpu')

In [None]:
default_batch_size = 256
loader_args = {'batch_size' : default_batch_size, 'shuffle' : True}
score_args = {'batch_size' : default_batch_size, 'shuffle' : False}
if use_cuda:
    loader_args.update({'pin_memory' : True})
    score_args.update({'pin_memory' : True})

In [None]:
class Reporter(ABC):
    @abstractmethod
    def report(self, typ, **metric):
        pass
    @abstractmethod
    def reset(self):
        pass

In [None]:
class SReporter(Reporter):
    def __init__(self):
        self.log = []
    def report(self, typ, **data):
        self.log.append((typ, data))
    def reset(self):
        self.log.clear()
    def loss(self, t):
        losses = []
        for (typ, data) in self.log:
            if typ == t:
                losses.append(data['loss'])
        return losses
    def loss(self, t, idx):
        if idx >= 0:
            count = 0
            for (typ, data) in self.log:
                if typ == t:
                    if count == idx:
                        return data['loss']
                    count += 1
        else:
            count = -1
            for (typ, data) in reversed(self.log):
                if typ == t:
                    if count == idx:
                        return data['loss']
                    count -= 1
        return float("inf")
    def eval_loss(self):
        return self.loss('eval')
    def train_loss(self):
        return self.loss('train')
    def eval_loss(self, idx):
        return self.loss('eval', idx)
    def train_loss(self, idx):
        return self.loss('train', idx)
    def get_record(self, t, idx):
        if idx >= 0:
            count = 0
            for (typ, data) in self.log:
                if typ == t:
                    if count == idx:
                        return data
                    count += 1
        else:
            count = -1
            for (typ, data) in reversed(self.log):
                if typ == t:
                    if count == idx:
                        return data
                    count -= 1
        return dict()
    def eval_record(self, idx):
        return self.get_record('eval', idx)
    def train_record(self, idx):
        return self.get_record('train', idx)

In [None]:
class TimeEmbeddingV1(nn.Module):
    def __init__(self, cdim):
        super(TimeEmbeddingV1, self).__init__()
        self.cdim = cdim
        self.layers = nn.Sequential(
            nn.Linear(self.cdim//4, self.cdim),
            nn.SiLU(),
            nn.Linear(self.cdim, self.cdim),
        )
    
    def forward(self, t):
        half_dim = self.cdim // 8
        emb = math.log(10_000) / (half_dim-1)
        emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
        emb = t[:,None] * emb[None,:]
        emb = torch.cat((emb.sin(), emb.cos()), dim=1)
        emb = self.layers(emb)
        return emb


In [None]:
class ResidualBlockV1(nn.Module):
    def __init__(self, cdim_in, cdim_out, tdim, gdim=32, dropout=0.1):
        super(ResidualBlockV1, self).__init__()
        self.layer1 = nn.Sequential(
            nn.GroupNorm(gdim, cdim_in),
            nn.SiLU(),
            nn.Conv2d(cdim_in, cdim_out, kernel_size=(3,3), padding=(1,1)),
        )
        self.layer2 = nn.Sequential(
            nn.GroupNorm(gdim, cdim_out),
            nn.SiLU(),
            nn.Dropout(p=dropout),
            nn.Conv2d(cdim_out, cdim_out, kernel_size=(3,3), padding=(1,1)),
        )
        self.time_layer = nn.Sequential(
            nn.SiLU(),
            nn.Linear(tdim, cdim_out),
        )
        if cdim_in != cdim_out:
            self.skip = nn.Conv2d(cdim_in, cdim_out, kernel_size=(1,1))
        else:
            self.skip = nn.Identity()
    
    def forward(self, x, t):
        h = self.layer1(x)
        h += self.time_layer(t)[:,:,None,None]
        h = self.layer2(h)
        return h + self.skip(x)

In [None]:
class AttentionBlockV1(nn.Module):
    def __init__(self, cdim, hdim=1, d_k=None, gdim=32):
        super(AttentionBlockV1, self).__init__()
        if d_k is None:
            d_k = cdim
        self.proj = nn.Linear(cdim, hdim * d_k * 3)
        self.output = nn.Linear(hdim * d_k, cdim)
        self.hdim = hdim
        self.d_k = d_k
        self.scale = d_k ** -0.5
    
    #NOTE: t is not used
    def forward(self, x, t=None):
        batch_sz, cdim, h, w = x.shape
        x = x.view(batch_sz, cdim, -1).permute(0, 2, 1)
        qkv = self.proj(x).view(batch_sz, -1, self.hdim, 3 * self.d_k)
        q, k, v = torch.chunk(qkv, 3, dim=-1)
        attn = torch.einsum('bihd,bjhd->bijh', q, k) * self.scale
        attn = attn.softmax(dim=2)
        res = torch.einsum('bihj,bjhd->bihd', attn, v)
        res = res.view(batch_sz, -1, self.hdim * self.d_k)
        res = self.output(res)
        res += x
        res = res.permute(0, 2, 1).view(batch_sz, cdim, h, w)
        return res

In [None]:
class DownBlockV1(nn.Module):
    def __init__(self, cdim_in, cdim_out, tdim, has_attn):
        super(DownBlockV1, self).__init__()
        self.res = ResidualBlockV1(cdim_in, cdim_out, tdim)
        if has_attn:
            self.attn = AttentionBlockV1(cdim_out)
        else:
            self.attn = nn.Identity()
        
    def forward(self, x, t):
        x = self.res(x, t)
        x = self.attn(x)
        return x

In [None]:
class UpBlockV1(nn.Module):
    def __init__(self, cdim_in, cdim_out, tdim, has_attn):
        super(UpBlockV1, self).__init__()
        self.res = ResidualBlockV1(cdim_in + cdim_out, cdim_out, tdim)
        if has_attn:
            self.attn = AttentionBlockV1(cdim_out)
        else:
            self.attn = nn.Identity()
    
    def forward(self, x, t):
        x = self.res(x, t)
        x = self.attn(x)
        return x

In [None]:
class MiddleBlockV1(nn.Module):
    def __init__(self, cdim, tdim):
        super(MiddleBlockV1, self).__init__()
        self.res1 = ResidualBlockV1(cdim, cdim, tdim)
        self.attn = AttentionBlockV1(cdim)
        self.res2 = ResidualBlockV1(cdim, cdim, tdim)
    
    def forward(self, x, t):
        x = self.res1(x, t)
        x = self.attn(x)
        x = self.res2(x, t)
        return x

In [None]:
class UpSampleV1(nn.Module):
    def __init__(self, cdim):
        super(UpSampleV1, self).__init__()
        self.layer = nn.ConvTranspose2d(cdim, cdim, (4,4), (2,2), (1,1))
    
    def forward(self, x, t):
        return self.layer(x)

In [None]:
class DownSampleV1(nn.Module):
    def __init__(self, cdim):
        super(UpSampleV1, self).__init__()
        self.layer = nn.Conv2d(cdim, cdim, (3,3), (2,2), (1,1))
    
    def forward(self, x, t):
        return self.layer(x)

In [None]:
class DDPMV1(nn.Module):
    def __init__(self, cdim_in, cdim, cmults, is_attn, n_blocks):
        super(DDPMV1, self).__init__()
        n = len(cmults)
        self.time_emb = TimeEmbeddingV1(cdim*4)
        self.image_proj = nn.Conv2d(cdim_in, cdim, kernel_size=(3,3), padding=(1,1))
        down = []
        out_channels = in_channels = cdim
        for i in range(n):
            out_channels = in_channels * cmults[i]
            for _ in range(n_blocks):
                down.append(DownBlockV1(in_channels, out_channels, cdim*4, is_attn[i]))
                in_channels = out_channels
            down.append(DownSampleV1(in_channels))
        self.down = nn.ModuleList(down)
        self.middle = MiddleBlockV1(out_channels, cdim*4)
        up = []
        in_channels = out_channels
        for i in reversed(range(n)):
            up.append(UpSampleV1(in_channels))
            out_channels = in_channels // cmults[i]
            up.append(UpBlockV1(in_channels+out_channels*cmults[i], out_channels, cdim*4, is_attn[i]))
            for _ in range(1, n_blocks):
                up.append(UpBlockV1(out_channels, out_channels, cdim*4, is_attn[i]))
            in_channels = out_channels
        self.up = nn.ModuleList(up)
        self.final = nn.Sequential(
            nn.GroupNorm(8, cdim),
            nn.SilU(),
            nn.Conv2d(in_channels, cdim_in, kernel_size=(3,3), padding=(1,1)),
        )
    
    def forward(self, x, t):
        t = self.time_emb(t)
        x = self.image_proj(x)
        h = []
        for m in self.down:
            if isinstance(m, DownSampleV1):
                h.append(x)
            x = m(x, t)
        x = self.middle(x, t)
        is_first = False
        for m in self.up:
            if isinstance(m, UpSampleV1):
                x = m(x, t)
                is_first = True
            else:
                if is_first:
                    s = h.pop()
                    x = torch.cat((x, s), dim=1)
                    is_first = False
                x = m(x, t)
        return self.final(x)

In [None]:
## datasets