In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.autograd import Variable
from torch.nn.utils import weight_norm

In [2]:
%run utils.ipynb
#from utils import *
%run config.ipynb
#from config import *

In [3]:
class ImageNet(nn.Module):
    def __init__(self):
        super(ImageNet, self).__init__()
        self.num_class = 100
        self.img_size = 28
        self.dropout = 0.3
        self.maxpool = nn.MaxPool2d(2, stride=2)
        self.relu = nn.ReLU(self.dropout)
        
        self.pre_conv = nn.Conv2d(1, 4, kernel_size=3, stride=1, padding=1)
        self.block1 = InceptionBlock(4, 8)
        self.block2 = InceptionBlock(8, 16)
        
        self.fc1 = nn.Linear(16 * 7 * 7, 1024)
        self.bn = nn.BatchNorm1d(1024)
        self.fc2 = nn.Linear(1024, self.num_class)
        
    def forward(self, input_feat, _):
        # input size [batch_size, 28, 28]
        # output size [batch_size, 345]
        input_feat = input_feat.view(-1, 1, self.img_size, self.img_size)
        feat = self.pre_conv(input_feat)
        feat = self.block1(feat)
        feat = self.maxpool(feat)
        feat = self.block2(feat)
        feat = self.relu(feat)
        feat = self.maxpool(feat)
        
        feat = self.fc1(feat.view(-1, 16 * 7 * 7))
        feat = self.bn(feat)
        output = self.fc2(feat)
        return output

In [4]:
class InceptionBlock(nn.Module):
    def __init__(self, input_size, output_size):
        super(InceptionBlock, self).__init__()
        
        self.conv1 = nn.Conv2d(input_size, output_size, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(input_size, output_size, kernel_size=5, stride=1, padding=2)
        self.conv3 = nn.Conv2d(input_size, output_size, kernel_size=7, stride=1, padding=3)
        self.maxpool = nn.MaxPool2d(3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(input_size, output_size, kernel_size=1, stride=1, padding=0)
        
        self.output_conv = nn.Conv2d(output_size * 4, output_size, kernel_size=3, stride=1, padding=1)
        self.bn = nn.BatchNorm2d(output_size)
        self.relu = nn.ReLU()

    def forward(self, input_feat):
        x1 = self.conv1(input_feat)
        x2 = self.conv2(input_feat)
        x3 = self.conv3(input_feat)
        x4 = self.maxpool(input_feat)
        x4 = self.conv4(x4)
        
        x = torch.cat((x1,x2,x3,x4), dim=1)
        x = self.relu(x)
        x = self.output_conv(x)
        x = self.bn(x)
        return x