In [1]:
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
import sys
sys.path.insert(0, 'models')

In [10]:
%reload_ext autoreload
%autoreload 1
%aimport rnn, lstm
from lstm import LSTM

In [11]:
model = LSTM(128, 128, 5)

In [7]:
model

LSTM(
  (layers): ModuleList(
    (0-4): 5 x LSTMBase(
      (sigmoid): Sigmoid()
      (tanh): Tanh()
      (fc_ii): Linear(in_features=128, out_features=128, bias=True)
      (fc_hi): Linear(in_features=128, out_features=128, bias=True)
      (fc_if): Linear(in_features=128, out_features=128, bias=True)
      (fc_hf): Linear(in_features=128, out_features=128, bias=True)
      (fc_ig): Linear(in_features=128, out_features=128, bias=True)
      (fc_hg): Linear(in_features=128, out_features=128, bias=True)
      (fc_io): Linear(in_features=128, out_features=128, bias=True)
      (fc_ho): Linear(in_features=128, out_features=128, bias=True)
    )
  )
  (dropout): Dropout(p=0.2, inplace=False)
)

Here we test the inference time cost of two ways of implementing LSTM: one is to create a nn.Linear submodule for every transformation; the other is to create big chunks of weights first and split them into smaller chunks. Many chose to implement via the second way but it's actually much slower.  
In this experiment, the model would create both nn.Linear submodules and big chunks of weights and split them into smaller weights, whether split is enabled or not. So the difference of inference time is purely related to the computational efficiency of these two ways since the number of parameters are the same.  

In [16]:
t0 = time.time()
for _ in range(10):
    output, state = model(torch.randn([1000, 100, 128]))
t1 = time.time()
print(t1-t0)

18.103795528411865


In [17]:
t0 = time.time()
for _ in range(10):
    output, state = model(torch.randn([1000, 100, 128]), split=False)
t1 = time.time()
print(t1-t0)

8.983369588851929


We show that whether using split or not, the number of parameters are the same both way.

In [21]:
linear_param_nums = 0
parameter_param_nums = 0
for pn, p in model.named_parameters():
    if pn.split('.')[2].startswith('fc'):
        linear_param_nums += p.numel()
    else:
        parameter_param_nums += p.numel()

In [22]:
linear_param_nums

660480

In [23]:
parameter_param_nums

660480