In [1]:
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import PatchEmbed

In [2]:
randn_im_th_cu = torch.rand((1, 3, 512, 512), device='cuda:0')

In [3]:
patchembed = PatchEmbed(img_size=512, patch_size=16, embed_dim=512).to('cuda:0')

In [4]:
embedding = patchembed(randn_im_th_cu)

In [6]:
from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)   # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class Block(nn.Module):

    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

class StyleVisionTransformer(nn.Module):
    def __init__(self, embed_dim=128, depth=12, num_heads=16, mlp_ratio=4.,
                qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
                norm_layer=None, act_layer=None):
        """
        Args:
            img_size (int, tuple): input image size
            patch_size (int, tuple): patch size
            in_chans (int): number of input channels
            embed_dim (int): embedding dimension
            depth (int): depth of transformer
            num_heads (int): number of attention heads
            mlp_ratio (int): ratio of mlp hidden dim to embedding dim
            qkv_bias (bool): enable bias for qkv if True
            drop_rate (float): dropout rate
            attn_drop_rate (float): attention dropout rate
            drop_path_rate (float): stochastic depth rate
            embed_layer (nn.Module): patch embedding layer
            norm_layer: (nn.Module): normalization layer
            weight_init: (str): weight init scheme
        """
        super().__init__()
        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
        act_layer = act_layer or nn.GELU

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
        self.blocks = nn.Sequential(*[
            Block(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
                attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
            for i in range(depth)])
        self.norm = norm_layer(embed_dim)

        self.linear_channels = [
            256,
            256,
            256,
            256,
            256,
            256,
            128,
            128,
            64,
            64,
            32,
            32,
            16,
            16,
        ]
        self.linear_dims = [
            8,
            8,
            16,
            16,
            32,
            32,
            64,
            64,
            128,
            128,
            256,
            256,
            512,
            512    
        ]

        self.bottlenecks = nn.Sequential(*[
            nn.Linear(512, 8**2) for _ in range(len(self.linear_channels))
        ])

        linear_to_convs = []
        for linear_channel, linear_dim in zip(self.linear_channels, self.linear_dims):
            linear_to_convs.append(nn.Linear(1024//linear_channel*64, linear_dim**2))
        self.linear_to_convs = nn.Sequential(*linear_to_convs)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
        B = x.shape[0]

        conditions = []
        for block, bottleneck, linear_to_conv, linear_channel, linear_dim in zip(self.blocks, self.bottlenecks, self.linear_to_convs, self.linear_channels, self.linear_dims):
            x = block(x)
            x = bottleneck(x)
            condition = linear_to_conv(x.tranpose(B, linear_channel, -1))
            condition = condition.reshape(B, linear_channel, linear_dim, linear_dim)
            conditions.append(condition)
        x = self.norm(x)
        return x, conditions


In [7]:
stylevit = StyleVisionTransformer().to('cuda:0')

In [8]:
stylevit

StyleVisionTransformer(
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=128, out_features=384, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=128, out_features=128, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=128, out_features=512, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=512, out_features=128, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (1): Block(
      (norm1): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=128, out_features=384, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_featu