In [3]:
import torch
import torchvision
import torchvision.transforms as transforms

import torch.nn as nn
import torch.nn.functional as F

Stem 영역을 세 구간으로 나누어 선언한다

In [6]:
class stem1(nn.Module):
    def __init__(self):
        self.a1 = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride = 2),
            nn.ReLU(True),
            nn.Conv2d(32, 32, 3, stride = 1),
            nn.ReLU(True),
            nn.Conv2d(32, 64, 3, stride = 1, padding = 1),
            nn.ReLU(True),
        )
        self.a2 = nn.Sequential(
            nn.Conv2d(64, 96, 3, stride = 2),
            nn.ReLU(True),
        )
        self.a3 = nn.Sequential(
            nn.MaxPool2d(3, stride = 2),
        )
    def forward(self,x):
        y1 = self.a1(x)
        out1 = self.a2(y1)
        out2 = self.a3(y1)
        return torch.cat([out1,out2],1)       
        

In [7]:
class stem2(nn.Module):
    def __init__(self):
        self.a1 = nn.Sequential(
            nn.Conv2d(160, 64, 1, stride = 1),
            nn.ReLU(True),
            nn.Conv2d(64, 64, 7, 1, stride = 1),
            nn.ReLU(True),
            nn.Conv2d(64, 64, 1, 7, stride = 1),
            nn.ReLu(True),
            nn.Conv2d(64, 96, 3, stride = 1)
        )
        self.a2 = nn.Sequential(
            nn.Conv2d(160, 64, 1, stride = 1),
            nn.ReLU(True),
            nn.Conv2d(64, 96, 3, stride = 1),
            nn.ReLU(True),
        )
    def forward(self, x):
        y1 = self.a1(x)
        y2 = self.a2(x)
        return torch.cat([y1,y2],1)

In [9]:
class stem3(nn.Module):
    def __init__(self):
        self.a1 = nn.Sequential(
            nn.Conv2d(192, 192, 3, stride = 2),
            nn.ReLU(True),
        )
        self.a2 = nn.Sequential(
            nn.MaxPool2d(3, stride = 2)
        )
    def forward(self,x):
        y1 = self.a1(x)
        y2 = self.a2(x)
        return torch.cat([y1,y2],1)

4 x Inception-A 영역을 선언한다

In [13]:
class InceptionA(nn.Module):
    def __init__(self):
        self.a1 = nn.Sequential(
            nn.Conv2d(384, 64, 1, stride = 1),
            nn.ReLU(True),
            nn.Conv2d(64, 96, 3, stride = 1),
            nn.ReLU(True),
            nn.Conv2d(96, 96, 3, stride = 1),
            nn.ReLU(True),
        )
        self.a2 = nn.Sequential(
            nn.Conv2d(384, 64, 1, stride = 1),
            nn.ReLU(True),
            nn.Conv2d(64, 96, 3, stride = 1),
            nn.ReLU(True),
        )
        self.a3 = nn.Sequential(
            nn.Conv2d(384, 96, 1, stride = 1),
            nn.ReLU(True),
        )
        self.a4 = nn.Sequential(
            nn.AvgPool2d(3, padding=1),
            nn.Conv2d(384, 96, 1, stride = 1),
            nn.ReLU(True),
        )
    def forward(self,x):
        y1 = self.a1(x)
        y2 = self.a2(x)
        y3 = self.a3(x)
        y4 = self.a4(x)
        return torch.cat([y1,y2,y3,y4],1)

ReductionA를 선언한다.

In [17]:
class ReductionA(nn.Module):
    def __init__(self):
        self.a1 = nn.Sequential(
            nn.Conv2d(384, 192, 1, stride = 1),
            nn.ReLU(True),
            nn.Conv2d(192, 224, 3, stride = 1),
            nn.ReLU(True),
            nn.Conv2d(224, 256, 3, stride = 2),
            nn.ReLU(True),
        )
        self.a2 = nn.Sequential(
            nn.Conv2d(384, 384, 3, stride = 2),
            nn.ReLU(True),
        )
        self.a3 = nn.Sequential(
            nn.MaxPool2d(3, stride = 2),
        )
    
    def forward(self, x):
        y1 = self.a1(x)
        y2 = self.a2(x)
        y3 = self.a3(x)
        return torch.cat([y1,y2,y3],1)
        

7 x Inception-B를 선언한다.

In [20]:
class InceptionB(nn.Module):
    def __init__(self):
        self.a1 = nn.Sequential(
            nn.Conv2d(1024, 192, 1, stride = 1),
            nn.ReLU(True),
            nn.Conv2d(192, 192, 1, 7, stride = 1),
            nn.ReLU(True),
            nn.Conv2d(192, 224, 7, 1, stride = 1),
            nn.ReLU(True),
            nn.Conv2d(224, 224, 1, 7, stride = 1),
            nn.ReLU(True),
            nn.Conv2d(224, 256, 7, 1, stride = 1),
            nn.ReLU(True),
        )
        self.a2 = nn.Sequential(
            nn.Conv2d(1024, 192, 1, stride = 1),
            nn.ReLU(True),
            nn.Conv2d(192, 224, 1, 7, stride = 1),
            nn.ReLU(True),
            nn.Conv2d(224, 256, 1, 7, stride = 1),
            nn.ReLU(True),
        )
        self.a3 = nn.Sequential(
            nn.Conv2d(1024, 384, 1, stride = 1),
            nn.ReLU(True),
        )
        self.a4 = nn.Sequential(
            nn.AvgPool2d(3, stride = 2),
            nn.Conv2d(1024, 128, 1, stride = 1),
            nn.ReLU(True),
        )
    
    def forward(self,x):
        y1 = self.a1(x)
        y2 = self.a2(x)
        y3 = self.a3(x)
        y4 = self.a4(x)
        return torch.cat([y1,y2,y3,y4],1)
        
    

ReductionB를 선언한다.

In [None]:
class ReductionB(nn.Module):
    def __init__(self):
        self.a1 = nn.Sequential(
            nn.Conv2d(1024, 256, 1, stride = 1),
            nn.ReLU(True),
            nn.Conv2d(256, 256, 1, 7, stride = 1),
            nn.ReLU(True),
            nn.Conv2d(256, 320, 7, 1, stride = 1),
            nn.ReLU(True),
            nn.Conv2d(320, 320, 3, stride = 2),
            nn.ReLU(True),
        )
        self.a2 = nn.Sequential(
            nn.Conv2d(1024, 192, 1, stride = 1),
            nn.ReLU(True),
            nn.Conv2d(192, 192, 3, stride = 2),
            nn.ReLU(True),
        )
        self.a3 = nn.Sequential(
            nn.MaxPool2d(3, stride = 2),
        )
    
    def forward(self, x):
        y1 = self.a1(x)
        y2 = self.a2(x)
        y3 = self.a3(x)
        return torch.cat([y1,y2,y3],1)


In [21]:
class Inceptionv4(nn.Module):
    def __init__(self):
        super(Inceptionv4,self).__init__()
        
        self.a1 = stem1()
        self.a2 = stem2()
        self.a3 = stem3()
        self.a4 = InceptionA()
        self.a5 = ReductionA()
        self.a6 = InceptionB()
        
        
    def forward(self,x):
        out = self.a1(x)
        out = self.a2(out)
        out = self.a3(out)
        out = self.a4(out)
        out = self.a5(out)
        out = self.a6(out)