In [92]:
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange
from einops.layers.torch import Rearrange


In [93]:
def conv_3x3_bn(inp, oup, image_size, downsample=False):
    stride = 1 if downsample == False else 2
    return nn.Sequential(
        nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
        nn.BatchNorm2d(oup),
        nn.GELU()
    )
class PreNorm(nn.Module):
    def __init__(self, dim, fn, norm):
        super().__init__()
        self.norm = norm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

# PreNorm(inp, self.attn, nn.LayerNorm)

class SE(nn.Module):
    def __init__(self, inp, oup, expansion=0.25):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(oup, int(inp * expansion), bias=False),
            nn.GELU(),
            nn.Linear(int(inp * expansion), oup, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)


class MBConv(nn.Module):
    def __init__(self, inp, oup, image_size, downsample=False, expansion=4):
        super().__init__()
        self.downsample = downsample
        stride = 1 if self.downsample == False else 2
        hidden_dim = int(inp * expansion)

        if self.downsample:
            self.pool = nn.MaxPool2d(3, 2, 1)
            self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False)

        if expansion == 1:
            self.conv = nn.Sequential(
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
                          1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
        else:
            self.conv = nn.Sequential(
                # pw
                # down-sample in the first conv
                nn.Conv2d(inp, hidden_dim, 1, stride, 0, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, 1, 1,
                          groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                SE(inp, hidden_dim),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
        
        self.conv = PreNorm(inp, self.conv, nn.BatchNorm2d)

    def forward(self, x):
        if self.downsample:
            return self.proj(self.pool(x)) + self.conv(x)
        else:
            return x + self.conv(x)


class Attention(nn.Module):
    def __init__(self, inp, oup, image_size, heads=8, dim_head=32, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == inp)

        self.ih, self.iw = image_size

        self.heads = heads
        self.scale = dim_head ** -0.5

        # parameter table of relative position bias
        self.relative_bias_table = nn.Parameter(
            torch.zeros((2 * self.ih - 1) * (2 * self.iw - 1), heads))

        coords = torch.meshgrid((torch.arange(self.ih), torch.arange(self.iw)))
        coords = torch.flatten(torch.stack(coords), 1)
        relative_coords = coords[:, :, None] - coords[:, None, :]

        relative_coords[0] += self.ih - 1
        relative_coords[1] += self.iw - 1
        relative_coords[0] *= 2 * self.iw - 1
        relative_coords = rearrange(relative_coords, 'c h w -> h w c')
        relative_index = relative_coords.sum(-1).flatten().unsqueeze(1)
        self.register_buffer("relative_index", relative_index)

        self.attend = nn.Softmax(dim=-1)
        self.to_qkv = nn.Linear(inp, inner_dim * 3, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, oup),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(
            t, 'b n (h d) -> b h n d', h=self.heads), qkv)
        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        # Use "gather" for more efficiency on GPUs
        relative_bias = self.relative_bias_table.gather(
            0, self.relative_index.repeat(1, self.heads))
        relative_bias = rearrange(
            relative_bias, '(h w) c -> 1 c h w', h=self.ih*self.iw, w=self.ih*self.iw)
        dots = dots + relative_bias

        attn = self.attend(dots)
        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.to_out(out)
        return out

class Transformer(nn.Module):
    def __init__(self, inp, oup, image_size, heads=8, dim_head=32, downsample=False, dropout=0.):
        super().__init__()
        hidden_dim = int(inp * 4)
        self.ih, self.iw = image_size
        self.downsample = downsample

        if self.downsample:
            self.pool1 = nn.MaxPool2d(3, 2, 1)
            self.pool2 = nn.MaxPool2d(3, 2, 1)
            self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False)

        self.attn = Attention(inp, oup, image_size, heads, dim_head, dropout)
        self.ff = FeedForward(oup, hidden_dim, dropout)

        self.attn = nn.Sequential(
            Rearrange('b c ih iw -> b (ih iw) c'),
            PreNorm(inp, self.attn, nn.LayerNorm),
            Rearrange('b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw)
        )

        self.ff = nn.Sequential(
            Rearrange('b c ih iw -> b (ih iw) c'),
            PreNorm(oup, self.ff, nn.LayerNorm),
            Rearrange('b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw)
        )

    def forward(self, x):
        if self.downsample:
            x = self.proj(self.pool1(x)) + self.attn(self.pool2(x))
        else:
            x = x + self.attn(x)
        x = x + self.ff(x)
        return x

    
class TransformerAlt(nn.Module):
    def __init__(self, inp, oup, image_size, heads=8, dim_head=32, downsample=False, dropout=0.):
        super().__init__()
        hidden_dim = int(inp * 4)
        self.ih, self.iw = image_size
        self.downsample = downsample

        if self.downsample:
            self.pool1 = nn.MaxPool2d(3, 2, 1)
            self.pool2 = nn.MaxPool2d(3, 2, 1)
            self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False)

        self.attn = Attention(inp, oup, image_size, heads, dim_head, dropout)
        self.ff = FeedForward(oup, hidden_dim, dropout)
        
        self.pre_attn = Rearrange('b c ih iw -> b (ih iw) c')
        self.attn = nn.Sequential(
            PreNorm(inp, self.attn, nn.LayerNorm)
        )
        self.post_attn = Rearrange('b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw)

        self.ff = nn.Sequential(
            Rearrange('b c ih iw -> b (ih iw) c'),
            PreNorm(oup, self.ff, nn.LayerNorm),
            Rearrange('b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw)
        )

    def forward(self, x):
        if self.downsample:
            tmp = self.attn(self.pre_attn(self.pool2(x)))
            print(tmp.shape)
            x = self.proj(self.pool1(x)) + self.post_attn(tmp)
        else:
            x = x + self.post_attn(self.attn(self.pre_attn(x)))
        x = x + self.ff(x)
        return x

    

In [94]:
x = torch.rand(1, 128, 112, 112)

In [95]:
x.view(x.shape[0], -1, x.shape[1]).shape

torch.Size([1, 12544, 128])

In [96]:
class CrossConv(pl.LightningModule):
    def __init__(self, image_size, in_channels, num_blocks, channels, num_classes=100):
        super().__init__()
        ih, iw = image_size
        block = {'C': MBConv, 'T': Transformer}

        self.s0 = self._make_layer(
            conv_3x3_bn, in_channels, channels[0], num_blocks[0], (ih // 2, iw // 2))
        
        self.s1_c = self._make_layer(
            MBConv, channels[0], channels[1], num_blocks[1], (ih // 4, iw // 4))
        self.s1_a = self._make_layer(
            Transformer, channels[0], channels[1], num_blocks[1], (ih // 4, iw // 4))
        
        self.s2_c = self._make_layer(
            MBConv, channels[1], channels[2], num_blocks[2], (ih // 8, iw // 8))
        self.s2_a = self._make_layer(
            Transformer, channels[1], channels[2], num_blocks[2], (ih // 8, iw // 8))
            
        self.s3_c = self._make_layer(
            MBConv, channels[2], channels[3], num_blocks[3], (ih // 16, iw // 16))
        self.s3_a = self._make_layer(
            Transformer, channels[2], channels[3], num_blocks[3], (ih // 16, iw // 16))
        
        self.s4_c = self._make_layer(
            MBConv, channels[3], channels[4], num_blocks[4], (ih // 32, iw // 32))
        self.s4_a = self._make_layer(
            Transformer, channels[3], channels[4], num_blocks[4], (ih // 32, iw // 32))

        self.pool = nn.AvgPool2d(ih // 32, 1)
        self.fc = nn.Linear(channels[-1], num_classes, bias=False)

    def forward(self, x):
        x = self.s0(x)
        print("IN:", x.shape)
        x_c = self.s1_c(x)
        x_a = self.s1_a(x)
        print("OUTC: ", x_c.shape)
        print("OUTA: ", x_a.shape)
        # send x_c to x_a
        x_a += x_c
        
        x_c = self.s2_c(x_c)
        x_a = self.s2_a(x_a)
        
        x_a += x_c
        
        x_c = self.s3_c(x_c)
        x_a = self.s3_a(x_a)
        
        x_a += x_c
        
        x_c = self.s4_c(x_c)
        x_a = self.s4_a(x_a)
        
        x_a += x_c

        x = self.pool(x_a).view(-1, x_a.shape[1])
        x = self.fc(x)
        return x

    def _make_layer(self, block, inp, oup, depth, image_size):
        layers = nn.ModuleList([])
        for i in range(depth):
            if i == 0:
                layers.append(block(inp, oup, image_size, downsample=True))
            else:
                layers.append(block(oup, oup, image_size))
        return nn.Sequential(*layers)


In [97]:
model = CrossConv((224,224), 3, [2,2,3,5,2], [64,96, 192, 384, 768]).cuda()

In [98]:
model(torch.rand(1,3,224,224).cuda())

IN: torch.Size([1, 64, 112, 112])
OUTC:  torch.Size([1, 96, 56, 56])
OUTA:  torch.Size([1, 96, 56, 56])


tensor([[ 1.8699e-01,  1.8376e-01,  6.4260e-02,  4.9565e-01,  2.0549e-01,
          6.4272e-01, -1.1592e+00,  8.1238e-01,  9.9985e-02, -3.1417e-01,
          4.7120e-01,  7.2618e-01,  7.3724e-01,  3.0669e-01,  3.3929e-01,
          2.6543e-01, -6.0442e-01,  6.6925e-01, -1.5300e-01,  3.1523e-01,
          4.0570e-01, -4.7409e-01,  3.8161e-01, -2.1652e-02,  4.4554e-01,
         -4.4789e-01, -7.5438e-03,  3.1133e-01, -2.1799e-01, -1.5996e+00,
         -1.0367e+00, -3.4649e-01,  5.1882e-01, -5.7754e-02, -1.3137e-01,
         -1.3219e-01,  3.8770e-01,  9.3633e-01,  1.5024e-01,  4.4008e-01,
          7.0525e-02,  4.0046e-01, -1.9880e-02, -2.7394e-01,  3.9517e-01,
          2.9172e-01, -3.2222e-01,  3.5538e-01, -1.5972e-01, -5.8321e-01,
         -6.7887e-01,  1.9498e-01, -7.2924e-01,  2.6174e-01,  4.1293e-02,
         -2.6217e-01,  3.0124e-01, -8.2174e-01,  1.2576e-01,  1.2329e+00,
          8.3843e-01, -2.1704e-01,  2.6598e-01,  7.0405e-01,  1.7839e-01,
         -3.5959e-01, -4.2120e-01, -1.