In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.autograd import Variable
from torch.nn.utils import weight_norm

In [3]:
%run utils.ipynb
#from utils import *
%run config.ipynb
#from config import *

In [4]:
class SequenceNet(nn.Module):
    def __init__(self):
        super(SequenceNet, self).__init__()
        self.num_class = 100
        self.tokens_range = 256
        self.feature_size = 512
        self.max_length = 200
        self.dropout = 0.3
        
        # feature preprocessing
        self.pre_fc = nn.Linear(3, 512)
        self.pre_bn = nn.BatchNorm1d(512)
        # lstm
        self.lstm = nn.GRU(input_size=512,
                           hidden_size=int(self.feature_size/2),
                           num_layers=2,
                           dropout = self.dropout,
                           batch_first=True)
        self._init_lstm()
        # 
        self.fc1 = nn.Linear(self.feature_size, 1024)
        self.bn = nn.BatchNorm1d(1024)
        self.fc2 = nn.Linear(1024, self.num_class)
        print('sequence_model')
    
    def _init_lstm(self):
        self._init_weight(self.lstm.weight_ih_l0)
        self._init_weight(self.lstm.weight_hh_l0)
        self._init_weight(self.lstm.weight_ih_l1)
        self._init_weight(self.lstm.weight_hh_l1)
        self.lstm.bias_ih_l0.data.zero_()
        self.lstm.bias_hh_l0.data.zero_()
        self.lstm.bias_ih_l1.data.zero_()
        self.lstm.bias_hh_l1.data.zero_()
    
    def _init_weight(self, weight):
        for w in weight.chunk(3, 0):
            init.xavier_uniform_(w)
    
    def forward(self, _, input_feat):
        # input size [batch_size, MAXI_LENGTH, 3]
        # output size [batch_size, 100]
        input_feat = input_feat.float()/255.0
        batch_size, max_length, input_size = input_feat.shape
        
        x = self.pre_fc(input_feat.view(-1, input_size))
        x = self.pre_bn(x)
        x = x.view(batch_size, max_length, -1)
        
        # lstm
        _, output = self.lstm(x)
        x = torch.transpose(output, 0, 1).contiguous().view(batch_size,-1)
        
        # fc 
        x = self.fc1(x)
        x = self.bn(x)
        output = self.fc2(x)
        
        return output