In [1]:
import torch
import torch.nn as nn
import sys
sys.path.append("../model_utils/")
import custom_gru

In [2]:
class test_model(nn.Module):
    def __init__(self, sub_model, input_size, output_size):
        super(test_model, self).__init__()
        self.criteria = None
        self.sub_model = sub_model
        self.fc = nn.Linear(input_size=input_size, output_size=output_size)
    
    def forward(self, x, hidden):
        x = self.sub_model(x, hidden)
        x = self.fc(x)

In [3]:
lstm_pt = nn.LSTM(input_size=3, hidden_size=5, batch_first=True)
lstm_cell_pt = nn.LSTMCell(input_size=3, hidden_size=5)

In [4]:
import copy 

def lstm_initialization(lstm_model, seed = 9):
    model = copy.deepcopy(lstm_model)
    torch.manual_seed(seed)
    for name, param in model.named_parameters():
        if 'bias' in name:
            nn.init.constant_(param, 0.3)
        elif 'weight' in name:
            nn.init.xavier_uniform_(param)
    return model
    

lstm_cell_pt = lstm_initialization(lstm_cell_pt) 
lstm_pt = lstm_initialization(lstm_pt) 
# test 1. make sure weights are initialized the same
def parameter_checking(model_a, model_b):
    for param_pair in zip(model_a.parameters(), model_b.parameters()):
        assert torch.all(torch.eq(*param_pair))

In [11]:
test

tensor([[[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]])

In [13]:
test = torch.zeros([1,3,5]).requires_grad_()

In [32]:
# test 2 make sure outputs are the same

input = torch.rand(3, 1, 3)
target = torch.empty(3, dtype=torch.long).random_(5)
initial_hidden_pt = [torch.zeros([1, 3, 5]).requires_grad_(), torch.zeros([1, 3, 5]).requires_grad_()]
initial_hidden_cell_pt = [torch.zeros([3, 5]).requires_grad_() for _ in initial_hidden_pt]

lstm_pt = nn.LSTM(input_size=3, hidden_size=5, batch_first=True)
lstm_cell_pt = nn.LSTMCell(input_size=3, hidden_size=5)

lstm_pt = lstm_initialization(lstm_pt) 
lstm_cell_pt = lstm_initialization(lstm_cell_pt) 


output_pt = lstm_pt(input, initial_hidden_pt)
output_cell_pt = lstm_cell_pt(torch.squeeze(input, 1) , initial_hidden_cell_pt)

assert torch.all(torch.eq(output_cell_pt[0], output_pt[1][0]))

# make sure the gradient are the same

hyp_cell_pt = output_cell_pt[0]
hyp_pt = output_pt[0]
hyp_pt = torch.squeeze(hyp_pt, 1)
criteria = nn.CrossEntropyLoss()


def get_gradient(model, criteria, hyp, ref):
    loss = criteria(hyp, ref)
    loss.backward()
    
    
get_gradient(lstm_cell_pt, criteria, hyp_cell_pt, ref=target)
get_gradient(lstm_pt, criteria, hyp_pt, ref=target)
for param_pair in zip(lstm_cell_pt.parameters(), lstm_pt.parameters()):
    assert torch.all(torch.eq(param_pair[0].grad, param_pair[1].grad))

tensor([[[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]], requires_grad=True)

In [33]:
initial_hidden_pt[0].grad

tensor([[[-0.0112,  0.0251,  0.0491,  0.0502,  0.0306],
         [ 0.0422, -0.0152, -0.0086, -0.0186,  0.0391],
         [ 0.0008, -0.0027,  0.0045, -0.0043, -0.0243]]])

In [34]:
initial_hidden_cell_pt[0]

tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]], requires_grad=True)

In [35]:
initial_hidden_cell_pt[0].grad

tensor([[-0.0112,  0.0251,  0.0491,  0.0502,  0.0306],
        [ 0.0422, -0.0152, -0.0086, -0.0186,  0.0391],
        [ 0.0008, -0.0027,  0.0045, -0.0043, -0.0243]])

In [6]:
class LSTMCell_format(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(LSTMCell_format, self).__init__()
        self.lstm = nn.LSTMCell(input_size=input_size, hidden_size=hidden_size)
        
    def forward(self, input, hidden):
        h, c = self.lstm(input, hidden)
        return h, [h,c]
        

In [7]:
target.shape

torch.Size([3])

In [8]:
length = 10000
input = torch.rand(3, length, 3)
target = torch.empty([3,length], dtype=torch.long).random_(5)
initial_hidden = [torch.zeros([1, 3, 5]), torch.zeros([1, 3, 5])]

lstm_custom = custom_gru.RNNLayer(LSTMCell_format, input_size=3, hidden_size=5)
lstm_custom = lstm_initialization(lstm_custom)
output_custom = lstm_custom(nn.utils.rnn.pack_padded_sequence(input, [length, length, length], batch_first=True), [torch.squeeze(i, 0) for i in initial_hidden])

lstm_pt = nn.LSTM(input_size=3, hidden_size=5, batch_first=True)
lstm_pt = lstm_initialization(lstm_pt) 
output_pt = lstm_pt(input, initial_hidden)

# lstm_cell_pt = nn.LSTMCell(input_size=3, hidden_size=5)
# lstm_cell_pt = lstm_initialization(lstm_cell_pt) 
# output_cell_pt = lstm_cell_pt(torch.squeeze(input, 1) , [torch.squeeze(i, 0) for i in initial_hidden])


assert torch.all(torch.eq(output_custom[1][1], output_pt[1][1]))
output_custom = nn.utils.rnn.pad_packed_sequence(output_custom[0], batch_first=True)[0]
assert torch.all(torch.eq(output_custom, output_pt[0]))

hyp_custom = torch.transpose(output_custom, 2, 1)

hyp_pt = output_pt[0]
hyp_pt = torch.transpose(hyp_pt, 2, 1)

# hyp_cell_pt = torch.unsqueeze(output_cell_pt[0], 2)
# assert torch.all(torch.eq(hyp_cell_pt, hyp_pt)), "{}, {}".format(hyp_cell_pt.shape, hyp_pt.shape)

criteria = nn.CrossEntropyLoss(reduction="mean")
    
get_gradient(lstm_pt, criteria, hyp_pt, ref=target)
# get_gradient(lstm_cell_pt, criteria, hyp_cell_pt, ref=target)
get_gradient(lstm_custom, criteria, hyp_custom, ref=target)

# for param_pair in zip(lstm_cell_pt.parameters(), lstm_pt.parameters()):
#     assert torch.all(torch.eq(param_pair[0].grad, param_pair[1].grad)), "{}, {}".format(param_pair[0].grad, param_pair[1].grad)
    
for param_pair in zip(lstm_custom.parameters(), lstm_pt.parameters()):
    assert torch.allclose(param_pair[0].grad, param_pair[1].grad, atol=1e-08), "{}".format(param_pair[0].grad - param_pair[1].grad)

In [11]:
for param_pair in zip(lstm_custom.parameters(), lstm_pt.parameters()):
    print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
    print(torch.std(param_pair[0].grad - param_pair[1].grad))

++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
tensor(9.4702e-09)
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
tensor(0.)
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
tensor(1.0060e-08)
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
tensor(0.)


In [None]:
output_pt[1][1]