In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from Utils.U_net_Modules import *

In [9]:
model = torchvision.models.resnet152(pretrained = False)



In [10]:
new_model = nn.Sequential(*list(model.children())[:-2])

In [11]:
print(new_model)

Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)


In [2]:
inp = torch.randn((3,224,224))

In [13]:
out = new_model(inp.unsqueeze(0))

In [14]:
out.shape

torch.Size([1, 2048, 7, 7])

In [15]:
out = out.flatten()

In [16]:
out.shape

torch.Size([100352])

In [3]:
class BayarConv(nn.Module):

    def __init__(self, in_channel, out_channel, kernel_size = 5, stride = 1, padding = 0):
        super(BayarConv, self).__init__()

        self.in_channel = in_channel
        self.out_channel = out_channel
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.mid_ele = torch.ones(self.in_channel, self.out_channel, 1) * -1.000

        self.kernel = nn.Parameter(torch.randn((self.in_channel,self.out_channel,kernel_size**2-1)), requires_grad= True)


    def constraint(self):
        self.kernel.data = self.kernel.data.div(self.kernel.data.sum(dim = -1, keepdim=True))
        center = self.kernel_size**2//2
        real_kernel = torch.cat((self.kernel[:,:,:center], self.mid_ele, self.kernel[:,:,center:]),dim=2)
        real_kernel = real_kernel.reshape((self.out_channel, self.in_channel, self.kernel_size, self.kernel_size))
        return real_kernel
    
    def forward(self, x):
        return F.conv2d(x, self.constraint(), stride = self.stride, padding=self.padding)


class BayarConvBlock(nn.Module):

    def __init__(self, in_channel = 12, out_channel = 1024, stride = 1, padding = 0):

        super(BayarConvBlock, self).__init__()

        self.in_channel = in_channel
        self.out_channel = out_channel
        self.padding = padding
        self.stride = stride

        self.mod = nn.Sequential(
            nn.Conv2d(self.in_channel,64,7,stride = 2),
            nn.MaxPool2d(kernel_size= 2),
            nn.ReLU(inplace= True),
            nn.Conv2d(64,128,3,stride = self.stride),
            nn.MaxPool2d(kernel_size= 2),
            nn.ReLU(inplace= True),
            nn.Conv2d(128,256, 5,stride = self.stride),
            nn.MaxPool2d(kernel_size= 2),
            nn.ReLU(inplace= True),
            nn.Conv2d(256,self.out_channel, 3,stride = self.stride),
            nn.MaxPool2d(kernel_size= 2),
            nn.ReLU(inplace= True)
        )

    def forward(self, x):

        return self.mod(x)


class BayarBlock(nn.Module):

    def __init__(self, in_channel = 3, out_channel = 1024, kernel_size = 3, stride = 1, padding = 0):

        super(BayarBlock, self).__init__()

        self.in_channel = in_channel
        self.out_channel = out_channel
        self.padding = padding
        self.kernel_size = kernel_size
        self.stride = stride

        self.Bayarconv = BayarConv(self.in_channel, 12)
        self.rest = BayarConvBlock(12, self.out_channel)

    def forward(self, x):

        x = self.Bayarconv(x)
        return self.rest(x)

In [4]:
mod = BayarBlock(3, 1024)

In [7]:
out1 = mod(inp.unsqueeze(0))

In [8]:
out1.shape

torch.Size([1, 1024, 4, 4])

In [None]:
class Concat(nn.Module):

    def __init__(self):
        pass