In [142]:
import torch
import torch.nn as nn
from einops_exts.torch import EinopsToAndFrom

In [211]:
class AttentionTypes(nn.Module):
    def __init__(self, dim, dim_head = 64, heads = 8, dropout = 0.05, context_dim = None, norm_context = False, att_type='normal'):
        super().__init__()
        
        self.scale         = dim_head ** -0.5
        self.heads         = heads
        self.att_type      = att_type
        self.norm_context  = norm_context
        # Normal Attention Initialization #############################################################################################################
        
        if att_type == 'normal':     
            self.norm = nn.LayerNorm(dim)

            self.null_kv = nn.Parameter(torch.randn(2, dim_head))
            self.Q       = nn.Linear(dim, dim_head * heads, bias = False)
            self.KV      = nn.Linear(dim, dim_head * 2, bias = False)

            self.context      = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, dim_head * 2)) if context_dim is not None else None
            self.output_layer = nn.Sequential(nn.Linear(dim_head * heads, dim, bias = False), nn.LayerNorm(dim))
            
        # Linear Attention Initialization #############################################################################################################
            
        elif att_type == 'linear':
            
            self.g1 = nn.Parameter(torch.ones(1, dim, 1, 1))
            
            self.Q = nn.Sequential(nn.Dropout(dropout),
                                   nn.Conv2d(dim, dim_head*heads, 1, bias=False),
                                   nn.Conv2d(dim_head*heads, dim_head*heads, 3, bias=False, padding=1, groups=dim_head*heads))

            self.K = nn.Sequential(nn.Dropout(dropout),
                                   nn.Conv2d(dim, dim_head*heads, 1, bias=False),
                                   nn.Conv2d(dim_head*heads, dim_head*heads, 3, bias=False, padding=1, groups=dim_head*heads))

            self.V = nn.Sequential(nn.Dropout(dropout),
                                   nn.Conv2d(dim, dim_head*heads, 1, bias=False),
                                   nn.Conv2d(dim_head*heads, dim_head*heads, 3, bias=False, padding=1, groups=dim_head*heads))
            
            self.context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, dim_head*heads*2, bias=False)) if context_dim is not None else None

            self.activation   = Swish()
            self.output_layer = nn.Conv2d(dim_head*heads, dim, 1, bias=False)            
            self.g2           = nn.Parameter(torch.ones(1, dim, 1, 1))
            
        # Cross and Linear Cross Attention Initialization #############################################################################################

        elif att_type == 'cross' or att_type == 'linear-cross':

            context_dim = context_dim if context_dim is not None else dim

            self.norm         = nn.LayerNorm(dim)
            if norm_context:
                self.norm_context = nn.LayerNorm(context_dim)

            self.null_kv = nn.Parameter(torch.randn(2, dim_head))
            self.Q       = nn.Linear(dim, dim_head*heads, bias=False)
            self.KV      = nn.Linear(context_dim, dim_head*heads*2, bias=False)

            self.output_layer = nn.Sequential(nn.Linear(dim_head*heads, dim, bias=False), nn.LayerNorm(dim))
    
    #===================================================================================#
    #                               ATTENTION FUNCTION                                  #
    #===================================================================================#    
        
    def attention(self, q, k, v, mask = None, att_bias = None):
        
        # Normal Attention ######################################################
        
        if self.att_type == 'normal':
            score = einsum('b h i d, b j d -> b h i j', q, k)

            if att_bias is not None: score = score + att_bias

            if mask is not None:            
                max_neg = -torch.finfo(score.dtype).max            
                mask    = mask[:, None, None, :]
                score   = score.masked_fill(~mask, max_neg)

            probs = score.softmax(dim = -1, dtype = torch.float32)
            
            print('probs', probs.shape)
            print('v', v.shape)
            
            att = einsum('b h i j, b j d -> b h i d', probs, v)
            att = rearrange(att, 'b h n d -> b n (h d)')
            
        # Linear Attention ######################################################
            
        elif self.att_type == 'linear':
            
            att = torch.matmul(k.transpose(-1, -2), v)
            att = torch.matmul(q, att)
            
        # Cross Attention #######################################################
            
        elif self.att_type == 'cross':
            score = torch.matmul(q, k.transpose(-1, -2))

            if mask is not None:            
                max_neg = -torch.finfo(score.dtype).max            
                mask    = mask[:, None, None, :]
                score   = score.masked_fill(~mask, max_neg)

            probs = score.softmax(dim = -1, dtype = torch.float32)

            att = torch.matmul(probs, v)
            att = rearrange(att, 'b h n d -> b n (h d)')
            
        # Linear Cross Attention ################################################
            
        elif self.att_type == 'linear-cross':           

            if mask is not None: 
                max_neg = -torch.finfo(score.dtype).max
                mask = mask[:, :, None]
                k, v = k.masked_fill(~mask, max_neg), v.masked_fill(~mask, 0.)

            q, k = q.softmax(dim = -1)*self.scale, k.softmax(dim = -2)

            att = torch.matmul(k.transpose(-1, -2), v)
            att = torch.matmul(q, att)
            att = rearrange(att, '(b h) n d -> b n (h d)', h = self.heads)
        
        return att
    
    #===================================================================================#
    #                                      FORWARD                                      #
    #===================================================================================#      

    def forward(self, x, context = None, mask = None, att_bias = None):
        
        # Normal Attention #####################################################################################
        
        if self.att_type == 'normal': 
        
            b, n, device = x.shape[0], x.shape[1], x.device

            x = self.norm(x)

            q = self.Q(x)
            print(q.shape)
            q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads) * self.scale
            
            k, v = self.KV(x).chunk(2, dim = -1)

            nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b 1 d', b = b)
            k, v = torch.cat((nk, k), dim = -2), torch.cat((nv, v), dim = -2)

            if context is not None:
                assert self.context is not None
                ck, cv = self.context(context).chunk(2, dim = -1)
                k, v = torch.cat((ck, k), dim = -2), torch.cat((cv, v), dim = -2)

            # ATTENTION

            if mask is not None:
                mask = F.pad(mask, (1, 0), value = True)
                att = self.attention(q=q, k=k, v=v, mask=mask, att_bias=att_bias)
            else:
                att = self.attention(q=q, k=k, v=v, att_bias=att_bias)
                
            out = self.output_layer(att) 
            
        # Linear Attention #####################################################################################
        
        elif self.att_type == 'linear': 

            x1, y1, h = x.shape[-2], x.shape[-1], self.heads
            
            var  = torch.var(x, dim = 1, unbiased = False, keepdim = True)
            mean = torch.mean(x, dim = 1, keepdim = True)
            x    = (x - mean) / (var + 1e-15).sqrt() * self.g1   
            
            q, k, v = self.Q(x), self.K(x), self.V(x)
            q, k, v = rearrange_many((q, k, v), 'b (h c) x y -> (b h) (x y) c', h = h)

            if context is not None:
                assert self.context is not None
                ck, cv = self.context(context).chunk(2, dim = -1)
                ck, cv = rearrange_many((ck, cv), 'b n (h d) -> (b h) n d', h = h)
                k, v = torch.cat((k, ck), dim = -2), torch.cat((v, cv), dim = -2)

            q, k = q.softmax(dim = -1)*self.scale, k.softmax(dim = -2)
            
            att = self.attention(q, k, v)

            out = rearrange(att, '(b h) (x y) d -> b (h d) x y', h = h, x = x1, y = y1)
            out = self.activation(out)
            out = self.output_layer(out)
            
            var_out  = torch.var(out, dim = 1, unbiased = False, keepdim = True)
            mean_out = torch.mean(out, dim = 1, keepdim = True)
            out      = (x - mean_out) / (var_out + 1e-15).sqrt() * self.g2  
            
        # Cross and Linear Cross Attention #####################################################################
            
        elif self.att_type == 'cross' or self.att_type == 'linear-cross':
        # elif self.att_type == 'cross':
            
            b, n, device = x.shape[0], x.shape[1], x.device

            x       = self.norm(x)
            context = self.norm_context(context) if self.norm_context else context

            q    = self.Q(x)
            k, v = self.KV(context).chunk(2, dim = -1)

            if self.att_type == 'cross':    
                
                q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = self.heads)
                nk, nv  = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b h 1 d', h = self.heads,  b = b)
                k, v, q = torch.cat((nk, k), dim = -2), torch.cat((nv, v), dim = -2), q*self.scale
                
            else:
                q, k, v = rearrange_many((q, k, v), 'b n (h d) -> (b h) n d', h = self.heads)
                nk, nv  = repeat_many(self.null_kv.unbind(dim = -2), 'd -> (b h) 1 d', h = self.heads,  b = b)
                k, v    = torch.cat((nk, k), dim = -2), torch.cat((nv, v), dim = -2)
            
            if mask is not None:
                mask = F.pad(mask, (1, 0), value = True)
                att = self.attention(q=q, k=k, v=v, mask=mask)
            else:
                att = self.attention(q=q, k=k, v=v)

            out = self.output_layer(att)
            
        return out

In [None]:
class CrossEmbedLayer(nn.Module):
    def __init__(self, dim_in, kernel_sizes, dim_out = None, stride = 2):
        super().__init__()
        assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)])
        dim_out = default(dim_out, dim_in)

        kernel_sizes = sorted(kernel_sizes)
        num_scales = len(kernel_sizes)

        # calculate the dimension at each scale
        dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)]
        dim_scales = [*dim_scales, dim_out - sum(dim_scales)]

        self.convs = nn.ModuleList([])
        for kernel, dim_scale in zip(kernel_sizes, dim_scales):
            self.convs.append(nn.Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2))

    def forward(self, x):
        fmaps = tuple(map(lambda conv: conv(x), self.convs))
        return torch.cat(fmaps, dim = 1)

In [2]:
kernel_sizes = [2, 4]
stride = 2

all([ks % 2 == stride % 2 for ks in kernel_sizes])

True

In [23]:
class CrossEmbedding(nn.Module):
    
    def __init__(self, in_channels, out_channels, stride, kernel_sizes):
        
        super(CrossEmbedding, self).__init__()        
        assert all([ks % 2 == stride % 2 for ks in kernel_sizes]), 'Kernel and Stride must be odd or even, both'
        assert all([ks >= stride for ks in kernel_sizes]), 'All Kernels must be larger than Stride'
        
        kernel_sizes, n_kernels = sorted(kernel_sizes), len(kernel_sizes)
        
        dim_scales = [int(out_channels / (2 ** i)) for i in range(1, n_kernels)]
        dim_scales.append(out_channels - sum(dim_scales))
        
        self.conv_layer = nn.ModuleList([nn.Conv2d(in_channels, dim_scale, kernel,
                                                   stride = stride, padding = (kernel - stride) // 2)
                                        for kernel, dim_scale in zip(kernel_sizes, dim_scales)])
        
    def forward(self, x):
        out = [conv(x) for conv in self.conv_layer]
        return torch.cat(out, dim = 1)

In [28]:
ce = CrossEmbedding(in_channels=3, out_channels=3, stride=4, kernel_sizes=[2, 4])
ce

AssertionError: All Kernels must be larger than Stride

In [25]:
x = torch.rand((1, 3, 10, 10))
x1, x2 = ce(x)

In [26]:
x1

tensor([[[[ 0.1607,  0.0354,  0.0207,  0.3303, -0.0195],
          [ 0.3290, -0.0417,  0.4365, -0.1227, -0.0551],
          [-0.1933,  0.1280, -0.1405, -0.1524, -0.1422],
          [ 0.1780, -0.0020, -0.2729,  0.0030, -0.0936],
          [ 0.2232,  0.0090, -0.0836,  0.2321,  0.0535]],

         [[-0.0992, -0.0362,  0.0603, -0.1646, -0.2857],
          [-0.2236,  0.1094,  0.0615, -0.0907, -0.0036],
          [ 0.0175, -0.0200, -0.1888,  0.0071,  0.0780],
          [ 0.2778,  0.1466,  0.0529,  0.0920,  0.1258],
          [ 0.3529,  0.2027,  0.0930,  0.1575,  0.2072]],

         [[-0.3897, -0.7092, -0.3089, -0.3905, -0.5074],
          [-0.4342,  0.0895,  0.0833, -0.0590,  0.0147],
          [-0.5677, -0.3688, -0.0239, -0.2113, -0.1807],
          [-0.2386, -0.0758, -0.1090, -0.1430, -0.0996],
          [ 0.2304,  0.3404,  0.1393, -0.0335,  0.0495]]]],
       grad_fn=<CatBackward0>)

In [27]:
x2

tensor([[[[ 0.1607,  0.0354,  0.0207,  0.3303, -0.0195],
          [ 0.3290, -0.0417,  0.4365, -0.1227, -0.0551],
          [-0.1933,  0.1280, -0.1405, -0.1524, -0.1422],
          [ 0.1780, -0.0020, -0.2729,  0.0030, -0.0936],
          [ 0.2232,  0.0090, -0.0836,  0.2321,  0.0535]],

         [[-0.0992, -0.0362,  0.0603, -0.1646, -0.2857],
          [-0.2236,  0.1094,  0.0615, -0.0907, -0.0036],
          [ 0.0175, -0.0200, -0.1888,  0.0071,  0.0780],
          [ 0.2778,  0.1466,  0.0529,  0.0920,  0.1258],
          [ 0.3529,  0.2027,  0.0930,  0.1575,  0.2072]],

         [[-0.3897, -0.7092, -0.3089, -0.3905, -0.5074],
          [-0.4342,  0.0895,  0.0833, -0.0590,  0.0147],
          [-0.5677, -0.3688, -0.0239, -0.2113, -0.1807],
          [-0.2386, -0.0758, -0.1090, -0.1430, -0.0996],
          [ 0.2304,  0.3404,  0.1393, -0.0335,  0.0495]]]],
       grad_fn=<CatBackward0>)

In [4]:
dim_scales = [i for i in range(1,3)]
dim_scales = [*dim_scales, 4]
dim_scales

[1, 2, 4]

In [5]:
dim_scales = [i for i in range(1,3)]
dim_scales.append(4)
dim_scales

[1, 2, 4]

In [143]:
class Always():
    def __init__(self, val):
        self.val = val

    def __call__(self, *args, **kwargs):
        return self.val

In [144]:
class Swish(nn.Module):

    def __init__(self):
        super(Swish, self).__init__()

    def forward(self, x):
        return x * torch.sigmoid(x)

In [145]:
class Identity(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()

    def forward(self, x, *args, **kwargs):
        return x

In [5]:
class CrossAttention(nn.Module):
    def __init__(
        self,
        dim,
        *,
        context_dim = None,
        dim_head = 64,
        heads = 8,
        norm_context = False
    ):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        inner_dim = dim_head * heads

        context_dim = context_dim if context_dim is not None else dim

        self.norm = nn.LayerNorm(dim)
        self.norm_context = nn.LayerNorm(context_dim) if norm_context else Identity()

        self.null_kv = nn.Parameter(torch.randn(2, dim_head))
        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim, bias = False),
            nn.LayerNorm(dim)
        )

    def forward(self, x, context, mask = None):
        b, n, device = *x.shape[:2], x.device

        x = self.norm(x)
        context = self.norm_context(context)

        q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))

        q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = self.heads)

        # add null key / value for classifier free guidance in prior net

        nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b h 1 d', h = self.heads,  b = b)

        k = torch.cat((nk, k), dim = -2)
        v = torch.cat((nv, v), dim = -2)

        q = q * self.scale

        sim = einsum('b h i d, b h j d -> b h i j', q, k)
        max_neg_value = -torch.finfo(sim.dtype).max

        if exists(mask):
            mask = F.pad(mask, (1, 0), value = True)
            mask = rearrange(mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(~mask, max_neg_value)

        attn = sim.softmax(dim = -1, dtype = torch.float32)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

In [146]:
def exists(val):
    return val is not None

In [13]:
class Block(nn.Module):
    def __init__(self, dim, dim_out, groups = 8):
        super().__init__()
        self.norm = nn.GroupNorm(groups, dim)
        self.acti = Swish()
        self.conv = nn.Conv2d(dim, dim_out, 3, padding=1)

    def forward(self, x, scale_shift = None):
        
        x = self.norm(x)
        
        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift

        x = self.activation(x)
        return self.project(x)

In [123]:
class GlobalContext(nn.Module):
    """ basically a superior form of squeeze-excitation that is attention-esque """

    def __init__(self, dim_in, dim_out):
        super().__init__()
        
        self.to_k = nn.Conv2d(dim_in, 1, 1)
        hidden_dim = max(3, dim_out // 2)

        self.net = nn.Sequential(nn.Conv2d(dim_in, hidden_dim, 1),
                                 nn.SiLU(),
                                 nn.Conv2d(hidden_dim, dim_out, 1),
                                 nn.Sigmoid())

    def forward(self, x):
        context = self.to_k(x)
        x, context = rearrange_many((x, context), 'b n ... -> b n (...)')
        print(context.shape)
        out = einsum('b i n, b c n -> b c i', context.softmax(dim = -1), x)
        out = rearrange(out, '... -> ... 1')
        print(out.shape)
        return self.net(out)

In [137]:
class ResnetBlock(nn.Module):
    def __init__(
        self,
        dim,
        dim_out,
        *,
        cond_dim = None,
        time_cond_dim = None,
        groups = 8,
        linear_attn = False,
        use_gca = False
    ):
        super().__init__()

        self.time_mlp = None

        if exists(time_cond_dim):
            self.time_mlp = nn.Sequential(
                nn.SiLU(),
                nn.Linear(time_cond_dim, dim_out * 2)
            )

        self.cross_attn = None

        if exists(cond_dim):
            attn_klass = CrossAttention if not linear_attn else LinearCrossAttention

            self.cross_attn = EinopsToAndFrom(
                'b c h w',
                'b (h w) c',
                attn_klass(
                    dim = dim_out,
                    context_dim = cond_dim
                )
            )

        self.block1 = Block(dim, dim_out, groups = groups)
        self.block2 = Block(dim_out, dim_out, groups = groups)

        self.gca = GlobalContext(dim_in = dim_out, dim_out = dim_out) if use_gca else Always(1)

        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else Identity()


    def forward(self, x, time_emb = None, cond = None):

        scale_shift = None
        if exists(self.time_mlp) and exists(time_emb):
            time_emb = self.time_mlp(time_emb)
            time_emb = rearrange(time_emb, 'b c -> b c 1 1')
            scale_shift = time_emb.chunk(2, dim = 1)

        h = self.block1(x)

        if exists(self.cross_attn):
            assert exists(cond)
            h = self.cross_attn(h, context = cond) + h

        print(h.shape)
            
        h = self.block2(h, scale_shift = scale_shift)

        h = h * self.gca(h)

        return h + self.res_conv(x)

In [132]:
class ResnetBlock(nn.Module):
    def __init__(self, dim, dim_out, cond_dim = None, time_cond_dim = None,
                 groups = 8, linear_att = False, use_gca = False):
        
        super().__init__()
        
        self.use_gca = use_gca

        self.time_NN = nn.Sequential(Swish(), nn.Linear(time_cond_dim, dim_out*2))

        self.att_layer = None
        if cond_dim is not None:
            att_type = 'cross' if not linear_att else 'linear-cross'    
            self.att_layer = AttentionTypes(dim = dim_out, context_dim = cond_dim, att_type=att_type)

        self.seq1 = nn.Sequential(nn.GroupNorm(groups, dim),
                                  Swish(),
                                  nn.Conv2d(dim, dim_out, kernel_size=3, padding=1))
        
        self.norm = nn.GroupNorm(groups, dim_out)
        self.seq2 = nn.Sequential(Swish(),
                                  nn.Conv2d(dim_out, dim_out, kernel_size=3, padding=1))
        
        if use_gca:
            self.conv = nn.Conv2d(dim_out, 1, 1)
            self.seq3 = nn.Sequential(nn.Conv2d(dim_out, max(3, dim_out//2), 1),
                                      nn.SiLU(),
                                      nn.Conv2d(max(3, dim_out//2), dim_out, 1),
                                      nn.Sigmoid())

        self.output_layer = None
        if dim != dim_out:
            self.output_layer = nn.Conv2d(dim, dim_out, 1)


    def forward(self, x, time_emb = None, cond = None):

        scale, shift = None, None
        if self.time_NN is not None and time_emb is not None:
            
            time_emb = self.time_NN(time_emb)
            scale, shift = time_emb[:, :, None, None].chunk(2, dim = 1)

        h = self.seq1(x)

        if self.att_layer is not None:
            assert cond is not None
            
            w = h.shape[-1]            
            m = rearrange(h, 'b c h w -> b (h w) c')
            m = self.att_layer(m, context = cond) 
            h = rearrange(m, 'b (h w) c -> b c h w',w=w) + h                    
                
        print(h.shape)
                
        h = self.norm(h)*(scale + 1) + shift if scale is not None else self.norm(h)        
        h = self.seq2(h)
        
        if self.use_gca:
            c    = self.conv(h)
            y, c = rearrange_many((h, c), 'b n ... -> b n (...)')

            c    = c.softmax(dim = -1) 
            print(c.shape)
            o = torch.matmul(y, c.transpose(1, 2))[:, :, :, None]
            
            print(o.shape)
        
            h = h*self.seq3(o)
        
        return h + self.output_layer(x) if self.output_layer is not None else h + x

In [135]:
rb = ResnetBlock(dim = 64, dim_out=16, cond_dim = 4, time_cond_dim = 4,
                 groups = 8, linear_att = True, use_gca = True)

In [136]:
x = torch.rand((1, 64, 20, 20))
t = torch.rand((1, 4)) 
c = torch.rand((1, 16, 4)) 

rb(x=x, time_emb = t, cond = c).shape

torch.Size([1, 16, 20, 20])
torch.Size([1, 1, 400])
torch.Size([1, 16, 1, 1])


torch.Size([1, 16, 20, 20])

In [138]:
rb = ResnetBlock(dim = 64, dim_out=16, cond_dim = 4, time_cond_dim = 4,
                 groups = 8, linear_attn = True, use_gca = True)

NameError: name 'LinearCrossAttention' is not defined

In [126]:
x = torch.rand((1, 64, 20, 20))
t = torch.rand((1, 4)) 
c = torch.rand((1, 16, 4)) 

rb(x=x, time_emb = t, cond = c).shape

torch.Size([1, 16, 20, 20])
torch.Size([1, 1, 400])
torch.Size([1, 16, 1, 1])


torch.Size([1, 16, 20, 20])

In [86]:
v = torch.rand((1, 2, 3))
r = torch.rand((1, 4, 3))

einsum('b i n, b c n -> b c i', v, r).shape

torch.Size([1, 4, 2])

In [74]:
rearrange(v, '... -> ... 1').shape

torch.Size([1, 2, 3, 1])

In [75]:
v[:, :, :, None].shape

torch.Size([1, 2, 3, 1])

In [73]:
torch.matmul(v, r.transpose(1, 2)).shape

torch.Size([1, 2, 4])

In [83]:
rb = ResnetBlock(dim = 64, dim_out=16, cond_dim = 4, time_cond_dim = 4,
                 groups = 8, linear_att = False, use_gca = False, squeeze_excite = False)

In [70]:
x = torch.rand((1, 64, 20, 20))
t = torch.rand((1, 4)) 
c = torch.rand((1, 16, 4)) 

rb(x=x, time_emb = t, cond = c).shape

torch.Size([1, 400, 16])
torch.Size([1, 400, 16])
torch.Size([1, 16, 20, 20])


torch.Size([1, 16, 20, 20])

In [84]:
x = torch.rand((1, 64, 20, 20))
t = torch.rand((1, 4)) 
c = torch.rand((1, 16, 4)) 

rb(x=x, time_emb = t, cond = c).shape

RuntimeError: Expected weight to be a vector of size equal to the number of channels in input, but got weight of shape [64] and input of shape [1, 16, 20, 20]

In [35]:
from einops import rearrange
from einops_exts import rearrange_many, repeat_many
from torch import nn, einsum

In [57]:
rb = ResnetBlock(dim = 64, dim_out=16, cond_dim = 4, time_cond_dim = 4,
            groups = 8, linear_attn = False, use_gca = False, squeeze_excite = False)

rb

ResnetBlock(
  (time_NN): Sequential(
    (0): Swish()
    (1): Linear(in_features=4, out_features=32, bias=True)
  )
  (cross_attn): CrossAttention(
    (norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
    (norm_context): Identity()
    (to_q): Linear(in_features=16, out_features=512, bias=False)
    (to_kv): Linear(in_features=4, out_features=1024, bias=False)
    (to_out): Sequential(
      (0): Linear(in_features=512, out_features=16, bias=False)
      (1): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
    )
  )
  (block1): Block(
    (groupnorm): GroupNorm(8, 64, eps=1e-05, affine=True)
    (activation): SiLU()
    (project): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (block2): Block(
    (groupnorm): GroupNorm(8, 16, eps=1e-05, affine=True)
    (activation): SiLU()
    (project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (res_conv): Conv2d(64, 16, kernel_size=(1, 1), stride=(1, 1))
)

In [215]:
class Attention(nn.Module):
    def __init__(
        self,
        dim,
        *,
        dim_head = 64,
        heads = 8,
        context_dim = None
    ):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        inner_dim = dim_head * heads

        self.norm = LayerNorm(dim)

        self.null_kv = nn.Parameter(torch.randn(2, dim_head))
        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)

        self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, dim_head * 2)) if exists(context_dim) else None

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim, bias = False),
            LayerNorm(dim)
        )

    def forward(self, x, context = None, mask = None, attn_bias = None):
        b, n, device = *x.shape[:2], x.device

        x = self.norm(x)
        q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))
        print('v', v.shape)
        
        print('q', q.shape)

        q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
        q = q * self.scale

        # add null key / value for classifier free guidance in prior net

        nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b 1 d', b = b)
        k = torch.cat((nk, k), dim = -2)
        v = torch.cat((nv, v), dim = -2)
        print('v', v.shape)

        # add text conditioning, if present

        if exists(context):
            assert exists(self.to_context)
            ck, cv = self.to_context(context).chunk(2, dim = -1)
            k = torch.cat((ck, k), dim = -2)
            v = torch.cat((cv, v), dim = -2)
            print('v', v.shape)

        # calculate query / key similarities

        sim = einsum('b h i d, b j d -> b h i j', q, k)

        # relative positional encoding (T5 style)

        if exists(attn_bias):
            sim = sim + attn_bias

        # masking

        max_neg_value = -torch.finfo(sim.dtype).max

        if exists(mask):
            mask = F.pad(mask, (1, 0), value = True)
            mask = rearrange(mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(~mask, max_neg_value)

        # attention

        attn = sim.softmax(dim = -1, dtype = torch.float32)

        # aggregate values

        print('attn', attn.shape)
        print('v', v.shape)
        
        out = einsum('b h i j, b j d -> b h i d', attn, v)

        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.to_out(out)
        print('out', out.shape)
        return out

In [235]:
class LinearAttention(nn.Module):
    def __init__(
        self,
        dim,
        dim_head = 32,
        heads = 8,
        dropout = 0.05,
        context_dim = None
    ):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        inner_dim = dim_head * heads
        self.norm = ChanLayerNorm(dim)

        self.nonlin = nn.SiLU()

        self.to_q = nn.Sequential(
            nn.Dropout(dropout),
            nn.Conv2d(dim, inner_dim, 1, bias = False),
            nn.Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
        )

        self.to_k = nn.Sequential(
            nn.Dropout(dropout),
            nn.Conv2d(dim, inner_dim, 1, bias = False),
            nn.Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
        )

        self.to_v = nn.Sequential(
            nn.Dropout(dropout),
            nn.Conv2d(dim, inner_dim, 1, bias = False),
            nn.Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
        )

        self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, inner_dim * 2, bias = False)) if exists(context_dim) else None

        self.to_out = nn.Sequential(
            nn.Conv2d(inner_dim, dim, 1, bias = False),
            ChanLayerNorm(dim)
        )

    def forward(self, fmap, context = None):
        h, x, y = self.heads, *fmap.shape[-2:]

        fmap = self.norm(fmap)
        q, k, v = map(lambda fn: fn(fmap), (self.to_q, self.to_k, self.to_v))
        q, k, v = rearrange_many((q, k, v), 'b (h c) x y -> (b h) (x y) c', h = h)

        if exists(context):
            assert exists(self.to_context)
            ck, cv = self.to_context(context).chunk(2, dim = -1)
            ck, cv = rearrange_many((ck, cv), 'b n (h d) -> (b h) n d', h = h)
            k = torch.cat((k, ck), dim = -2)
            v = torch.cat((v, cv), dim = -2)

        q = q.softmax(dim = -1)
        k = k.softmax(dim = -2)

        q = q * self.scale

        context = einsum('b n d, b n e -> b d e', k, v)
        out = einsum('b n d, b d e -> b n e', q, context)
        out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y)

        out = self.nonlin(out)
        return self.to_out(out)

In [None]:
m = rearrange(h, 'b c h w -> b (h w) c')
m = self.att_layer(m, context = cond) 
h = rearrange(m, 'b (h w) c -> b c h w',w=w) + h

self.cross_attn = EinopsToAndFrom(
                'b c h w',
                'b (h w) c',
                attn_klass(
                    dim = dim_out,
                    context_dim = cond_dim
                )
            )

In [162]:
class ChanLayerNorm(nn.Module):
    def __init__(self, dim, eps = 1e-5):
        super().__init__()
        self.eps = eps
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1))

    def forward(self, x):
        var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
        mean = torch.mean(x, dim = 1, keepdim = True)
        return (x - mean) / (var + self.eps).sqrt() * self.g

In [216]:
class TransformerBlock(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 32, ff_mult = 2, context_dim = None):
        super().__init__()
        self.attn = EinopsToAndFrom('b c h w', 'b (h w) c',
                                    Attention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim))
        self.ff = ChanFeedForward(dim = dim, mult = ff_mult)

    def forward(self, x, context = None):
        x = self.attn(x, context = context) + x
        x = self.ff(x) + x
        return x

In [236]:
class LinearAttentionTransformerBlock(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 32, ff_mult = 2, context_dim = None):
        super().__init__()
        self.attn = LinearAttention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim)
        self.ff = ChanFeedForward(dim = dim, mult = ff_mult)

    def forward(self, x, context = None):
        x = self.attn(x, context = context) + x
        x = self.ff(x) + x
        return x

In [240]:
class TransformerBlock(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 32, ff_mult = 2, att_type='normal', context_dim = None):
        super().__init__()
        
        self.att_type = att_type
        
        self.att = AttentionTypes(dim=dim, heads=heads, dim_head=dim_head, context_dim=context_dim, att_type=att_type)
        self.g1    = nn.Parameter(torch.ones(1, dim, 1, 1))
        self.conv1 = nn.Conv2d(dim, dim*ff_mult, 1, bias = False)
        self.activ = Swish()
        self.g2    = nn.Parameter(torch.ones(1, dim*ff_mult, 1, 1))
        self.conv2 = nn.Conv2d(dim*ff_mult, dim, 1, bias = False)

    def forward(self, x, context = None):
        
        if self.att_type == 'normal':
            w = x.shape[-1]    
            x = rearrange(x, 'b c h w -> b (h w) c')
            x = self.att(x, context = context) + x
            x = rearrange(x, 'b (h w) c -> b c h w',w=w)
            
        elif self.att_type == 'linear':
            x = self.att(x, context = context) + x
        
        num = x - torch.mean(x, dim=1, keepdim=True)
        den = (torch.var(x, dim=1, unbiased=False, keepdim=True) + 1e-5).sqrt()
        x1  = (num/den)*self.g1
        
        x1 = self.activ(self.conv1(x1))
        
        num = x1 - torch.mean(x1, dim=1, keepdim=True)
        den = (torch.var(x1, dim=1, unbiased=False, keepdim=True) + 1e-5).sqrt()
        x1  = (num/den)*self.g2
        
        out = self.conv2(x1) + x
        return out

In [237]:
li = LinearAttentionTransformerBlock(dim=4, heads = 8, dim_head = 32, ff_mult = 2, context_dim = None)

In [248]:
x = torch.rand((1, 8, 15))
li(x).shape

torch.Size([1, 4, 8, 15])

In [241]:
tb = TransformerBlock(dim=4, heads = 8, dim_head = 32, ff_mult = 2, context_dim = None, att_type='linear')

In [247]:
x = torch.rand((1, 8, 15))
tb(x).shape

torch.Size([1, 4, 8, 15])

In [217]:
tb = TransformerBlock(dim=4, heads = 8, dim_head = 32, ff_mult = 2, context_dim = None)

In [218]:
x = torch.rand((1, 4, 4, 4))
tb(x).shape

v torch.Size([1, 16, 32])
q torch.Size([1, 16, 256])
v torch.Size([1, 17, 32])
attn torch.Size([1, 8, 16, 17])
v torch.Size([1, 17, 32])
out torch.Size([1, 16, 4])


torch.Size([1, 4, 4, 4])

In [233]:
tb = TransformerBlock(dim=4, heads = 8, dim_head = 32, ff_mult = 2, context_dim = None)

In [234]:
x = torch.rand((1, 4, 4, 4))
tb(x).shape

torch.Size([1, 16, 256])
probs torch.Size([1, 8, 16, 17])
v torch.Size([1, 17, 32])


torch.Size([1, 4, 4, 4])

In [170]:
import torch.nn.functional as F

In [149]:
x    = torch.tensor(4)
mean = torch.tensor(2)
var  = torch.tensor(9)
g1   = torch.tensor(15)

(x - mean)/(var + 1e-5).sqrt()*g1

tensor(10.0000)

In [153]:
num = x - mean
den = (var + 1e-5).sqrt()
(num/den)*g1

tensor(10.0000)

In [151]:
a = (var + 1e-5).sqrt()
b = (x - mean)
c = b/a
c*g1

tensor(10.0000)

In [141]:
tb = TransformerBlock(dim=5)

NameError: name 'ChanFeedForward' is not defined

In [None]:
def ChanFeedForward(dim, mult = 2):  # in paper, it seems for self attention layers they did feedforwards with twice channel width
    hidden_dim = int(dim * mult)
    return nn.Sequential(
        ChanLayerNorm(dim),
        nn.Conv2d(dim, hidden_dim, 1, bias = False),
        nn.GELU(),
        ChanLayerNorm(hidden_dim),
        nn.Conv2d(hidden_dim, dim, 1, bias = False)
    )

In [164]:
class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.register_buffer('beta', torch.zeros(dim))

    def forward(self, x):
        return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)

In [166]:
def ChanFeedForward(dim, mult = 2):  # in paper, it seems for self attention layers they did feedforwards with twice channel width
    hidden_dim = int(dim * mult)
    return nn.Sequential(
        ChanLayerNorm(dim),
        nn.Conv2d(dim, hidden_dim, 1, bias = False),
        nn.GELU(),
        ChanLayerNorm(hidden_dim),
        nn.Conv2d(hidden_dim, dim, 1, bias = False)
    )

In [None]:
def Downsample(dim, dim_out = None):
    dim_out = default(dim_out, dim)
    return nn.Conv2d(dim, dim_out, 4, 2, 1)