In [1]:
import torch
from ae_ddpm import *

In [8]:
class EncoderM(BaseEncoder):

    def __init__(
            self,
            in_channel: StrictInt,
            channel: StrictInt,
            channel_multiplier: List[StrictInt],
            n_res_blocks: StrictInt,
            attn_strides: List[StrictInt],
            attn_heads: StrictInt = 1,
            dropout: StrictFloat = 0,
            fold: StrictInt = 1,
    ):
        super().__init__(
            in_channel,
            channel,
            channel_multiplier,
            n_res_blocks,
            attn_strides,
            attn_heads,
            dropout,
            fold)
        group_norm = channel // 4
        in_channel = channel * 4
        self.mid = nn.ModuleList(
            [
                EncResBlockWithAttention(
                    in_channel,
                    in_channel,
                    dropout=dropout,
                    use_attention=True,
                    attention_head=attn_heads,
                    group_norm=group_norm
                ),
                EncResBlockWithAttention(
                    in_channel,
                    in_channel,
                    dropout=dropout,
                    group_norm=group_norm
                ),
            ]
        )

        self.out = nn.Linear(channel * 4 * 8 * 8, 512)



    def forward(self, input):

        x = super().forward(input)
        print(x.shape)
        for layer in self.mid:
            x = layer(x)
        x = self.out(x.flatten(start_dim=1))
        return x

In [9]:
enc = EncoderM(in_channel=3,
                      channel=128,
                      channel_multiplier=[1, 2, 2, 4, 4],
                      n_res_blocks=2,
                      attn_strides=[8, 16],
                      attn_heads=4,
                      dropout=0,
                      fold=1)

In [10]:
enc(torch.randn(1, 3, 128, 128))

torch.Size([1, 512, 8, 8])


tensor([[ 1.7648e-01,  8.5962e-02,  3.8631e-02, -1.2128e-01,  3.5422e-02,
         -1.1440e-01,  1.4193e-01, -6.0567e-02, -1.1883e-01,  8.9824e-02,
         -6.4125e-02, -1.2256e-01, -7.8203e-02, -9.7997e-02, -1.0199e-01,
         -4.2376e-03,  2.7806e-02,  2.3010e-02, -1.4266e-02, -3.1796e-02,
         -2.1724e-02, -6.4066e-02,  1.2855e-02, -5.6851e-02, -7.0995e-02,
         -8.8030e-03, -1.7191e-01,  1.3001e-01,  4.5608e-02, -8.1250e-02,
          2.4362e-03, -1.3606e-01, -3.4377e-02,  1.2428e-01,  3.4799e-02,
          5.4412e-02,  1.1654e-01,  1.0832e-01,  4.9515e-02, -2.9227e-02,
         -1.4401e-02, -6.5456e-02,  9.4891e-02,  6.0969e-02,  5.7815e-02,
          3.5636e-02, -1.1265e-02,  5.3637e-02, -1.2299e-01, -7.7614e-02,
          2.8612e-02,  4.0059e-03,  2.0339e-01, -1.4915e-02, -1.0014e-01,
         -3.4283e-02, -2.1381e-01, -1.2323e-01,  3.2159e-02, -1.2821e-02,
          8.5107e-02, -4.6876e-02, -2.7548e-02,  3.0883e-02,  5.4227e-02,
         -1.2840e-02, -1.6232e-01,  1.

In [2]:
from vit_pytorch.twins_svt import TwinsSVT

encoder = TwinsSVT(
    num_classes=512,  # number of output classes
    s1_emb_dim=64,  # stage 1 - patch embedding projected dimension
    s1_patch_size=4,  # stage 1 - patch size for patch embedding
    s1_local_patch_size=7,  # stage 1 - patch size for local attention
    s1_global_k=7,  # stage 1 - global attention key / value reduction factor, defaults to 7 as specified in paper
    s1_depth=1,  # stage 1 - number of transformer blocks (local attn -> ff -> global attn -> ff)
    s2_emb_dim=128,  # stage 2 (same as above)
    s2_patch_size=2,
    s2_local_patch_size=7,
    s2_global_k=7,
    s2_depth=1,
    s3_emb_dim=256,  # stage 3 (same as above)
    s3_patch_size=2,
    s3_local_patch_size=7,
    s3_global_k=7,
    s3_depth=5,
    s4_emb_dim=512,  # stage 4 (same as above)
    s4_patch_size=2,
    s4_local_patch_size=7,
    s4_global_k=7,
    s4_depth=4,
    peg_kernel_size=3,  # positional encoding generator kernel size
    dropout=0.  # dropout
)

In [12]:
encoder(torch.randn(1, 3, 128, 128))

EinopsError:  Error while processing rearrange-reduction pattern "b c (x p1) (y p2) -> (b x y) c p1 p2".
 Input tensor shape: torch.Size([1, 64, 32, 32]). Additional info: {'p1': 7, 'p2': 7}.
 Shape mismatch, can't divide axis of length 32 in chunks of 7

In [9]:
128%28

16

In [10]:
128+16

144

In [30]:
from nystrom_attention import Nystromformer
from einops import repeat
from einops.layers.torch import Rearrange

class EncoderV3(BaseEncoder):

    def __init__(
            self,
            in_channel: StrictInt,
            channel: StrictInt,
            channel_multiplier: List[StrictInt],
            n_res_blocks: StrictInt,
            attn_strides: List[StrictInt],
            attn_heads: StrictInt = 1,
            dropout: StrictFloat = 0,
            fold: StrictInt = 1,
    ):
        super().__init__(
            in_channel,
            channel,
            channel_multiplier,
            n_res_blocks,
            attn_strides,
            attn_heads,
            dropout,
            fold)
        group_norm = channel // 4
        self.mid = nn.ModuleList(
            [
                EncResBlockWithAttention(
                    in_channel,
                    in_channel,
                    dropout=dropout,
                    use_attention=True,
                    attention_head=attn_heads,
                    group_norm=group_norm
                ),
                EncResBlockWithAttention(
                    in_channel,
                    in_channel,
                    dropout=dropout,
                    group_norm=group_norm
                ),
            ]
        )

        #self.out = nn.Linear(channel * 4 * 8 * 8, 512)
        self.skip = nn.Linear(channel * 4 * 8 * 8, 512)

        t_dim = channel * 4
        t_patches = 8 * 8
        self.pos_embedding = nn.Parameter(torch.randn(1, t_patches + 1, t_dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, t_dim))

        self.transformer = Nystromformer(
            dim = t_dim,
            depth = 8,
            heads = 4,
            num_landmarks = 256
        )


    def forward(self, input):
        x = super().forward(input)
        skip = self.skip(x.flatten(start_dim=1))
        #x = x.flatten(start_dim=2)
        #x = self.to_patch_embedding(x)
        #print("ggg", x.shape)
        x = x.flatten(start_dim=2).permute(0, 2, 1)

        b, n, _ = x.shape
        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        #x = self.dropout(x)
        x = self.transformer(x)[:, 0]
        out = x + skip
        return out

In [31]:
encoder = EncoderV3(in_channel=3,
                  channel=128,
                  channel_multiplier=[1, 2, 2, 4, 4],
                  n_res_blocks=2,
                  attn_strides=[8, 16],
                  attn_heads=4,
                  dropout=0,
                  fold=1)

In [32]:
encoder(torch.randn(5, 3, 128, 128))

ggg torch.Size([5, 512]) torch.Size([5, 512])


tensor([[ 0.6818, -0.6286, -1.2552,  ...,  3.6689, -0.7544,  2.2791],
        [ 0.6940, -0.4346, -1.0720,  ...,  3.7116, -0.9664,  2.2288],
        [ 0.7017, -0.5301, -1.2037,  ...,  3.7276, -0.9122,  2.2274],
        [ 0.6755, -0.4856, -1.2206,  ...,  3.6690, -0.7638,  2.1814],
        [ 0.6901, -0.5135, -1.2251,  ...,  3.6426, -0.8047,  2.1940]],
       grad_fn=<AddBackward0>)