Skip to content

Commit 078d4aa

Browse files
misc
1 parent aa0b509 commit 078d4aa

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

Core/CModelWrapper.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,10 @@ def load(self, folder=None, postfix='', embeddings=False):
8787
self._model.load_weights(path)
8888
if embeddings:
8989
embeddings = np.load(path.replace('.h5', '-embeddings.npz'))
90-
for nm in self._embeddings.keys(): # recreate embeddings
90+
for nm, emb in self._embeddings.items():
9191
w = embeddings[nm]
92-
emb = L.Embedding(w.shape[0], w.shape[1])
93-
emb.build((None, 1))
94-
emb.set_weights([w])
95-
self._embeddings[nm] = emb # replace
92+
if not emb.built: emb.build((None, w.shape[0]))
93+
emb.set_weights([w]) # replace embeddings
9694
continue
9795
return
9896

0 commit comments

Comments
 (0)