You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
We've changed the ordering of our LSTM gates for the mLSTM cell. To make old models you've trained forward compatible with versions >= v0.5 please use the following script.
The google drive models will be updated automatically and ready for download in a couple of hours after this post.
import torch
import sys
import os
inpath = sys.argv[1]
outpath = sys.argv[2]
dp = os.path.dirname(outpath)
if dp!='' and not os.path.exists(dp):
os.makedirs(dp)
sd= torch.load(inpath)
if 'state_dict' in sd:
sd = sd['state_dict']
rnn = sd['encoder']['rnn']
def reorder_gates(param):
if param.size(0) == param.size(1) * 4:
i,f,o,j = torch.chunk(param, 4)
return torch.cat([i,f,j,o])
return param
is_mlstm = sum([p.find('w_mih')!=-1 for p in rnn.keys()])
if not is_mlstm:
print('no conversion needed')
exit()
for k, v in rnn.items():
if k.find('w_ih')!=-1 or k.find('w_hh')!=-1:
rnn[k] = reorder_gates(v)
torch.save(sd, outpath)
The text was updated successfully, but these errors were encountered:
raulpuric
changed the title
Convert version 0.4 to 0.5
Convert version 0.2 to 0.3
Apr 6, 2018
We've changed the ordering of our LSTM gates for the mLSTM cell. To make old models you've trained forward compatible with versions >= v0.5 please use the following script.
The google drive models will be updated automatically and ready for download in a couple of hours after this post.
The text was updated successfully, but these errors were encountered: