In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.utils as vutils
from einops import rearrange,repeat
import math

In [17]:
class CONVFFN(nn.Module):

    def __init__(self,stride=1,in_channels=1,hidden_dimension=1):
        super(CONVFFN,self).__init__()
        self.stride=stride
        self.in_channels=in_channels
        self.hidden_dimension=hidden_dimension
        self.act=nn.GELU()

        # first block
        self.dwconv1=nn.Conv2d(
            in_channels= self.in_channels,
            out_channels= self.in_channels,
            kernel_size=3,
            stride=1,
            padding=1,
            groups=self.in_channels

        )
        self.pwconv1=nn.Conv2d(
            in_channels=self.in_channels,
            out_channels= self.hidden_dimension,
            kernel_size=1,
            stride=1
        )

        self.norm1=nn.BatchNorm2d(self.hidden_dimension)
        # second block 

        self.dwconv2=nn.Conv2d(
            in_channels=self.hidden_dimension,
            out_channels=self.hidden_dimension,
            kernel_size=2,
            stride=1,
            groups=self.hidden_dimension
        )
        self.pwconv2=nn.Conv2d(
            in_channels=self.hidden_dimension,
            out_channels=self.hidden_dimension*2,
            kernel_size=1,
            stride=1,
            padding=0
        )

        self.norm2=nn.BatchNorm2d(self.hidden_dimension*2)

        #third block

        self.dwconv3=nn.Conv2d(
            in_channels=self.hidden_dimension*2,
            out_channels=self.hidden_dimension*4,
            kernel_size=2,
            stride=1,
            padding=1,
            groups=self.hidden_dimension*2
        )

        self.pwconv3=nn.Conv2d(
            in_channels=self.hidden_dimension*4,
            out_channels=self.hidden_dimension*4,
            kernel_size=1,
            stride=1
        )

        self.norm3=nn.BatchNorm2d(self.hidden_dimension*4)

        # down sample
        self.downsample1=nn.Conv2d(
            in_channels=self.hidden_dimension*4,
            out_channels=self.hidden_dimension*4,
            kernel_size=3,
            padding=1,
            groups=self.hidden_dimension*4
        )

        self.pwdsample1=nn.Conv2d(
            in_channels=self.hidden_dimension*4,
            out_channels=self.in_channels,
            kernel_size=1
        )

        self.norm4=nn.BatchNorm2d(self.in_channels)

        self.apply(self._init_weights)
    
    def _init_weights(self,m):
        if isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()
    def forward(self,x,H,W):
        x=rearrange(x,'b (H W) C -> b C H W',H=H,W=W)   # B N C -> B C H W

        x1= self.dwconv1(x)                 # C -> 4*C
        x1=self.pwconv1(x1)

        x1=self.norm1(x1)
        x1=self.act(x1)


        x1=self.dwconv2(x1)                 # 4*C -> 6*C
        x1=self.pwconv2(x1)


        x1=self.norm2(x1)
        x1=self.act(x1)



        x1=self.dwconv3(x1)                 # 6*C -> 8*C
        x1=self.pwconv3(x1)
        
        x1=self.norm3(x1)
        x1=self.act(x1)



        x1=self.downsample1(x1)             # 8*C -> C
        x1=self.pwdsample1(x1)

        x1=self.norm4(x1)
        x1=self.act(x1)

        print(x1.shape,x.shape)
        x=x+x1

        x=rearrange(x,'b c H W -> b (H W) c')

        return x




In [24]:
model=CONVFFN(in_channels=512,hidden_dimension=512*4)
model(torch.rand(32,3136,512),56,56).shape

torch.Size([32, 512, 56, 56]) torch.Size([32, 512, 56, 56])
torch.Size([32, 4096, 55, 55]) torch.Size([32, 512, 56, 56])
torch.Size([32, 8192, 56, 56]) torch.Size([32, 512, 56, 56])
torch.Size([32, 512, 56, 56]) torch.Size([32, 512, 56, 56])


torch.Size([32, 3136, 512])

In [20]:
total_params = sum(p.numel() for p in model.parameters())

In [22]:
total_params/10e6

18.1939968

In [7]:
class CONVFFN(nn.Module):

    def __init__(self,stride=1):
        super(CONVFFN,self).__init__()
        self.conv1 = nn.Conv1d(
            in_channels=1,
            kernel_size=7,
            out_channels=1,
            stride=stride,
            groups=1,
            padding=0
        )

        self.conv2=nn.Conv1d(
            in_channels=1,
            out_channels=1,
            kernel_size=3,
            stride=1,
            padding=1
        )

        self.conv3=nn.Conv1d(
            in_channels=1,
            out_channels=1,
            kernel_size=1,
            stride=1,
            padding=1
        )
        self.conv4=nn.Conv1d(
            in_channels=1,
            out_channels=1,
            kernel_size=1,
            stride=1,
            padding=2
        )

        self.gelu=nn.GELU()
    
        self.init_weights()

    def init_weights(self):
        nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out',nonlinearity='leaky_relu')
        nn.init.kaiming_normal_(self.conv1.bias, mode='fan_out',nonlinearity='leaky_relu')
        nn.init.kaiming_normal_(self.conv2.weight, mode='fan_out',nonlinearity='leaky_relu')
        nn.init.kaiming_normal_(self.conv2.bias, mode='fan_out',nonlinearity='leaky_relu')
        nn.init.kaiming_normal_(self.conv3.weight, mode='fan_out',nonlinearity='leaky_relu')
        nn.init.kaiming_normal_(self.conv3.bias, mode='fan_out',nonlinearity='leaky_relu')
        nn.init.kaiming_normal_(self.conv4.weight, mode='fan_out',nonlinearity='leaky_relu')
        nn.init.kaiming_normal_(self.conv4.bias, mode='fan_out',nonlinearity='leaky_relu')
    
    def forward(self,x):
        out=self.conv1(x)
        out=self.gelu(out)
        out=self.conv2(out)
        out=self.gelu(out)
        out=self.conv3(out)
        out=self.gelu(out)
        out=self.conv4(out)
        return out

In [1]:
import random
c=0
r=0
for i in range(1_000_000):
    r=random.random()
    if r>0.5:
        c+=1
    else:
        r+=1

In [4]:
c/1_000_000

0.500611

In [5]:
r/1_000_000

1.4222659786170484e-06

In [6]:
c/r

351981.2802432166