In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import numpy as np
from torch.autograd import Variable as Variable
import random
import torch.nn.init as init
import copy
import tqdm

In [2]:
START_TAG = "START"
STOP_TAG = "STOP"


class BiLSTM_CRF(nn.Module):
    def __init__(self,tag_to_ix,batch_size, vocab_size, embedding_dim, hidden_dim,weight=None):
        super(BiLSTM_CRF, self).__init__()
        self.tag_to_ix = tag_to_ix
        self.batch_size = batch_size
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.tag_size = len(tag_to_ix)
        
        if weight is not None:
            self.word_embeddings = nn.Embedding.from_pretrained(weight)
        else:
            self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim)
        self.lstm = nn.LSTM(self.embedding_dim, self.hidden_dim // 2, num_layers = 1, bidirectional=True, batch_first = True)
        self.hidden2tag = nn.Linear(self.hidden_dim, self.tag_size)
        self.hidden = self.init_hidden()
        
        self.transitions = nn.Parameter(torch.randn(self.tag_size, self.tag_size))
        self.transitions.data[self.tag_to_ix[START_TAG], :] = -10000
        self.transitions.data[:, self.tag_to_ix[STOP_TAG]] = -10000
        
    def init_hidden(self):
        return(torch.randn(2, self.batch_size, self.hidden_dim // 2),
               torch.randn(2, self.batch_size, self.hidden_dim // 2))
    
    def _get_lstm_features(self, sentences):
        self.hidden = self.init_hidden()
        length = sentences.shape[1]
        embeddings = self.word_embeddings(sentences).view(self.batch_size, length, self.embedding_dim)
        lstm_out, self.hidden = self.lstm(embeddings, self.hidden)
        lstm_out = lstm_out.view(self.batch_size, -1, self.hidden_dim)
        lstm_feats = self.hidden2tag(lstm_out)
        return lstm_feats
    
    def _score_sentence(self, feats, label):
        score = torch.zeros(1)
        label = torch.cat([torch.tensor([self.tag_to_ix[START_TAG]], dtype=torch.long), label])
        for index, feat in enumerate(feats):
            emission_sorce = feat[label[index + 1]]
            transitions_score = self.transitions[label[index + 1], label[index]]
            score += emission_sorce + transitions_score
        score += self.transitions[self.tag_to_ix[STOP_TAG], label[-1]]
        return score
    
    def _forward_alg(self, feats):
        init_alphas = torch.full([self.tag_size], -10000.)
        init_alphas[self.tag_to_ix[START_TAG]] = 0.
        forward_var = init_alphas
        
        for feat_index in range(feats.shape[0]):
            previous = torch.stack([forward_var] * feats.shape[1])
            emit_scores = torch.unsqueeze(feats[feat_index], 0).transpose(0, 1)
            next_tag_var = previous + emit_scores + self.transitions
            forward_var = torch.logsumexp(next_tag_var, dim=1)
        terminal_val = forward_var + self.transitions[self.tag_to_ix[STOP_TAG]]
        terminal_val = torch.unsqueeze(terminal_val, 0)
        alpha = torch.logsumexp(terminal_val, dim=1)[0]
        return alpha
    
    def neg_log_likelihood(self, sentences, tags, lengths):
        self.batch_size = sentences.size(0)
        featss = self._get_lstm_features(sentences)
        gold_score = torch.zeros(1)
        forward_score = torch.zeros(1)
        for feats, tag, length in zip(featss, tags, lengths):
            feats = feats[:length]
            tag = tag[:length]
            gold_score += self._score_sentence(feats, tag)
            forward_score += self._forward_alg(feats)
        return forward_score - gold_score
    
    def forward(self, sentences, lengths=None):
        sentences = torch.tensor(sentences, dtype=torch.long)
        if not lengths:
            lengths = [i.size(-1) for i in sentences]
        self.batch_size = sentences.size(0)
        logits = self._get_lstm_features(sentences)
        scores = []
        paths = []
        for logit, leng in zip(logits, lengths):
            logit = logit[:leng]
            score, path = self._viterbi_decode(logit)
            scores.append(score)
            paths.append(path)
        return scores, paths
    
    def _viterbi_decode(self, feats):
        backpointers = []
        init_vvars = torch.full((1, self.tag_size), -10000.)
        init_vvars[0][self.tag_to_ix[START_TAG]] = 0
        forward_var = init_vvars
        
        for feat_index in range(feats.shape[0]):
            forward_vars = torch.stack([forward_var] * feats.shape[1])
            forward_vars = torch.squeeze(forward_vars)
            next_tag_vars = forward_vars + self.transitions
            viterbivar_s_t, bptr_s_t = torch.max(next_tag_vars, dim=1)
            
            feat_s_t = torch.unsqueeze(feats[feat_index], 0)
            forward_var = torch.unsqueeze(viterbivar_s_t, 0) + feat_s_t
            backpointers.append(bptr_s_t.tolist())
            
        terminal_val = forward_var + self.transitions[self.tag_to_ix[STOP_TAG]]
        best_tag_id = torch.argmax(terminal_val).tolist()
        path_score = terminal_val[0][best_tag_id]
        best_path = [best_tag_id]
        for bptrs_t in reversed(backpointers):
            best_tag_id = bptrs_t[best_tag_id]
            best_path.append(best_tag_id)
        start = best_path.pop()
        assert start == self.tag_to_ix[START_TAG]
        best_path.reverse()
        return path_score, best_path

In [3]:
def prepare_sequence(seq, to_ix):
    idxs = [to_ix[w] for w in seq]
    return torch.tensor(idxs, dtype=torch.long)

In [4]:
EMBEDDING_DIM = 5
HIDDEN_DIM = 4

# Make up some training data
training_data = [(
    "the wall street journal reported today that apple corporation made money".split(),
    "B I I I O O O B I O O".split()
), (
    "georgia tech is a university in georgia".split(),
    "B I O O O O B".split()
)]

word_to_ix = {}
for sentence, tags in training_data:
    for word in sentence:
        if word not in word_to_ix:
            word_to_ix[word] = len(word_to_ix)
word_to_ix['<UNK>'] = len(word_to_ix)

tag_to_ix = {"B": 0, "I": 1, "O": 2, START_TAG: 3, STOP_TAG: 4}

In [5]:
def padding(seq, tag, maxlen, pad='<UNK>'):
    p_seq = []
    p_tag = []
    true_len = len(seq)
    if len(seq) < maxlen:
        p_seq = copy.copy(seq)
        p_tag = copy.copy(tag)
        while len(p_seq) < maxlen:
            p_seq.append(pad)
            p_tag.append('STOP')
    else:
        p_seq = seq[:maxlen]
        p_tag = tag[:maxlen]
    p_seq = prepare_sequence(p_seq, word_to_ix)
    p_tag = torch.tensor([tag_to_ix[t] for t in p_tag], dtype=torch.long)
    return p_seq, p_tag, true_len

In [6]:
training_data_pad = []
for i in range(len(training_data)):
    training_data_pad.append(padding(training_data[i][0], training_data[i][1], 11)) #这里用11是因为样例数据最大也就11个词

In [7]:
data_iter = torch.utils.data.DataLoader(training_data_pad, batch_size=2)

In [8]:
model = BiLSTM_CRF(tag_to_ix=tag_to_ix, batch_size=2, vocab_size=len(word_to_ix), embedding_dim=50, hidden_dim=64)
optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=1e-4)

In [9]:
with torch.no_grad():
    print(model(training_data_pad[0][0].view(1, -1), [training_data_pad[0][2]]))

([tensor(12.1246)], [[1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]])




In [10]:
for epoch in range(20):
    for sentence, tags, true_len in data_iter:
        model.zero_grad()
        loss = model.neg_log_likelihood(sentence, tags, true_len)
        loss.backward()
        optimizer.step()

In [11]:
with torch.no_grad():
    print(model(training_data_pad[0][0].view(1, -1), [training_data_pad[0][2]]))

([tensor(7.9958)], [[0, 1, 1, 1, 2, 2, 2, 0, 1, 2, 2]])


