In [1]:
import torch
import torchcrf
import nltk
from torchcrf import CRF
import numpy as np
from sklearn.model_selection import train_test_split
import random

In [2]:
nltk.download('treebank')

[nltk_data] Downloading package treebank to /home/lizirui/nltk_data...
[nltk_data]   Package treebank is already up-to-date!


True

In [3]:
sents = nltk.corpus.treebank.tagged_sents()
#print(sents)
random.seed(66)

# Get the length of the longest sequence and pad the sentence

In [4]:
#train_sents, test_sents = train_test_split(sents, test_size = 0.05, random_state = 1000)
max_len = 0
for sent in sents:
    max_len = max(max_len, len(sent))
pad_token = "<PAD>"
pad_label = "PAD"
padded_sents = []
for i, sent in enumerate(sents):
    curr_len = len(sent)
    padded_sents.append(sent + [(pad_token, pad_label)] * (max_len - curr_len))
#print(padded_sents)
sents = padded_sents

# Turn the word to lower case

In [5]:
for i in range(len(sents)):
    for j in range(len(sent)):
        if sents[i][j][0] != pad_token:
            sents[i][j] = (sents[i][j][0].lower(), sents[i][j][1])
#print(sents[0])
train_data, test_data = train_test_split(sents, shuffle = True, test_size = 0.2, random_state = 17)

# compute number of labels and tags

In [6]:
tag_set = set([tup[1] for sent in sents for tup in sent])
num_tag = len(tag_set)

word_set = set([tup[0] for sent in sents for tup in sent])
num_word = len(word_set)
print(num_tag, num_word)

47 12410


# Count word frequency

In [7]:
def count_word_freq(sents):
    word_tag_pair_freq = dict()
    word_freq = dict()
    for sent in sents:
        for pair in sent:
            if pair not in word_tag_pair_freq.keys():
                word_tag_pair_freq[pair] = 1
            else:
                word_tag_pair_freq[pair] = word_tag_pair_freq[pair] + 1
            word = pair[0]
            if word not in word_freq.keys():
                word_freq[word] = 1
            else:
                word_freq[word] = word_freq[word] + 1
    return word_freq, word_tag_pair_freq

word_freq, word_tag_pair_freq = count_word_freq(sents)

# Build label-idx dict and label-idx reverse dict

In [8]:
def build_label_idx(label_set):
    i = 0
    label2idx = dict()
    idx2label = dict()
    for label in label_set:
        label2idx[label] = i
        idx2label[i] = label
        i += 1
    return label2idx, idx2label

label2idx, idx2label = build_label_idx(tag_set)
print(label2idx, idx2label)

{'WRB': 0, 'JJR': 1, '-LRB-': 2, '-NONE-': 3, 'RBR': 4, 'CD': 5, 'WP$': 6, 'CC': 7, 'DT': 8, 'POS': 9, 'VBN': 10, 'NNS': 11, 'SYM': 12, '-RRB-': 13, 'PRP$': 14, '.': 15, 'RB': 16, 'NNPS': 17, 'EX': 18, 'JJS': 19, 'RP': 20, 'IN': 21, 'VBP': 22, 'LS': 23, '``': 24, 'TO': 25, 'FW': 26, 'MD': 27, '#': 28, 'PRP': 29, ',': 30, 'RBS': 31, 'NNP': 32, 'VBG': 33, 'VBZ': 34, '$': 35, 'UH': 36, ':': 37, 'WP': 38, 'VB': 39, 'PAD': 40, 'WDT': 41, 'PDT': 42, 'VBD': 43, 'JJ': 44, 'NN': 45, "''": 46} {0: 'WRB', 1: 'JJR', 2: '-LRB-', 3: '-NONE-', 4: 'RBR', 5: 'CD', 6: 'WP$', 7: 'CC', 8: 'DT', 9: 'POS', 10: 'VBN', 11: 'NNS', 12: 'SYM', 13: '-RRB-', 14: 'PRP$', 15: '.', 16: 'RB', 17: 'NNPS', 18: 'EX', 19: 'JJS', 20: 'RP', 21: 'IN', 22: 'VBP', 23: 'LS', 24: '``', 25: 'TO', 26: 'FW', 27: 'MD', 28: '#', 29: 'PRP', 30: ',', 31: 'RBS', 32: 'NNP', 33: 'VBG', 34: 'VBZ', 35: '$', 36: 'UH', 37: ':', 38: 'WP', 39: 'VB', 40: 'PAD', 41: 'WDT', 42: 'PDT', 43: 'VBD', 44: 'JJ', 45: 'NN', 46: "''"}


# Build word-idx dict and reverse dict

In [9]:
def build_word_idx(word_set):
    i = 0
    word2idx = dict()
    idx2word = dict()
    for word in word_set:
        word2idx[word] = i
        idx2word[i] = word
        i += 1
    return word2idx, idx2word

word2idx, idx2word = build_label_idx(word_set)
print(word2idx, idx2word)



# Compute the emission probability of every word

In [11]:
def compute_prob(word_freq, word_tag_pair_freq):
    freq_table = np.zeros((num_word, num_tag))
    for tup in word_tag_pair_freq.keys():
        #print(tup)
        r, c = word2idx[tup[0]], label2idx[tup[1]]
        #print(r, c)
        freq_table[r, c] = float(word_tag_pair_freq[tup] / word_freq[tup[0]])
        #print(freq_table[r,c])
    return freq_table

prob_dict = compute_prob(word_freq, word_tag_pair_freq)

# Function that used to build the probability matrix of a sentence

In [12]:
def build_sent_prob(prob_dict, sent):
    sent_prob = []
    sent_tags = np.zeros(max_len, dtype=np.long)
    for i, tup in enumerate(sent):
        sent_prob.append(prob_dict[word2idx[tup[0]], :])
        sent_tags[i] = label2idx[tup[1]]
    sent_lst = np.array(sent_prob)
    return sent_prob, sent_tags
#build_sent_prob(prob_dict, sents[0])

# Compute per-word-error

In [21]:
def per_word_err(predict, label):
    total = 0
    err = 0
    for i in range(len(predict)):
        for j in range(len(predict[i])):
            if label[i][j] != predict[i][j]:
                err += 1
            total += 1
    return err / total

# build probablity and tags matrix for the corpus

In [13]:
train_sent_lst = []
tr_tags_lst = []
test_sent_lst = []
te_tags_lst = []
count = 0
for i, sent in enumerate(train_data):
    tr_tmp1, tr_tmp2 = build_sent_prob(prob_dict, sent)
    """
    if count == 1:
        print(tmp1)
    """
    train_sent_lst.append(tr_tmp1)
    tr_tags_lst.append(tr_tmp2)
    #count += 1
for i, sent in enumerate(test_data):
    te_tmp1, te_tmp2 = build_sent_prob(prob_dict, sent)
    test_sent_lst.append(te_tmp1)
    te_tags_lst.append(te_tmp2)
    

x_train = np.array(train_sent_lst)
y_train = np.array(tr_tags_lst)
x_test = np.array(test_sent_lst)
y_test = np.array(te_tags_lst)

# Generate masks

In [14]:
def gen_mask(sents):
    mask = torch.zeros((len(sents), max_len), dtype=torch.bool)
    for i, sent in enumerate(sents):
        for j, tup in enumerate(sent):
            if tup[0] != pad_token:
                mask[i, j] = 1
    return mask

mask_train = gen_mask(train_data)

# Using cuda for learning

In [15]:
use_cuda = torch.cuda.is_available()
print(use_cuda)
torch.cuda.empty_cache()

#np.save('./prob.npy', sent_all_prob)

if use_cuda:
    ten_x_train, ten_y_train = torch.from_numpy(x_train), torch.from_numpy(y_train)
    #print(mask[:, 0].all())
    ten_y_train = ten_y_train.long().cuda()
    ten_x_train = ten_x_train.cuda()
    mask_train = mask_train.cuda()

True


In [19]:
model = CRF(num_tag, batch_first=True)
model.reset_parameters()
model = model.cuda()

loss = model(ten_x_train, ten_y_train, mask=mask_train)
print("The loss is {}".format(loss.detach().item()))

The loss is -236747.2512728565


In [20]:
mask_test = gen_mask(test_data)
ten_x_test = torch.from_numpy(x_test)
#print(mask[:, 0].all())
ten_x_test = ten_x_test.cuda()
mask_test = mask_test.cuda()

decode_seq = model.decode(ten_x_test, mask = mask_test)

In [23]:
per_word_err(decode_seq, y_test)

0.049856928076046365