In [None]:
# -*- coding:utf-8 -*-
# Created by LuoJie at 11/23/19

from utils.config import save_wv_model_path, vocab_path
import tensorflow as tf
from utils.gpu_utils import config_gpu
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import GRU, Input, Dense, TimeDistributed, Activation, RepeatVector, Bidirectional
from tensorflow.keras.layers import Embedding
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import sparse_categorical_crossentropy
import tensorflow as tf
from utils.wv_loader import load_embedding_matrix, Vocab

from tensorflow.python.ops import nn_ops


class Encoder(tf.keras.Model):
    def __init__(self, vocab_size, embedding_dim, embedding_matrix, enc_units, batch_sz):
        super(Encoder, self).__init__()
        self.batch_sz = batch_sz
        self.enc_units = enc_units
        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim, weights=[embedding_matrix],
                                                   trainable=False)
        self.gru = tf.keras.layers.GRU(self.enc_units,
                                       return_sequences=True,
                                       return_state=True,
                                       recurrent_initializer='glorot_uniform')

    def __call__(self, x, hidden):
        x = self.embedding(x)
        output, hidden = self.gru(x, initial_state=hidden)
        return output, hidden

    def initialize_hidden_state(self):
        return tf.zeros((self.batch_sz, self.enc_units))


def masked_attention(enc_padding_mask, attn_dist):
    """Take softmax of e then apply enc_padding_mask and re-normalize"""
    attn_dist = tf.squeeze(attn_dist, axis=2)
    mask = tf.cast(enc_padding_mask, dtype=attn_dist.dtype)
    attn_dist *= mask  # apply mask
    masked_sums = tf.reduce_sum(attn_dist, axis=1)  # shape (batch_size)
    attn_dist = attn_dist / tf.reshape(masked_sums, [-1, 1])  # re-normalize
    attn_dist = tf.expand_dims(attn_dist, axis=2)
    return attn_dist


class BahdanauAttention(tf.keras.layers.Layer):
    def __init__(self, units):
        super(BahdanauAttention, self).__init__()
        self.W_s = tf.keras.layers.Dense(units)
        self.W_h = tf.keras.layers.Dense(units)
        self.W_c = tf.keras.layers.Dense(units)
        self.V = tf.keras.layers.Dense(1)

    def __call__(self, dec_hidden, enc_output, enc_pad_mask, use_coverage, prev_coverage=None):
        # query为上次的GRU隐藏层
        # values为编码器的编码结果enc_output
        # 在seq2seq模型中，St是后面的query向量，而编码过程的隐藏状态hi是values。

        # hidden shape == (batch_size, hidden size)
        # hidden_with_time_axis shape == (batch_size, 1, hidden size)
        # we are doing this to perform addition to calculate the score
        hidden_with_time_axis = tf.expand_dims(dec_hidden, 1)

        if use_coverage and prev_coverage is not None:
            # Multiply coverage vector by w_c to get coverage_features.
            # self.W_s(values) [batch_sz, max_len, units] self.W_h(hidden_with_time_axis) [batch_sz, 1, units]
            # self.W_c(prev_coverage) [batch_sz, max_len, units]  score [batch_sz, max_len, 1]
            score = self.V(tf.nn.tanh(self.W_s(enc_output) + self.W_h(hidden_with_time_axis) + self.W_c(prev_coverage)))
            # attention_weights shape (batch_size, max_len, 1)

            # attention_weights sha== (batch_size, max_length, 1)
            attention_weights = tf.nn.softmax(score, axis=1)

            # attention_weights = masked_attention(enc_pad_mask, attention_weights)
            coverage = attention_weights + prev_coverage
        else:
            # score shape == (batch_size, max_length, 1)
            # we get 1 at the last axis because we are applying score to self.V
            # the shape of the tensor before applying self.V is (batch_size, max_length, units)
            # 计算注意力权重值
            score = self.V(tf.nn.tanh(
                self.W_s(enc_output) + self.W_h(hidden_with_time_axis)))

            attention_weights = tf.nn.softmax(score, axis=1)
            # attention_weights = masked_attention(enc_pad_mask, attention_weights)
            if use_coverage:
                coverage = attention_weights
            else:
                coverage = []

        # # 使用注意力权重*编码器输出作为返回值，将来会作为解码器的输入
        # context_vector shape after sum == (batch_size, hidden_size)
        context_vector = attention_weights * enc_output
        context_vector = tf.reduce_sum(context_vector, axis=1)
        return context_vector, tf.squeeze(attention_weights, -1), coverage


class Decoder(tf.keras.Model):
    def __init__(self, vocab_size, embedding_dim, embedding_matrix, dec_units, batch_sz):
        super(Decoder, self).__init__()
        self.batch_sz = batch_sz
        self.dec_units = dec_units
        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim, weights=[embedding_matrix],
                                                   trainable=False)
        self.gru = tf.keras.layers.GRU(self.dec_units,
                                       return_sequences=True,
                                       return_state=True,
                                       recurrent_initializer='glorot_uniform')
        self.fc = tf.keras.layers.Dense(vocab_size, activation=tf.keras.activations.softmax)

    def __call__(self, x, hidden, enc_output, context_vector):
        # 使用上次的隐藏层（第一次使用编码器隐藏层）、编码器输出计算注意力权重
        # enc_output shape == (batch_size, max_length, hidden_size)

        # x shape after passing through embedding == (batch_size, 1, embedding_dim)
        # print('x:{}'.format(x))
        x = self.embedding(x)

        # 将上一循环的预测结果跟注意力权重值结合在一起作为本次的GRU网络输入
        # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)
        dec_x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)

        # passing the concatenated vector to the GRU
        output, state = self.gru(dec_x)

        # output shape == (batch_size * 1, hidden_size)
        output = tf.reshape(output, (-1, output.shape[2]))

        # output shape == (batch_size, vocab)
        prediction = self.fc(output)

        return dec_x, prediction, state


class Pointer(tf.keras.layers.Layer):

    def __init__(self):
        super(Pointer, self).__init__()
        self.w_s_reduce = tf.keras.layers.Dense(1)
        self.w_i_reduce = tf.keras.layers.Dense(1)
        self.w_c_reduce = tf.keras.layers.Dense(1)

    def __call__(self, context_vector, dec_hidden, dec_inp):
        return tf.nn.sigmoid(self.w_s_reduce(dec_hidden) + self.w_c_reduce(context_vector) + self.w_i_reduce(dec_inp))


if __name__ == '__main__':
    # GPU资源配置
    config_gpu()
    # 读取vocab训练
    vocab = Vocab(vocab_path)
    # 计算vocab size
    vocab_size = vocab.count

    # 使用GenSim训练好的embedding matrix
    embedding_matrix = load_embedding_matrix()

    enc_max_len = 200
    dec_max_len = 41
    batch_size = 64
    embedding_dim = 300
    units = 1024

    # 编码器结构
    encoder = Encoder(vocab_size, embedding_dim, embedding_matrix, units, batch_size)
    # encoder input
    enc_inp = tf.ones(shape=(batch_size, enc_max_len), dtype=tf.int32)
    # decoder input
    dec_inp = tf.ones(shape=(batch_size, 1, dec_max_len), dtype=tf.int32)
    # enc pad mask
    enc_pad_mask = tf.ones(shape=(batch_size, enc_max_len), dtype=tf.int32)

    # encoder hidden
    enc_hidden = encoder.initialize_hidden_state()

    enc_output, enc_hidden = encoder(enc_inp, enc_hidden)
    # 打印结果
    print('Encoder output shape: (batch size, sequence length, units) {}'.format(enc_output.shape))
    print('Encoder Hidden state shape: (batch size, units) {}'.format(enc_hidden.shape))

    attention_layer = BahdanauAttention(10)
    context_vector, attention_weights, coverage = attention_layer(enc_hidden, enc_output, enc_pad_mask)

    print("Attention context_vector shape: (batch size, units) {}".format(context_vector.shape))
    print("Attention weights shape: (batch_size, sequence_length, 1) {}".format(attention_weights.shape))
    print("Attention coverage shape: (batch_size, ) {}".format(coverage.shape))

    decoder = Decoder(vocab_size, embedding_dim, embedding_matrix, units, batch_size)

    dec_x, dec_out, dec_hidden, = decoder(tf.random.uniform((64, 1)),
                                          enc_hidden,
                                          enc_output,
                                          context_vector)
    print('Decoder output shape: (batch_size, vocab size) {}'.format(dec_out.shape))
    print('Decoder dec_x shape: (batch_size, 1,embedding_dim + units) {}'.format(dec_x.shape))

    pointer = Pointer()
    p_gen = pointer(context_vector, dec_hidden, dec_inp)
    print('Pointer p_gen shape: (batch_size,1) {}'.format(p_gen.shape))


In [1]:
# 初始化
%load_ext autoreload
%autoreload 2
import sys
import os
os.chdir('E:\GitHub\QA-abstract-and-reasoning')
sys.path.append('E:\GitHub\QA-abstract-and-reasoning')

In [2]:
import tensorflow as tf
from pgn.layers import Encoder, Decoder, Pointer, BahdanauAttention
from pgn.model import PGN
from pgn.batcher import batcher
from utils.saveLoader import load_embedding_matrix
from utils.saveLoader import Vocab
from utils.config import VOCAB_PAD
from utils.config_gpu import config_gpu
config_gpu()

1 Physical GPUs, 1 Logical GPUs


## 模拟model.call里的情况
### 构建输入

In [3]:
%run utils/params.py
# 产生输入数据
vocab = Vocab(VOCAB_PAD)
ds = batcher(vocab, params)
enc_data, dec_data = next(iter(ds))

## model用到的层

In [4]:
embedding_matrix = load_embedding_matrix()
encoder = Encoder(vocab_size = params["vocab_size"],
                               embedding_dim = params["embed_size"],
                               embedding_matrix = embedding_matrix,
                               enc_units = params["enc_units"],
                               batch_size = params["batch_size"])
attention = BahdanauAttention(units = params["attn_units"])
decoder = Decoder(vocab_size =  params["vocab_size"],
                               embedding_dim = params["embed_size"],
                               embedding_matrix = embedding_matrix,
                               dec_units = params["dec_units"],
                               batch_size = params["batch_size"])
pointer = Pointer()

## model.call的参数

In [5]:
enc_inp = enc_data["enc_input"]
dec_inp = dec_data["dec_input"]
enc_extended_inp = enc_data["extended_enc_input"]
batch_oov_len = enc_data["max_oov_len"]
enc_pad_mask = enc_data["enc_mask"]
use_coverage = True
prev_coverage=None

In [6]:
predictions = []
attentions = []
p_gens = []
coverages = []

# enc_output (batch_size, enc_len, enc_units)
# enc_hidden (batch_size, enc_units)
enc_output, enc_hidden = encoder(enc_inp)
dec_hidden = enc_hidden

# context_vector (batch_size, enc_units)
# coverage_ret (batch_size, enc_len, 1)
context_vector, _, coverage_ret = attention(dec_hidden,
                                             enc_output,
                                             enc_pad_mask,
                                             use_coverage,
                                             prev_coverage)

## 进入调用decoder的循环

In [7]:
for t in range(dec_inp.shape[1]):

    # dec_inp[:, t] (batch_size, )
    # dec_x (batch_size, 1, embedding_dim + dec_units)
    # dec_pred (batch_size, vocab_size)
    # dec_hidden (batch_size, dec_units)
    dec_x, dec_pred, dec_hidden = decoder(tf.expand_dims(dec_inp[:, t], 1),
                                           dec_hidden,
                                           enc_output,
                                           context_vector)
    context_vector, attn, coverage_ret = attention(dec_hidden,
                                                        enc_output,
                                                        enc_pad_mask,
                                                        use_coverage,
                                                        coverage_ret)
    # p_gen (batch_size, 1)
    p_gen = pointer(context_vector, dec_hidden, tf.squeeze(dec_x, axis=1))
    coverages.append(coverage_ret)
    attentions.append(attn)
    predictions.append(dec_pred)
    p_gens.append(p_gen)

## **cal_final_dist**
ps 要在上面的函数循环之后才进入

这一步才是今天的重点，先宏观的看下这个函数怎么用

In [16]:
# 宏观了解
from pgn.model import _calc_final_dist
# final_dists list (batch_size, vocab_size+batch_oov_len)
final_dists = _calc_final_dist(enc_extended_inp,
                               predictions,
                               attentions,
                               p_gens,
                               batch_oov_len,
                               params["vocab_size"],
                               params["batch_size"])
vocab.count, enc_data["max_oov_len"].numpy()

(32233, 5)

### 进入内部前配置参数

In [9]:
_enc_batch_extend_vocab = enc_extended_inp  # (batch_size, enc_len)
vocab_dists = predictions  # (batch_size, vocab_size)
attn_dists = attentions  # (batch_size, enc_len)
p_gens = p_gens  # (batch_size, 1)
batch_oov_len = batch_oov_len  # 5(for example)
vocab_size = params["vocab_size"]  # 32233
batch_size = params["batch_size"]  # 64 

### 原版内部

In [37]:
# 计算加权后的原词表分布
# shape保持不变
vocab_dists = [p_gen * dist for (p_gen, dist) in zip(p_gens, vocab_dists)]
# 计算注意力分布
# shape保持不变
attn_dists = [(1 - p_gen) * dist for (p_gen, dist) in zip(p_gens, attn_dists)]
# extended_vocab_size 扩展后的词表大小
extended_vocab_size = vocab_size + batch_oov_len 
# 给vocab_dists额外拼接的全0列
# extra_zeros (batch_size, batch_oov_len)
extra_zeros = tf.zeros((batch_size, batch_oov_len))

# 在vocab_dists尾部拼接oov全0列
# vocab_dists_extended (batch_size, vocab_size+batch_oov_len)
vocab_dists_extended = [tf.concat(axis=1, values=[dist, extra_zeros]) for dist in vocab_dists]

# 0-63 的索引数组
batch_nums = tf.range(0, limit=batch_size)
# batch_nums (batch_size, 1)
batch_nums = tf.expand_dims(batch_nums, 1)
# attn_len value: enc_len(200)
attn_len = tf.shape(_enc_batch_extend_vocab)[1]
# batch_nums (batch_size, enc_len) 
batch_nums = tf.tile(batch_nums, [1, attn_len])

# 给每一批次的样本加上了批次号标签
# indices (batch_size, enc_len, 2)
indices = tf.stack((batch_nums, _enc_batch_extend_vocab), axis=2)

# extended_vocab_size = vocab_size + batch_oov_len 
shape = [batch_size, extended_vocab_size]

attn_dists_projected = [tf.scatter_nd(indices, copy_dist, shape) for copy_dist in attn_dists]

final_dists2 = [vocab_dist + copy_dist for (vocab_dist, copy_dist) in
                   zip(vocab_dists_extended, attn_dists_projected)]

### 改进版
感觉不需要这么多列表生成式

好吧是需要的

但是我感觉可以改成不需要


In [11]:
# attentions_ (batch_size, dec_len, enc_len)
# predictions_ (batch_size, dec_len, vocab_size)
# p_gens_ (batch_size, dec_len, 1)
attentions_ = tf.stack(attentions, 1)
predictions_ = tf.stack(predictions, 1)
p_gens_ = tf.stack(p_gens, 1)

$$
P(w) = p_{gen}P_{vocab}(w)+(1-P_{gen})\sum_{i:w_i=w}a_i^t
$$


In [14]:
# 确定的修改代码
# 先计算公式的左半部分
# _vocab_dists_pgn (batch_size, dec_len, vocab_size)
_vocab_dists_pgn = predictions_ * p_gens_
# 根据oov表的长度补齐原词表
# _extra_zeros (batch_size, dec_len, batch_oov_len)
_extra_zeros = tf.zeros((batch_size, p_gens_.shape[1],batch_oov_len))
# 拼接后公式的左半部分完成了
# _vocab_dists_extended (batch_size, dec_len, vocab_size+batch_oov_len)
_vocab_dists_extended = tf.concat([_vocab_dists_pgn, _extra_zeros], axis=-1)

# 公式右半部分
# 乘以权重后的注意力
# _attn_dists_pgn (batch_size, dec_len, enc_len)
_attn_dists_pgn = attentions_ * (1-p_gens_)
# 拓展后的长度
_extended_vocab_size = vocab_size + batch_oov_len

# 要更新的数组 _attn_dists_pgn
# 更新之后数组的形状与 公式左半部分一致
# shape=[batch_size, dec_len, vocab_size+batch_oov_len]
shape = _vocab_dists_extended.shape

enc_len = tf.shape(_enc_batch_extend_vocab)[1]
dec_len = tf.shape(_vocab_dists_extended)[1]

# batch_nums (batch_size, )
batch_nums = tf.range(0, limit=batch_size)
# batch_nums (batch_size, 1)
batch_nums = tf.expand_dims(batch_nums, 1)
# batch_nums (batch_size, 1, 1)
batch_nums = tf.expand_dims(batch_nums, 2)

# tile 在第1,2个维度上分别复制batch_nums dec_len,enc_len次
# batch_nums (batch_size, dec_len, enc_len) 
batch_nums = tf.tile(batch_nums, [1, dec_len, enc_len])

# (dec_len, )
dec_len_nums = tf.range(0, limit=dec_len)
# (1, dec_len)
dec_len_nums = tf.expand_dims(dec_len_nums, 0)
# (1, dec_len, 1)
dec_len_nums = tf.expand_dims(dec_len_nums, 2)
# tile是用来在不同维度上复制张量的
# dec_len_nums (batch_size, dec_len, enc_len) 
dec_len_nums = tf.tile(dec_len_nums, [batch_size, 1, enc_len])

# _enc_batch_extend_vocab_expand (batch_size, 1, enc_len)
_enc_batch_extend_vocab_expand = tf.expand_dims(_enc_batch_extend_vocab, 1)
# _enc_batch_extend_vocab_expand (batch_size, dec_len, enc_len) 
_enc_batch_extend_vocab_expand = tf.tile(_enc_batch_extend_vocab_expand, [1, dec_len, 1])

# 因为要scatter到一个3Dtensor上，所以最后一维是3
# indices (batch_size, dec_len, enc_len, 3) 
indices = tf.stack((batch_nums, 
                    dec_len_nums, 
                    _enc_batch_extend_vocab_expand), 
                   axis=3)

# 开始更新
attn_dists_projected = tf.scatter_nd(indices, _attn_dists_pgn, shape)
# 至此完成了公式的右半边

# 计算最终分布
final_dists = _vocab_dists_extended + attn_dists_projected

In [18]:
final_dists = tf.stack(final_dists, 1)

In [24]:
64*39*32238

80466048

In [26]:
80466048-80460945

5103

In [25]:
(final_dists2 == final_dists).numpy().sum()

80460945

两种方法计算得到的final_dist结果一致

### 调试中

难点是公式的这部分 $\sum_{i:w_i=w}a_i^t$

如何从代码层面把相同词的注意力加和到一起

我看[有篇文章](https://blog.csdn.net/zlrai5895/article/details/80551056)说：**函数`tf.scatter_nd`更新应用的顺序是非确定性的，所以如果indices包含重复项的话，则输出将是不确定的。**



#### 学习tf.scatter_nd
```
tf.scatter_nd(
    indices,
    updates,
    shape,
    name=None
)

```
[这篇文章](https://blog.csdn.net/zlrai5895/article/details/80551056)的例子和图片比较丰富，是翻译了[官方文档](https://tensorflow.google.cn/api_docs/python/tf/scatter_nd?hl=en&version=stable)的中文版解释。


In [2]:
import tensorflow as tf

#### 基本用法
`updates`是原始的张量

`shape` 是要变成怎样长度的张量

`indices` 是原始张量在新张量的位置

In [9]:
indices = tf.constant([[1], [3], [5], [7]])
updates = tf.constant([1, 2, 3, 4])
shape = tf.constant([8])

scatter = tf.scatter_nd(indices, updates, shape)
scatter

<tf.Tensor: id=17, shape=(8,), dtype=int32, numpy=array([0, 1, 0, 2, 0, 3, 0, 4])>

#### indices重复项的验证情况
可以看到，如果indices包含重复项，那这些重复项的数字会加和然后放到指定位置

之所以说可能会数值不确定，是因为浮点数的数值精度问题，按不同顺序加和的数据，得到的最终数值可能不同，但其实大差不差

In [15]:
indices = tf.constant([[1], [1], [3], [3]])
scatter = tf.scatter_nd(indices, updates, shape)
scatter

<tf.Tensor: id=29, shape=(8,), dtype=int32, numpy=array([0, 3, 0, 7, 0, 0, 0, 0])>

#### 问题迁移step1
先从简单的例子开始
_attn_dists_pgn 是 `(64, 40, 200)`维度的注意力数值 

需要转变为`(64, 40, vocab_size+batch_oov_len)`

先实现(200,) --> (vocab_size+batch_oov_len,)

In [56]:
# step1
# 终极目标
_attn_dists_pgn.shape

# 构建要更新的数组
step1_updates = _attn_dists_pgn[0,0,:]
step1.shape

# 用于更新的索引，维度要比updates高一维
step1_indices = tf.expand_dims(_enc_batch_extend_vocab[0], axis=-1)
# 更新完成后的形状
step1_shape = tf.constant([vocab_size]) + batch_oov_len
step1_indices.shape, step1_shape

# 开始更新
scatter = tf.scatter_nd(step1_indices, step1_updates, step1_shape)
scatter.numpy().sum()

# 更新前注意力之和
# 在小数点5位之后有微笑的差别，不知道是否关键
step1_updates.numpy().sum()

TensorShape([64, 40, 200])

#### step2 (64, 200) -> (64, vocab_size+batch_oov_len)

In [58]:
# step2 构建要更新的数组
step2_updates = _attn_dists_pgn[:,0,:]
step2_updates.shape

# 0-63 的索引数组
# batch_nums (batch_size,)
batch_nums = tf.range(0, limit=batch_size)

# batch_nums (batch_size, 1)
batch_nums = tf.expand_dims(batch_nums, 1)

# attn_len : enc_len(200)
attn_len = tf.shape(_enc_batch_extend_vocab)[1]

# batch_nums (batch_size, enc_len) 
batch_nums = tf.tile(batch_nums, [1, attn_len])

# 给每一批次的样本加上了批次号标签
# indices (batch_size, enc_len, 2)
indices = tf.stack((batch_nums, _enc_batch_extend_vocab), axis=2)


# 用于更新的索引，维度要比updates高一维
step2_indices = indices
# 更新完成后的形状
_extended_vocab_size = vocab_size + batch_oov_len
step2_shape =[batch_size, _extended_vocab_size]
step2_indices.shape, step2_shape

# 开始更新
scatter = tf.scatter_nd(step2_indices, step2_updates, step2_shape)
scatter.numpy().sum()

step2_updates.numpy().sum()

TensorShape([64, 200])

#### step3 (64, 40, 200)-> ((64, 40, vocab_size+batch_oov_len))

In [44]:
# step3 构建要更新的数组
step3_updates = _attn_dists_pgn[:]
step3_shape = _vocab_dists_extended.shape
# step3_updates.shape, step3_shape

enc_len = tf.shape(_enc_batch_extend_vocab)[1]
dec_len = _vocab_dists_extended.shape[1]

# batch_nums (batch_size, )
batch_nums = tf.range(0, limit=batch_size)
# batch_nums (batch_size, 1)
batch_nums = tf.expand_dims(batch_nums, 1)
# batch_nums (batch_size, 1, 1)
batch_nums = tf.expand_dims(batch_nums, 2)
# batch_nums (batch_size, dec_len, enc_len) 
batch_nums = tf.tile(batch_nums, [1, dec_len, enc_len])

# (dec_len, )
dec_len_nums = tf.range(0, limit=dec_len)
# (1, dec_len)
dec_len_nums = tf.expand_dims(dec_len_nums, 0)
# (1, dec_len, 1)
dec_len_nums = tf.expand_dims(dec_len_nums, 2)
# tile是用来重复的
# dec_len_nums (batch_size, dec_len, enc_len) 
dec_len_nums = tf.tile(dec_len_nums, [batch_size, 1, enc_len])

# (batch_size, 1, enc_len)
_enc_batch_extend_vocab_expand = tf.expand_dims(_enc_batch_extend_vocab, 1)
_enc_batch_extend_vocab_expand = tf.tile(_enc_batch_extend_vocab_expand, 
                                         [1, dec_len, 1])

indices = tf.stack((batch_nums, 
                    dec_len_nums, 
                    _enc_batch_extend_vocab_expand), 
                   axis=3)

# 开始更新
scatter = tf.scatter_nd(indices, step3_updates, step3_shape)
# scatter.numpy().sum()

final_dists3 = scatter + _vocab_dists_extended

(TensorShape([64, 39, 200]), TensorShape([64, 39, 32242]))