In [73]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
%matplotlib inline
from tqdm.auto import tqdm
import concurrent.futures
from multiprocessing import Pool
import copy,os,sys,psutil
from collections import Counter,deque
import itertools
import os

In [74]:
import tensorflow as tf
import tensorflow_addons as tfa
import tensorflow_hub as hub
import numpy as np
import matplotlib.pyplot as plt

In [16]:
baseDir="/home/zhoutong/notebook_collection/tmp/NLP_ner/lstmcrf"
checkpoint_dir = baseDir+"/ckpt"

# SubclassedM

## BiLSTMCRF

In [380]:
# 参考： https://github.com/saiwaiyanyu/bi-lstm-crf-ner-tf2.0/blob/master/train.py
class LSTMCRF(tf.keras.Model):
    def __init__(self, label_size, lstm_units, vocab_size, emb_dim):
        super().__init__()
        self.label_size = label_size
        self.lstm_units = lstm_units
        self.vocab_size = vocab_size
        self.emb_dim = emb_dim

        self.embedding = tf.keras.layers.Embedding(self.vocab_size, self.emb_dim)
        self.bilstm = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(self.lstm_units, return_sequences=True))
#         self.dense = tf.keras.layers.Dense(label_size,activation="softmax")
        self.dense = tf.keras.layers.Dense(label_size) # 正统做法是没有激活函数没做归一化
        self.dropout = tf.keras.layers.Dropout(0.5)

        self.transition_params = tf.Variable(tf.random.uniform(shape=(self.label_size, self.label_size)),trainable=False)

    # @tf.function
    def call(self, text,labels=None,training=None):
        seq_lens = tf.math.reduce_sum(tf.cast(tf.math.not_equal(text, 0), dtype=tf.int32), axis=-1)
        # -1 change 0·
        inputs = self.embedding(text)
        inputs = self.dropout(inputs, training)
        logits = self.dense(self.bilstm(inputs))

        if labels is not None:
            label_sequences = tf.convert_to_tensor(labels, dtype=tf.int32)
            log_likelihood, self.transition_params = tfa.text.crf_log_likelihood(logits, label_sequences, seq_lens)
            self.transition_params = tf.Variable(self.transition_params, trainable=False)
            return logits, seq_lens, log_likelihood
        else:
            return logits, seq_lens
        

## BERTLayer

In [20]:
import tensorflow_hub as hub

In [None]:
max_seq_length = 128  # Your choice here.
input_word_ids = tf.keras.layers.Input(shape=(max_seq_length,), dtype=tf.int32,
                                       name="input_word_ids")
input_mask = tf.keras.layers.Input(shape=(max_seq_length,), dtype=tf.int32,
                                   name="input_mask")
segment_ids = tf.keras.layers.Input(shape=(max_seq_length,), dtype=tf.int32,
                                    name="segment_ids")
bert_layer = hub.KerasLayer("https://tfhub.dev/tensorflow/bert_zh_L-12_H-768_A-12/1",
                            trainable=True)
pooled_output, sequence_output = bert_layer([input_word_ids, input_mask, segment_ids])

In [None]:
class BertLayer(tf.layers.Layer):
    def __init__(self, n_fine_tune_layers=10, **kwargs):
        self.n_fine_tune_layers = n_fine_tune_layers
        self.trainable = True
        self.output_size = 768
        super(BertLayer, self).__init__(**kwargs)

    def build(self, input_shape):
        self.bert = hub.Module(
            bert_path,
            trainable=self.trainable,
            name="{}_module".format(self.name)
        )
        trainable_vars = self.bert.variables
        
        # Remove unused layers
        trainable_vars = [var for var in trainable_vars if not "/cls/" in var.name]
        
        # Select how many layers to fine tune
        trainable_vars = trainable_vars[-self.n_fine_tune_layers :]
        
        # Add to trainable weights
        for var in trainable_vars:
            self._trainable_weights.append(var)
        
        # Add non-trainable weights
        for var in self.bert.variables:
            if var not in self._trainable_weights:
                self._non_trainable_weights.append(var)
        
        super(BertLayer, self).build(input_shape)

    def call(self, inputs):
        inputs = [K.cast(x, dtype="int32") for x in inputs]
        input_ids, input_mask, segment_ids = inputs
        bert_inputs = dict(
            input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids
        )
        result = self.bert(inputs=bert_inputs, signature="tokens", as_dict=True)[
            "pooled_output"
        ]
        return result

    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.output_size)

# LoadData

In [489]:
pad_length = 100
target_fp=baseDir+"/data/renminribao/target_BIO_2014_cropus.txt"
source_fp=baseDir+"/data/renminribao/source_BIO_2014_cropus.txt"
with open(target_fp,"r") as fr:
    target=[i.strip() for i in fr.readlines()]
with open(source_fp,"r") as fr:
    source=[i.strip() for i in fr.readlines()]

# 所有样本
sample_iter=zip(source,target)
# 词表
all_words=set([char for sentence in source for char in sentence.split(" ") if char != ''])
word2idx = dict((word,idx+1) for idx,word in enumerate(all_words))
word2idx.update({"PAD":0})
# 标注label
all_tags=set([tag for tags in target for tag in tags.split(" ") if tag != ''])
tag2idx = dict((tag,idx) for idx,tag in enumerate(all_tags))
# tag2idx.update({"PAD":0})

# 训练集测试集
def _yield_samples(source_inp,target_inp):
    for sen,labels in zip(source_inp,target_inp):
        X = ([w for w in sen.split(" ")]+['PAD']*pad_length)[:pad_length]
        X = [word2idx[w] for w in X]
        Y = ([w for w in labels.split(" ")]+['O']*pad_length)[:pad_length]
        Y = [tag2idx[l] for l in Y]
        yield (X,Y)

total_size = len(source)
train_size = int(total_size*0.7)
train_dataset = _yield_samples(source[:train_size],target[:train_size])
test_dataset = _yield_samples(source[train_size:],target[train_size:])

In [482]:
idx2word=dict([(v,k) for k,v in word2idx.items()])
idx2tag=dict([(v,k) for k,v in tag2idx.items()])

In [475]:
sample = sample_iter.__next__()
sample
sen,tag = sample
sen_idx=[word2idx[i] for i in sen.split(" ")]
tag_idx=[tag2idx[i] for i in tag.split(" ")]
print(sen,"\n",sen_idx)
print(tag,"\n",tag_idx)

sen_idx = sen_idx[:4]
tag_idx = tag_idx[:4]
sen_idx,tag_idx


('人 民 网 1 月 1 日 讯 据 《 纽 约 时 报 》 报 道 , 美 国 华 尔 街 股 市 在 2 0 1 3 年 的 最 后 一 天 继 续 上 涨 , 和 全 球 股 市 一 样 , 都 以 最 高 纪 录 或 接 近 最 高 纪 录 结 束 本 年 的 交 易 。',
 'O O O B_T I_T I_T I_T O O O B_LOC I_LOC O O O O O O B_LOC I_LOC I_LOC I_LOC I_LOC O O O B_T I_T I_T I_T I_T O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O')

人 民 网 1 月 1 日 讯 据 《 纽 约 时 报 》 报 道 , 美 国 华 尔 街 股 市 在 2 0 1 3 年 的 最 后 一 天 继 续 上 涨 , 和 全 球 股 市 一 样 , 都 以 最 高 纪 录 或 接 近 最 高 纪 录 结 束 本 年 的 交 易 。 
 [5235, 3477, 628, 3238, 1756, 3238, 4392, 2325, 1762, 18, 2960, 2119, 2604, 394, 2827, 394, 3983, 1376, 1321, 4679, 6015, 2643, 5759, 4508, 2073, 5981, 1198, 3487, 3238, 4295, 2235, 1129, 1090, 1079, 5964, 1241, 1091, 4083, 3167, 3457, 1376, 5860, 1698, 5248, 4508, 2073, 5964, 3910, 1376, 4979, 843, 1090, 3128, 4074, 5030, 3197, 1537, 4389, 1090, 3128, 4074, 5030, 5425, 4839, 525, 2235, 1129, 377, 2916, 2900]
O O O B_T I_T I_T I_T O O O B_LOC I_LOC O O O O O O B_LOC I_LOC I_LOC I_LOC I_LOC O O O B_T I_T I_T I_T I_T O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O 
 [7, 7, 7, 6, 4, 4, 4, 7, 7, 7, 8, 3, 7, 7, 7, 7, 7, 7, 8, 3, 3, 3, 3, 7, 7, 7, 6, 4, 4, 4, 4, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7]


([5235, 3477, 628, 3238], [7, 7, 7, 6])

# InitModel | Params

In [511]:
config={
    "label_size":len(tag2idx),
    "lstm_units":8,
    "vocab_size":len(word2idx),
    "emb_dim":300
}
print(config)
M = LSTMCRF(**config)

M.build((None,20))
M.summary()

{'label_size': 9, 'lstm_units': 8, 'vocab_size': 6364, 'emb_dim': 300}
Model: "lstmcrf_12"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
embedding_12 (Embedding)     multiple                  1909200   
_________________________________________________________________
bidirectional_12 (Bidirectio multiple                  19776     
_________________________________________________________________
dense_12 (Dense)             multiple                  153       
_________________________________________________________________
dropout_12 (Dropout)         multiple                  0         
Total params: 1,929,210
Trainable params: 1,929,129
Non-trainable params: 81
_________________________________________________________________


# Train

## pre

In [525]:
##################################################
# tf.train.Checkpoint 不能存模型内部有Model类变量的模型
# tf.train.CheckpointManager 不能在保存时指定prefix 
# 自行包装一个类，主要使用model.save_weights
##################################################
class CustomCkpt:
    def __init__(self, ckpt_dir, model=None,max_keep=20):
        self.ckpt_dir = ckpt_dir
        self.best_valid_acc = 0.0
        self.best_valid_loss = 1e10
        self.total_saved_fps = []
        self.history = []
        self.model = model
        self.max_keep=max_keep
         
    def save(self,val_acc,val_loss,model=None,fileName=None):
        if model is not None:
            self.model = model
        assert self.model is not None, "初始化时未指定model则.save()必须提供model"
        if fileName is None:
            save_fp = os.path.join(self.ckpt_dir,"ckpt_"+len(self.total_saved_fps))
        else:
            save_fp = os.path.join(self.ckpt_dir,fileName)
        saved = False # 因为acc和loss提升时都要打log并更新best，但是保存只存一次
        if val_acc > self.best_valid_acc:
            print(f"acc improved [from]:{self.best_valid_acc:.4f} [to]:{val_acc:.4f}.")
            self.best_valid_acc = val_acc
            if not saved:
                self.model.save_weights(save_fp)
                saved=True
        if val_loss < self.best_valid_loss:
            print(f"loss improved [from]:{self.best_valid_loss:.4f} [to]:{val_loss:.4f}.")
            self.best_valid_loss = val_loss
            if not saved:
                self.model.save_weights(save_fp)
                saved=True
        # 限制最多保存文件个数
        if saved:
            print(f"[ckpt-path]: {save_fp}")
            self.total_saved_fps.append(save_fp)
            if len(self.total_saved_fps) >= self.max_keep:
                toDel_fp = self.total_saved_fps.pop(0)
                status,output=subprocess.getstatusoutput(f"rm {toDel_fp}*")
        
        self.history.append({"val_acc":val_acc, "val_loss":val_loss})
        

In [526]:
######
# tbd
######
opt=tf.keras.optimizers.Adam()
tbd_dir = baseDir+"/tensorboard"
if os.path.exists(tbd_dir):
    import shutil
    shutil.rmtree(tbd_dir)
    print("历史tbd信息已删除")
summary_writer = tf.summary.create_file_writer(tbd_dir)
with summary_writer.as_default():
    pass
#     _=tf.summary.image("Trainning Data", normal_flow_train[0][0], max_outputs=4, step=0)

    
#######
# ckpt
#######
ckpt_saver = CustomCkpt(checkpoint_dir)

##################
# 损失函数和评价指标
##################
ce_loss_fn = tf.keras.losses.categorical_crossentropy
acc_fn = tf.keras.metrics.categorical_accuracy

########
# 优化器
########
optimizer = tf.optimizers.Adam(1e-3)

历史tbd信息已删除


## train loop

In [527]:
data_batch = list(itertools.islice(train_dataset,batch_size))
text_batch = np.array([i[0] for i in data_batch])
label_batch = np.array([i[1] for i in data_batch])
text_batch
text_batch.shape

array([[6316, 1418, 3744, ...,    0,    0,    0],
       [5694, 3606, 3837, ...,    0,    0,    0],
       [2869,  818, 3837, ..., 2356, 1698, 2900],
       ...,
       [4692, 5235, 3392, ...,    0,    0,    0],
       [2543, 1363, 2604, ..., 1376, 1582,  401],
       [1523,  289, 5141, ...,    0,    0,    0]])

(20, 100)

In [528]:
def train_one_step(text_batch, labels_batch):
    with tf.GradientTape() as tape:
        logits, seq_len_batch, log_likelihood = M(text_batch, labels_batch,training=True)
        loss = - tf.reduce_mean(log_likelihood)
    gradients = tape.gradient(loss, M.trainable_variables)
    optimizer.apply_gradients(zip(gradients, M.trainable_variables))
    return loss,logits, seq_len_batch

def get_acc_one_step(logits, seq_len_batch, labels_batch):
    paths = []
    acc = 0
    for logit, seq_len, labels in zip(logits, seq_len_batch, labels_batch):
        viterbi_path, _ = tfa.text.viterbi_decode(logit[:seq_len], M.transition_params)
        paths.append(viterbi_path)
        correct_prediction = tf.equal(
            tf.convert_to_tensor(tf.keras.preprocessing.sequence.pad_sequences([viterbi_path], padding='post'),
                                 dtype=tf.int32),
            tf.convert_to_tensor(tf.keras.preprocessing.sequence.pad_sequences([labels[:seq_len]], padding='post'),
                                 dtype=tf.int32)
        )
        acc = acc + tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    acc = acc / len(paths)
    return acc


def get_validation_acc():
    while True:
        data_batch = list(itertools.islice(test_dataset,batch_size))
        if len(data_batch) < batch_size:
            break
        text_batch = np.array([i[0] for i in data_batch])
        labels_batch = np.array([i[1] for i in data_batch])
        step = step + 1
        acc = get_acc_one_step(logits, seq_len_batch, labels_batch)
        


best_acc = 0
step = 0
epoch=10
batch_size=20
for epoch in range(epoch):
    while True:
        data_batch = list(itertools.islice(train_dataset,batch_size))
        if len(data_batch) < batch_size:
            break
        text_batch = np.array([i[0] for i in data_batch])
        labels_batch = np.array([i[1] for i in data_batch])
        step = step + 1
        loss, logits, seq_len_batch = train_one_step(text_batch, labels_batch)
        if step % 20 == 0:
            acc = get_acc_one_step(logits, seq_len_batch, labels_batch)
            tqdm.write(f"[e]: {epoch} [step]: {step} [acc]: {acc:.4f} [loss]:{loss:.4f}")
            ckpt_saver.save(val_loss=loss,val_acc=acc,model=M,fileName=f"ckpt_e{epoch}_trainLoss{loss:.4f}_trainAcc{acc:.4f}")

[e]: 0 [step]: 20 [acc]: 0.9326 [loss]:23.6377
acc improved [from]:0.0000 [to]:0.9326.
loss improved [from]:10000000000.0000 [to]:23.6377.
[ckpt-path]: /home/zhoutong/notebook_collection/tmp/NLP_ner/lstmcrf/ckpt/ckpt_e0_trainLoss23.6377_trainAcc0.9326
[e]: 0 [step]: 40 [acc]: 0.9235 [loss]:25.9836
[e]: 0 [step]: 60 [acc]: 0.9280 [loss]:18.2528
loss improved [from]:23.6377 [to]:18.2528.
[ckpt-path]: /home/zhoutong/notebook_collection/tmp/NLP_ner/lstmcrf/ckpt/ckpt_e0_trainLoss18.2528_trainAcc0.9280
[e]: 0 [step]: 80 [acc]: 0.8746 [loss]:38.0338
[e]: 0 [step]: 100 [acc]: 0.9223 [loss]:21.5593
[e]: 0 [step]: 120 [acc]: 0.9270 [loss]:26.2218
[e]: 0 [step]: 140 [acc]: 0.8553 [loss]:26.2525
[e]: 0 [step]: 160 [acc]: 0.7975 [loss]:46.1039
[e]: 0 [step]: 180 [acc]: 0.9197 [loss]:22.3032
[e]: 0 [step]: 200 [acc]: 0.9179 [loss]:29.2866
[e]: 0 [step]: 220 [acc]: 0.9482 [loss]:15.3642
acc improved [from]:0.9326 [to]:0.9482.
loss improved [from]:18.2528 [to]:15.3642.
[ckpt-path]: /home/zhoutong/note

KeyboardInterrupt: 

# Valid

# Test

# CRF 计算参数实验

```
loglikelihood = sequence_score - log_norm
= (unary_score + binary_score) - logsumexp(alphas)
= (unary_score + binary_score) - logsumexp(crf_forward)
```

## loglikelihood

In [466]:
sen_idx=[500,501,502,503]
# tag_idx=[1,2,3,10]
tag_idx=[1,2,3,9]
inp = np.array([sen_idx])
"inp"
inp
# logits=M.dense(M.dropout(M.bilstm(M.embedding(inp))))
logits,seq_lens=M(inp)
"logits",logits.numpy().shape
logits.numpy()
label_sequences = np.array([tag_idx])
"label_sequences",label_sequences.shape
label_sequences
text_lens=np.array([len(sen_idx)])
loglikelihood,trans_params = tfa.text.crf_log_likelihood(logits, label_sequences, text_lens,M.transition_params)
"loglikelihood"
loglikelihood.numpy()
"seq_score"
tfa.text.crf_sequence_score(logits,label_sequences,text_lens,M.transition_params).numpy()
"log_norm"
tfa.text.crf_log_norm(logits, text_lens, M.transition_params).numpy()

'inp'

array([[500, 501, 502, 503]])

('logits', (1, 4, 10))

array([[[ 0.02905709, -0.0079308 ,  0.00130202, -0.01221373,
          0.00268621,  0.01556253, -0.00820034,  0.01168286,
         -0.01324397, -0.02614049],
        [ 0.01999625,  0.00837216,  0.02181312, -0.02321985,
          0.00284035,  0.02077706,  0.00205434,  0.01210876,
          0.00665217, -0.01442756],
        [ 0.01611964,  0.008289  ,  0.02945245, -0.01523286,
          0.00243924,  0.01745703, -0.0071438 ,  0.00392865,
          0.00447142,  0.00135649],
        [ 0.00934999, -0.0086754 , -0.01002031, -0.0018477 ,
          0.01812736,  0.01125069, -0.00853377,  0.00692784,
         -0.01058302, -0.00975374]]], dtype=float32)

('label_sequences', (1, 4))

array([[1, 2, 3, 9]])

'loglikelihood'

array([-8.901098], dtype=float32)

'seq_score'

array([1.7472095], dtype=float32)

'log_norm'

array([10.648308], dtype=float32)

## sequence_socre

In [252]:
f"crf_sequence_score: {tfa.text.crf_sequence_score(logits,label_sequences,text_lens,M.transition_params)}"
"== unary_score+bianry_score"
f"unary_score: {tfa.text.crf_unary_score(label_sequences, text_lens, logits)}"
f"== sum(emission of labelSeq): {sum([logits.numpy()[0,idx,i] for idx,i in enumerate(label_sequences[0])])}"
f"bianry_score: {tfa.text.crf_binary_score(label_sequences, text_lens, M.transition_params)}"
f"== sum(transition of labelSeq): {sum([M.transition_params.numpy()[a,b] for a,b in list(zip(label_sequences[0][:-1],label_sequences[0][1:]))])}"

'crf_sequence_score: [1.0131763]'

'== unary_score+bianry_score'

'unary_score: [0.3015548]'

'== sum(emission of labelSeq): 0.301554799079895'

'bianry_score: [0.7116215]'

'== sum(transition of labelSeq): 0.7116215229034424'

## log_norm

In [271]:
f"log_norm: {tfa.text.crf_log_norm(logits, text_lens, M.transition_params).numpy()}"

'log_norm: [8.306995]'

### first_input&rest_input

In [297]:
inputs=logits
sequence_lengths=text_lens
first_input = tf.slice(inputs, [0, 0, 0], [-1, 1, -1])
first_input = tf.squeeze(first_input, [1])
rest_of_input = tf.slice(inputs, [0, 1, 0], [-1, -1, -1])
f"[logits]: {logits.numpy().shape}"
logits.numpy()
f"[first_input]: {first_input.numpy().shape}"
first_input.numpy()
f"[rest_of_input]: {rest_of_input.numpy().shape}"
rest_of_input.numpy()

'[logits]: (1, 3, 10)'

array([[[ 0.00217672,  0.0098424 ,  0.00291217, -0.00717967,
          0.01139859, -0.01732886,  0.00449238, -0.0078244 ,
         -0.01410343,  0.00459101],
        [ 0.01717858,  0.00824076,  0.0123352 ,  0.02027788,
          0.01698909, -0.03285388, -0.02621158, -0.01658068,
         -0.01112588,  0.01208735],
        [-0.0097875 , -0.00283362,  0.01374137,  0.01923978,
          0.01765317, -0.01433519, -0.02876953, -0.03505114,
         -0.0089741 , -0.01326887]]], dtype=float32)

'[first_input]: (1, 10)'

array([[ 0.00217672,  0.0098424 ,  0.00291217, -0.00717967,  0.01139859,
        -0.01732886,  0.00449238, -0.0078244 , -0.01410343,  0.00459101]],
      dtype=float32)

'[rest_of_input]: (1, 2, 10)'

array([[[ 0.01717858,  0.00824076,  0.0123352 ,  0.02027788,
          0.01698909, -0.03285388, -0.02621158, -0.01658068,
         -0.01112588,  0.01208735],
        [-0.0097875 , -0.00283362,  0.01374137,  0.01923978,
          0.01765317, -0.01433519, -0.02876953, -0.03505114,
         -0.0089741 , -0.01326887]]], dtype=float32)

### crf_forward
前向计算和viterbi类似，只不过是把求最大改成了求和

In [306]:
alphas = tfa.text.crf_forward(rest_of_input, first_input, M.transition_params, sequence_lengths)
alphas
tf.reduce_logsumexp(alphas,[1])

<tf.Tensor: shape=(1, 10), dtype=float32, numpy=
array([[5.750242 , 5.76396  , 5.7077093, 5.6722016, 5.826963 , 5.8499684,
        5.7189903, 5.7958345, 5.6517015, 5.6982207]], dtype=float32)>

<tf.Tensor: shape=(1,), dtype=float32, numpy=array([8.048121], dtype=float32)>

In [305]:
# 模拟计算所有序列情况各自的sequence_score，并求和
# 从结果来看不是这样计算的
length=text_lens[0]
ps=[]
for i in range(length):
    for j in range(length): 
        for m in range(length):
                ps.append([i,j,m])
ps
res=[]
for seq in ps:
    unary_score=sum([logits.numpy()[0][idx,i] for idx,i in enumerate(seq)])
    binary_score = sum([M.transition_params[a][b].numpy() for a,b in list(zip(seq[:-1],seq[1:]))])
    res.append((unary_score, binary_score))

len(res)
res
# tf.reduce_logsumexp(res).numpy()
# for word_idx,emit_list in enumerate(logits.numpy()[0]):
#     for emit_idx,emit in enumerate(emit_list):
#         emit

[[0, 0, 0],
 [0, 0, 1],
 [0, 0, 2],
 [0, 1, 0],
 [0, 1, 1],
 [0, 1, 2],
 [0, 2, 0],
 [0, 2, 1],
 [0, 2, 2],
 [1, 0, 0],
 [1, 0, 1],
 [1, 0, 2],
 [1, 1, 0],
 [1, 1, 1],
 [1, 1, 2],
 [1, 2, 0],
 [1, 2, 1],
 [1, 2, 2],
 [2, 0, 0],
 [2, 0, 1],
 [2, 0, 2],
 [2, 1, 0],
 [2, 1, 1],
 [2, 1, 2],
 [2, 2, 0],
 [2, 2, 1],
 [2, 2, 2]]

27

[(0.009567800909280777, 1.563323736190796),
 (0.016521684359759092, 1.2504353523254395),
 (0.03309667017310858, 0.9342750310897827),
 (0.0006299763917922974, 1.1747034788131714),
 (0.007583859842270613, 0.6671038866043091),
 (0.024158845655620098, 1.3209037780761719),
 (0.004724414087831974, 0.8684554100036621),
 (0.01167829753831029, 1.0337573289871216),
 (0.028253283351659775, 0.5662949085235596),
 (0.017233476042747498, 1.4875918626785278),
 (0.024187359493225813, 1.1747034788131714),
 (0.0407623453065753, 0.8585431575775146),
 (0.008295651525259018, 0.9042603969573975),
 (0.015249534975737333, 0.39666080474853516),
 (0.03182452078908682, 1.050460696220398),
 (0.012390089221298695, 1.5679725408554077),
 (0.01934397267177701, 1.7332744598388672),
 (0.035918958485126495, 1.2658120393753052),
 (0.010303251445293427, 1.4975041151046753),
 (0.017257134895771742, 1.1846157312393188),
 (0.03383212070912123, 0.8684554100036621),
 (0.001365426927804947, 1.5870741605758667),
 (0.0083193103782

In [296]:
rest_of_input

<tf.Tensor: shape=(1, 3, 10), dtype=float32, numpy=
array([[[0.10040308, 0.10123939, 0.09852703, 0.09806626, 0.10067997,
         0.09858277, 0.10091217, 0.09968802, 0.10062727, 0.10127402],
        [0.09987237, 0.10105918, 0.09948634, 0.09686752, 0.09984942,
         0.09828502, 0.10078489, 0.10046219, 0.10212169, 0.10121138],
        [0.10036152, 0.0987159 , 0.09999987, 0.09854174, 0.09934476,
         0.09669184, 0.10196146, 0.10054415, 0.1020008 , 0.10183792]]],
      dtype=float32)>

In [211]:
# crf_forward
inputs = rest_of_input
stat = first_input
transition_params = M.transition_params
"shape: rest_of_input,first_input,transition_params"
[i.numpy().shape for i in [inputs,stat,transition_params]]

inputs = tf.transpose(inputs, [1, 0, 2])
transition_params = tf.expand_dims(transition_params, 0)
"shape(after): rest_of_input,first_input,transition_params"
[i.numpy().shape for i in [inputs,stat,transition_params]]


last_index = tf.maximum(tf.constant(0, dtype=sequence_lengths.dtype), sequence_lengths - 1)


stat
tf.expand_dims(stat, 2)
transition_params
tf.expand_dims(stat, 2) + transition_params
tf.reduce_logsumexp(transition_scores, [1])

'shape: rest_of_input,first_input,transition_params'

[(1, 3, 10), (1, 10), (10, 10)]

'shape(after): rest_of_input,first_input,transition_params'

[(3, 1, 10), (1, 10), (1, 10, 10)]

<tf.Tensor: shape=(1, 10), dtype=float32, numpy=
array([[0.1006834 , 0.10105516, 0.09909209, 0.09861839, 0.10069787,
        0.09951755, 0.0996674 , 0.10011183, 0.100015  , 0.10054132]],
      dtype=float32)>

<tf.Tensor: shape=(1, 10, 1), dtype=float32, numpy=
array([[[0.1006834 ],
        [0.10105516],
        [0.09909209],
        [0.09861839],
        [0.10069787],
        [0.09951755],
        [0.0996674 ],
        [0.10011183],
        [0.100015  ],
        [0.10054132]]], dtype=float32)>

<tf.Tensor: shape=(1, 10, 10), dtype=float32, numpy=
array([[[0.44697106, 0.2418617 , 0.8227587 , 0.4104979 , 0.782711  ,
         0.01426017, 0.40244222, 0.11317647, 0.1748898 , 0.6239077 ],
        [0.20628917, 0.99719846, 0.12566197, 0.8589554 , 0.01937366,
         0.8665881 , 0.59559464, 0.00556386, 0.9480275 , 0.35019696],
        [0.9995408 , 0.9722867 , 0.5364071 , 0.971385  , 0.3729949 ,
         0.6110319 , 0.9596269 , 0.5348549 , 0.06864536, 0.6061748 ],
        [0.19580531, 0.8521589 , 0.43386817, 0.62382567, 0.6366831 ,
         0.86621857, 0.06924629, 0.8040581 , 0.6113379 , 0.9791328 ],
        [0.58327127, 0.62254274, 0.4878497 , 0.5879643 , 0.3653742 ,
         0.0238682 , 0.3072735 , 0.34319937, 0.31553054, 0.5738963 ],
        [0.342664  , 0.01548791, 0.08265579, 0.3064649 , 0.49686992,
         0.83524024, 0.50523055, 0.95652676, 0.46592188, 0.51819336],
        [0.63470256, 0.2315464 , 0.24391973, 0.30125952, 0.42490733,
         0.63373923, 0.6383817 , 0.35166633,

<tf.Tensor: shape=(1, 10, 10), dtype=float32, numpy=
array([[[0.54765445, 0.3425451 , 0.92344207, 0.5111813 , 0.8833944 ,
         0.11494357, 0.5031256 , 0.21385986, 0.2755732 , 0.7245911 ],
        [0.30734432, 1.0982536 , 0.22671713, 0.9600105 , 0.12042882,
         0.96764326, 0.6966498 , 0.10661902, 1.0490826 , 0.4512521 ],
        [1.0986329 , 1.0713788 , 0.6354992 , 1.0704771 , 0.472087  ,
         0.71012396, 1.058719  , 0.63394696, 0.16773745, 0.7052669 ],
        [0.2944237 , 0.9507773 , 0.53248656, 0.72244406, 0.7353015 ,
         0.96483696, 0.16786468, 0.90267646, 0.7099563 , 1.0777512 ],
        [0.68396914, 0.7232406 , 0.5885476 , 0.6886622 , 0.46607208,
         0.12456607, 0.40797138, 0.44389725, 0.4162284 , 0.67459416],
        [0.44218156, 0.11500546, 0.18217334, 0.40598246, 0.5963875 ,
         0.9347578 , 0.60474813, 1.0560443 , 0.56543946, 0.61771095],
        [0.73437   , 0.3312138 , 0.34358713, 0.40092692, 0.52457476,
         0.73340666, 0.73804915, 0.45133373,

<tf.Tensor: shape=(1, 10), dtype=float32, numpy=
array([[2.9214406, 2.9880924, 2.808669 , 2.9367938, 2.8524787, 2.950623 ,
        2.892315 , 2.9182906, 2.941332 , 2.918378 ]], dtype=float32)>