diff --git a/keras_transformer/transformer.py b/keras_transformer/transformer.py index 6960f1c..6b66912 100644 --- a/keras_transformer/transformer.py +++ b/keras_transformer/transformer.py @@ -464,7 +464,7 @@ def decode(model, if top_k == 1: last_token = predicts[i][-1].argmax(axis=-1) else: - probs = [(prob, i) for i, prob in enumerate(predicts[i][-1])] + probs = [(prob, j) for j, prob in enumerate(predicts[i][-1])] probs.sort(reverse=True) probs = probs[:top_k] indices, probs = list(map(lambda x: x[1], probs)), list(map(lambda x: x[0], probs))