In [30]:
import math
from inspect import isfunction
from functools import partial

%matplotlib inline
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange

import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

In [2]:
# helper functions
def exists(x):
    return x is not None

def default(val, d):
    if exists(val): return val
    return d() if callable(d) else d

def num_to_groups(num, divisor):
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
        
    return arr

In [None]:
# helper modules
def Upsample(dim, dim_out = None):
    return nn.Sequential(
        nn.Upsample(scale_factor = 2, mode = 'nearest'),
        nn.Conv2d(dim, default(dim_out, dim), 3, padding = 1)
    )
    
def Downsample(dim, dim_out=None):
    return nn.Sequential(
        Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = 2, p2 = 2),
        nn.Conv2d(dim * 4, default(dim_out, dim), 1)
    )
    
class RMSNorm(nn.Module):
    def __init__(self, dim, eps = 1e-5):
        super().__init__()
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(dim))
    
    def forward(self, x):
        mean_sqrt = torch.sqrt(x.pow(2).mean(dim=1, keepdim=True) + self.eps)
        x = x/mean_sqrt * self.gamma[:, None, None]
        return x

Positional embeddings

In [23]:
torch.randn(23, 3, 23, 23).mean(dim=1, keepdim=True).shape

torch.Size([23, 1, 23, 23])

In [None]:
class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim, theta = 10000):
        super().__init__()
        self.dim = dim
        self.theta = theta
        
    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(self.theta) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
        return emb
        
                

In [None]:
class Block(nn.Module):
    def __init__(self, dim, dim_out, dropout = 0.):
        super().__init__()
        self.proj = nn.Conv2d(dim, dim_out, 3, padding=1)
        self.norm = RMSNorm()
        self.act = nn.SiLU()
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, scale_shift = None):
        x = self.proj(x)
        x = self.norm(x)
        
        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift
        
        return self.dropout(self.act(x))

In [None]:
class ResnetBlock(nn.Module):
    def __init__(self, dim, dim_out, *, time_emb_dim = None, dropout = 0.):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.SiLU,
            nn.Linear(time_emb_dim, dim_out * 2)
        ) if exists(time_emb_dim) else None
        
        self.block1 = Block(dim, dim_out, dropout=dropout)
        self.block2 = Block(dim_out, dim_out)
        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
        
    def forward(self, x, time_emb=None):
        scale_shift = None
        if exists(self.mlp) and exists(time_emb):
            time_emb = self.mlp(time_emb)
            time_emb = rearrange(time_emb, 'b c -> b c 1 1')
            scale_shift = time_emb.chunk(2, dim = 1)
            
        h = self.block1(x, scale_shift=scale_shift)
        out = self.block2(x)
        return out + self.res_conv(x)

In [None]:
class LinearAttention(nn.Module):
    def __init__(
        self,
        dim,
        heads = 4,
        dim_head = 32,
        num_mem_kv = 4
    ):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        
        self.norm = RMSNorm(dim)
        
        self.mem_kv = nn.Parameter(torch.randn(2, heads, dim_head, num_mem_kv))
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
        
        self.to_out = nn.Sequential(
            nn.Conv2d(hidden_dim, dim, 1),
            RMSNorm(dim)
        )
        
    def forward(self, x):
        b, c, h, w = x.shape
        
        x = self.norm(x)
        
        qkv = self.to_qkv(x).chunk(3, dim = 1)
        
        q, k, v = map(lambda t: rearrange('b (h c) x y -> b h c (x y)', h=self.heads))

        mk, mv = map(lambda t: repeat(t, ))
        
        
class Attention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = heads * dim
        
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)
        
        
        

(tensor([[ 0.4440,  0.3920,  0.6530, -0.9605,  0.2679, -0.5658, -0.9110,  1.4983,
           0.0041,  0.8113,  2.2044,  1.1604, -0.7538, -2.2985, -0.3477,  1.7214,
           0.5118,  0.4109, -0.2641,  1.2216, -1.1585,  0.2219,  1.3619,  0.2927,
          -0.5375,  0.7451, -0.0975,  0.8557, -1.5112,  1.5189,  0.0909, -1.8848,
           0.6316,  0.3313,  0.7537,  0.3601,  0.9153,  0.0710, -0.1831,  0.8386,
          -0.4594,  0.5228, -0.8754,  0.7014,  1.9111,  1.7775, -0.3256, -0.4987,
          -0.1167, -0.7170,  0.4379, -0.2896, -0.5781,  0.5635, -0.0133,  1.1437,
           0.9590,  1.1001,  1.2521, -1.4323,  0.0745, -0.2219,  0.6893,  0.3561,
          -0.3258,  0.5482, -0.3023, -0.0561, -1.3465, -1.7644, -0.0964, -0.3883,
           0.4494, -0.2762, -0.4611, -0.9451,  1.1704, -0.5131,  0.5018, -0.3176,
          -1.2073,  0.6973,  0.1539, -0.0280, -0.2244, -0.3203,  0.0479, -1.2624,
          -0.2860,  0.2154, -1.1691, -2.3402, -0.1102,  0.1076, -0.1814,  1.1576,
          -1.026