In [1]:
import torch

In [2]:
from torch import nn,optim
import torch.nn.functional as F

In [3]:
class ResBolck(nn.Module):
    def __init__(self,in_channels,out_channels,use1conv=True,stride=1):
        super(ResBolck,self).__init__()
        self.con1=nn.Conv2d(in_channels,out_channels,kernel_size=3,padding=1)
        self.con2=nn.Conv2d(out_channels,out_channels,kernel_size=3,padding=1,stride=stride)
        
        if use1conv:
            self.con3=nn.Conv2d(in_channels,out_channels,kernel_size=3,padding=1,stride=stride)
        else:
            self.con3=None
        
        self.bn1=nn.BatchNorm2d(out_channels)
        self.bn2=nn.BatchNorm2d(out_channels)
    def forward(self,x):
        y=F.relu(self.bn1(self.con1(x)))
        y=self.bn2(self.con2(y))
        if self.con3:
            x=self.con3(x)
        return F.relu(y+x)

In [4]:
blk=ResBolck(3,3)
x=torch.rand((4,3,6,6))
blk(x).shape

torch.Size([4, 3, 6, 6])

In [5]:
net=nn.Sequential(
    nn.Conv2d(1,64,kernel_size=7,stride=2,padding=3),
    nn.BatchNorm2d(64),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
)

In [6]:
def resnet_block(in_channels,out_channels,num_residuals,first_block=False):
    if first_block:
        assert in_channels==out_channels
    blk=[]
    for i in range(num_residuals):
        if i==0 and not first_block:
            blk.append(ResBolck(in_channels,out_channels,use1conv=True,stride=2))
        else:
            blk.append(ResBolck(out_channels,out_channels))
    return nn.Sequential(*blk)

In [7]:
net.add_module("resnet_block1",resnet_block(64,64,2,first_block=True))
net.add_module("resnet_block2",resnet_block(64,128,2,first_block=False))
net.add_module("resnet_block3",resnet_block(128,256,2,first_block=False))
net.add_module("resnet_block4",resnet_block(256,512,2,first_block=False))

In [8]:
class Global_Avg_Pool(nn.Module):
    def __init__(self):
        super(Global_Avg_Pool,self).__init__()
    def forward(self,x):
        x=x.view(x.shape[0],x.shape[1],-1)
        x=x.mean(dim=2,keepdim=True)
        x=x.unsqueeze(2)
        return x

In [9]:
class FlaternLayer(nn.Module):
    def __init__(self):
        super(FlaternLayer,self).__init__()
    def forward(self,x):
        x=x.view((x.shape[0],-1))
        print(x.shape)
        return x

In [10]:
net.add_module("globalavg",Global_Avg_Pool())
net.add_module("fc",nn.Sequential(FlaternLayer(),nn.Linear(512,10)))

In [11]:
x=torch.rand((1,1,224,224))
for name,layer in net.named_children():
    x=layer(x)
    print(name,x.shape)

0 torch.Size([1, 64, 112, 112])
1 torch.Size([1, 64, 112, 112])
2 torch.Size([1, 64, 112, 112])
3 torch.Size([1, 64, 56, 56])
resnet_block1 torch.Size([1, 64, 56, 56])
resnet_block2 torch.Size([1, 128, 28, 28])
resnet_block3 torch.Size([1, 256, 14, 14])
resnet_block4 torch.Size([1, 512, 7, 7])
globalavg torch.Size([1, 512, 1, 1])
torch.Size([1, 512])
fc torch.Size([1, 10])


In [12]:
import torchvision