In [86]:
import pandas as pd
import numpy as np
import torch

data = pd.read_csv("/Users/huangyuhao/Downloads/1014_4361_bundle_archive/ner_dataset.csv", encoding="latin1")
# data = data.fillna(method="ffill")
sentence = list(data["Sentence #"].values)
words = list(data["Word"].values)
tags = list(data["Tag"].values)
word_data = []
tag_data= []
for i in range(len(sentence)):
    if(str(sentence[i])[:9] == "Sentence:"):
        word_data.append([])
        tag_data.append([])
    word_data[-1].append(words[i])
    tag_data[-1].append(tags[i])
word_set = list(set(data["Word"].values))
tag_set = list(set(data["Tag"].values))

N = len(tag_set)
M = len(word_set)

A = torch.zeros(N, N) #transition 
B = torch.zeros(N, M) #emission
Pi = torch.zeros(N) #initial state

word2id = list(range(M))
word2id = dict(zip(word_set, word2id))
tag2id = list(range(N))
tag2id = dict(zip(tag_set, tag2id))

def train(A, B, Pi, words, tags, word2id, tag2id) :
    assert len(words) ==  len(tags)
    num = 0
    for taglist in tags:
        length = len(taglist) 
        for i in range(length-1):
            currentid = tag2id[taglist[i]]
            nextid = tag2id[taglist[i+1]]
            A[currentid][nextid] += 1
            num += 1
    A[A == 0.] = 1e-10
    A = A / A.sum(dim = 1, keepdim = True)
    
    for taglist, wordlist in zip(tags, words) :
        assert len(taglist) == len(wordlist)
        for tag, word in zip(taglist, wordlist) :
            tagid = tag2id[tag]
            wordid = word2id[word]
            B[tagid][wordid] += 1
    B[B == 0.] = 1e-10
    B = B / B.sum(dim = 1, keepdim = True)
    
    for taglist in tags:
        init = tag2id[taglist[0]]
        Pi[init] += 1
    Pi[Pi == 0.] = 1e-10
    Pi = Pi / Pi.sum()
    
    return A, B, Pi

A, B, Pi = train(A, B, Pi, word_data, tag_data, word2id, tag2id)

In [116]:
def decoding(A, B, Pi, N, M, words, word2id, tag2id):
    A = np.log(A)
    B = np.log(B)
    Pi = np.log(Pi)
    
    length = len(words)
    viterbi = torch.zeros(N, length)
    backpointer = torch.zeros(N, length)
    
    start = word2id.get(words[0], None)
    Bt = B.t()
    bt = []
    if start is None:
        bt = torch.log(torch.ones(N) / N)
    else:
        bt = Bt[start]
    viterbi[:,0] = Pi + bt
    backpointer[:,0] = -1
    
    for step in range(1, length) :
        wordid = word2id.get(words[step], None)
        if wordid is None:
            bt = torch.log(torch.ones(N) / N)
        else: 
            bt = Bt[wordid]
        for tag in range(len(tag2id)):
            max_prob, max_id = torch.max(viterbi[:, step-1] + A[:, tag], dim = 0)
            viterbi[tag, step] = max_prob + bt[tag]
            backpointer[tag, step] = max_id
    best_prob, best_pointer = torch.max(viterbi[:, -1], dim=0)
    
    best_pointer =  best_pointer.item()
    best_path = [best_pointer]
    for back in range(length-1, 0, -1) :
        best_pointer = backpointer[int(best_pointer), back]
        best_pointer = best_pointer.item()
        best_path.append(best_pointer)
    
    
    assert len(best_path) == len(words)
    
    id2tag = dict((id_, tag) for tag, id_ in tag2id.items())
    taglist = [id2tag[id_] for id_ in reversed(best_path)]
    return taglist
    
words = 'My girlfriend is Lisa Huang, and she starts to work at Blackrock Co. from 2020 May 20th in the United States'
words = words.split()
taglist = decoding(A, B, Pi, N, M, words, word2id, tag2id)
print(taglist)

['O', 'O', 'O', 'B-per', 'I-per', 'O', 'O', 'O', 'O', 'O', 'O', 'B-org', 'I-org', 'O', 'B-tim', 'I-tim', 'I-tim', 'O', 'O', 'B-geo', 'I-geo']
