In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.autograd import Variable
from torch.nn.utils import weight_norm

In [3]:
%run utils.ipynb
#from utils import *
%run config.ipynb
#from config import *

In [4]:
class JointNet(nn.Module):
    def __init__(self):
        super(JointNet, self).__init__()
        self.num_class = 100
        self.dropout = 0.3
        # CNN Model
        self.cnn = CNNFeature()
        # RNN Model
        self.rnn = RNNFeature()

        if FUSION_MODE == 'concatenate':
            self.fusion = nn.Linear(2048, 1024)
            self.bn = nn.BatchNorm1d(1024)
            self.final = nn.Linear(1024, self.num_class)
        elif FUSION_MODE == 'element_wise':
            self.fusion = nn.Linear(1024, 1024)
            self.bn = nn.BatchNorm1d(1024)
            self.final = nn.Linear(1024, self.num_class)
        elif FUSION_MODE == 'bilinear':
            print('bilinear')
            self.reduce1 = nn.Linear(1024, 256)
            self.reduce2 = nn.Linear(1024, 256)
            self.fusion = nn.Bilinear(256, 256, 256)
            self.bn = nn.BatchNorm1d(256)
            self.final = nn.Linear(256, self.num_class)
        elif FUSION_MODE == 'relu_sum':
            self.final = nn.Linear(1024, self.num_class)
        elif FUSION_MODE == 'crazy_fusion':
            self.fusion = nn.Linear(1024, 1024)
            self.bn = nn.BatchNorm1d(1024)
            self.final = nn.Linear(1024, self.num_class)
        
        self.relu = nn.ReLU()
    
    def forward(self, image_input, sequence_input):
        # input-> image: [batch_size, 28, 28]
        # input-> sequence: [batch_size, MAXI_LENGTH, 3]
        # output size [batch_size, 100]
        
        cnn_feat = self.cnn(image_input)
        rnn_feat = self.rnn(sequence_input)

        if FUSION_MODE == 'concatenate':
            feat = torch.cat((cnn_feat, rnn_feat), dim=1)
            feat = self.fusion(feat)
            feat = self.bn(feat)
            output = self.final(feat)
        elif FUSION_MODE == 'element_wise':
            feat = self.relu(cnn_feat + rnn_feat)
            feat = self.fusion(feat)
            feat = self.bn(feat)
            output = self.final(feat)
        elif FUSION_MODE == 'bilinear':
            cnn_feat = self.reduce1(cnn_feat)
            rnn_feat = self.reduce2(rnn_feat)
            feat = self.fusion(cnn_feat, rnn_feat)
            feat = self.bn(feat)
            output = self.final(feat)
        elif FUSION_MODE == 'relu_sum':
            feat = self.relu(cnn_feat + rnn_feat)
            output = self.final(feat)
        elif FUSION_MODE == 'crazy_fusion':
            feat = self.relu(cnn_feat + rnn_feat) - (cnn_feat - rnn_feat)**2
            feat = self.fusion(feat)
            feat = self.bn(feat)
            output = self.final(feat)
        
        return output

In [5]:
class CNNFeature(nn.Module):
    def __init__(self):
        super(CNNFeature, self).__init__()
        self.num_class = 100
        self.img_size = 28
        self.maxpool = nn.MaxPool2d(2, stride=2)
        self.relu = nn.ReLU()
        
        self.pre_conv = nn.Conv2d(1, 4, kernel_size=3, stride=1, padding=1)
        self.block1 = InceptionBlock(4, 8)
        self.block2 = InceptionBlock(8, 16)
        
        self.fc1 = nn.Linear(16 * 7 * 7, 1024)
        self.bn = nn.BatchNorm1d(1024)
        
    def forward(self, input_feat):
        # input size [batch_size, 28, 28]
        # output size [batch_size, 345]
        input_feat = input_feat.view(-1, 1, self.img_size, self.img_size)
        feat = self.pre_conv(input_feat)
        feat = self.block1(feat)
        feat = self.maxpool(feat)
        feat = self.block2(feat)
        feat = self.relu(feat)
        feat = self.maxpool(feat)
        
        feat = self.fc1(feat.view(-1, 16 * 7 * 7))
        feat = self.bn(feat)
        return feat

In [6]:
class RNNFeature(nn.Module):
    def __init__(self):
        super(RNNFeature, self).__init__()
        self.num_class = 100
        self.tokens_range = 256
        self.feature_size = 512
        self.max_length = 200
        self.dropout = 0.3
        
        # feature preprocessing
        self.pre_fc = nn.Linear(3, 512)
        self.pre_bn = nn.BatchNorm1d(512)
        # lstm
        self.lstm = nn.GRU(input_size=512,
                           hidden_size=int(self.feature_size/2),
                           num_layers=2,
                           dropout = self.dropout,
                           batch_first=True)
        self._init_lstm()
        # 
        self.fc1 = nn.Linear(self.feature_size, 1024)
        self.bn = nn.BatchNorm1d(1024)
    
    def _init_lstm(self):
        self._init_weight(self.lstm.weight_ih_l0)
        self._init_weight(self.lstm.weight_hh_l0)
        self._init_weight(self.lstm.weight_ih_l1)
        self._init_weight(self.lstm.weight_hh_l1)
        self.lstm.bias_ih_l0.data.zero_()
        self.lstm.bias_hh_l0.data.zero_()
        self.lstm.bias_ih_l1.data.zero_()
        self.lstm.bias_hh_l1.data.zero_()
    
    def _init_weight(self, weight):
        for w in weight.chunk(3, 0):
            init.xavier_uniform_(w)
    
    def forward(self, input_feat):
        # input size [batch_size, MAXI_LENGTH, 3]
        # output size [batch_size, 100]
        input_feat = input_feat.float()/255.0
        batch_size, max_length, input_size = input_feat.shape
        
        x = self.pre_fc(input_feat.view(-1, input_size))
        x = self.pre_bn(x)
        x = x.view(batch_size, max_length, -1)
        
        # lstm
        _, output = self.lstm(x)
        x = torch.transpose(output, 0, 1).contiguous().view(batch_size,-1)
        
        # fc 
        x = self.fc1(x)
        output = self.bn(x)
        
        return output