In [None]:
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.optim as optim

torch.manual_seed(1)

<torch._C.Generator at 0x7f3c92e40090>

In [None]:
def argmax(vec):
  return torch.argmax(vec, 1).item()

def prepare_sequence(seq, to_idx):
  idxs = [to_idx[w] for w in seq]
  return torch.tensor(idxs, dtype=torch.long)

def log_sum_exp(vec):
  max_score = vec[0, argmax(vec)]
  max_score_broadcast = max_score.view(1, -1).expand(1, vec.size()[1])
  return max_score + torch.log(torch.sum(torch.exp(vec - max_score_broadcast)))

In [None]:
START_TAG = "<START>"
STOP_TAG = "<STOP>"

In [None]:
class BiLSTM_CRF(nn.Module):
  def __init__(self, vocab_size, tag_to_idx, embed_dim, hidden_dim):
    super(BiLSTM_CRF, self).__init__()
    self.hidden_dim = hidden_dim
    self.vocab_size = vocab_size
    self.tag_to_idx = tag_to_idx
    self.tagset_size = len(tag_to_idx)

    self.word_embed = nn.Embedding(vocab_size, embed_dim)
    self.lstm = nn.LSTM(embed_dim, hidden_dim // 2, num_layers=1, bidirectional=True)
    self.hidden2tag = nn.Linear(hidden_dim, self.tagset_size)

    # Matrix of transition parameters.  Entry i,j is the score of
    # transitioning `to` i `from` j.
    self.transitions = nn.Parameter(torch.randn(self.tagset_size, self.tagset_size))

    # These two statements enforce the constraint that we never transfer
    # to the start tag and we never transfer from the stop tag
    self.transitions.data[tag_to_idx[START_TAG], :] = -10000
    self.transitions.data[:, tag_to_idx[STOP_TAG]] = -10000

    self.hidden = self.init_hidden()

  def init_hidden(self):
    return torch.randn(2, 1, self.hidden_dim // 2), torch.randn(2, 1, self.hidden_dim // 2)

  def _forward_alg(self, feats):
    init_alphas = torch.full((1, self.tagset_size), -10000.)
    init_alphas[0][self.tag_to_idx[START_TAG]] = 0.
    
    forward_var = init_alphas

    for feat in feats:
      alphas_t = []
      for next_tag in range(self.tagset_size):
        emit_score = feat[next_tag].view(1, -1).expand(1, self.tagset_size)
        trans_score = self.transitions[next_tag].view(1, -1)
        next_tag_var = forward_var + trans_score + emit_score
        alphas_t.append(log_sum_exp(next_tag_var).view(1))
      forward_var = torch.cat(alphas_t).view(1, -1)
    terminal_var = forward_var + self.transitions[self.tag_to_idx[STOP_TAG]]
    alpha = log_sum_exp(terminal_var)
    return alpha

  def _get_lstm_features(self, sentence):
    self.hidden = self.init_hidden()
    embeds = self.word_embed(sentence).view(len(sentence), 1, -1)
    lstm_out, self.hidden = self.lstm(embeds, self.hidden)
    lstm_out = lstm_out.view(len(sentence), self.hidden_dim)
    lstm_feats = self.hidden2tag(lstm_out)
    return lstm_feats

  def _score_sentence(self, feats, tags):
    score = torch.zeros(1)
    tags = torch.cat([torch.tensor([self.tag_to_idx[START_TAG]], dtype=torch.long), tags])
    for i, feat in enumerate(feats):
      score = score + self.transitions[tags[i+1], tags[i]] + feat[tags[i+1]]
    score = score + self.transitions[self.tag_to_idx[STOP_TAG], tags[-1]]
    return score

  def _viterbi_decode(self, feats):
    backpointers = []

    init_vvars = torch.full((1, self.tagset_size), -10000.)
    init_vvars[0][self.tag_to_idx[START_TAG]] = 0

    forward_var = init_vvars
    for feat in feats:
      bptrs_t = []
      viterbivars_t = []

      for next_tag in range(self.tagset_size):
        next_tag_var = forward_var + self.transitions[next_tag]
        best_tag_id = argmax(next_tag_var)
        bptrs_t.append(best_tag_id)
        viterbivars_t.append(next_tag_var[0][best_tag_id].view(1))
      forward_var = (torch.cat(viterbivars_t) + feat).view(1, -1)
      backpointers.append(bptrs_t)

    terminal_var = forward_var + self.transitions[self.tag_to_idx[STOP_TAG]]
    best_tag_id = argmax(terminal_var)
    path_score = terminal_var[0][best_tag_id]

    best_path = [best_tag_id]
    for bptrs_t in reversed(backpointers):
      best_tag_id = bptrs_t[best_tag_id]
      best_path.append(best_tag_id)
    start = best_path.pop()
    assert start == self.tag_to_idx[START_TAG]
    best_path.reverse()
    return path_score, best_path

  def neg_log_likelihood(self, sentence, tags):
    feats = self._get_lstm_features(sentence)
    forward_score = self._forward_alg(feats)
    gold_score = self._score_sentence(feats, tags)
    return forward_score - gold_score

  def forward(self, sentence):
    lstm_feats = self._get_lstm_features(sentence)
    score, tag_seq = self._viterbi_decode(lstm_feats)
    return score, tag_seq

In [None]:
EMBED_DIM = 5
HIDDEN_DIM = 4

training_data = [(
    "the wall street journal reported today that apple corporation made money".split(),
    "B I I I O O O B I O O".split()
), (
    "georgia tech is a university in georgia".split(),
    "B I O O O O B".split()
)]

vocab = set([word for data in training_data for word in data[0]])
word_to_idx = {word:idx for idx, word in enumerate(vocab)}
tag_to_idx = {"B": 0, "I": 1, "O": 2, START_TAG: 3, STOP_TAG: 4}
idx_to_tag = list(tag_to_idx.keys())

In [None]:
import pandas as pd
def print_beauty_tag(sentence, tag_seq):
  print(pd.DataFrame({'word': sentence, 'tag': [idx_to_tag[tag] for tag in tag_seq]}))

In [None]:
model = BiLSTM_CRF(len(word_to_idx), tag_to_idx, EMBED_DIM, HIDDEN_DIM)
optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=1e-4)

with torch.no_grad():
    precheck_sent = prepare_sequence(training_data[0][0], word_to_idx)
    _, tag_seq = model(precheck_sent)
    print_beauty_tag(training_data[0][0], tag_seq)

           word tag
0           the   B
1          wall   O
2        street   O
3       journal   O
4      reported   O
5         today   O
6          that   O
7         apple   O
8   corporation   O
9          made   O
10        money   O


In [None]:
for epoch in range(300):
  for sentence, tags in training_data:
    model.zero_grad()

    sentence_in = prepare_sequence(sentence, word_to_idx)
    targets = torch.tensor([tag_to_idx[t] for t in tags], dtype=torch.long)

    loss = model.neg_log_likelihood(sentence_in, targets)

    loss.backward()
    optimizer.step()

In [None]:
with torch.no_grad():
    precheck_sent = prepare_sequence(training_data[0][0], word_to_idx)
    _, tag_seq = model(precheck_sent)
    print_beauty_tag(training_data[0][0], tag_seq)

           word tag
0           the   B
1          wall   I
2        street   I
3       journal   I
4      reported   O
5         today   O
6          that   O
7         apple   B
8   corporation   I
9          made   O
10        money   O
