In [1]:
import os,sys,inspect
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0,parentdir)

In [2]:
import torch

from wiser.data.dataset_readers import LaptopsDatasetReader
from wiser.rules import TaggingRule, LinkingRule, DictionaryMatcher
from wiser.generative import get_label_to_ix, get_rules
from labelmodels import *
from wiser.generative import train_generative_model
from labelmodels import LearningConfig
from wiser.generative import evaluate_generative_model
from wiser.data import save_label_distribution
from wiser.rules import ElmoLinkingRule
from wiser.eval import *
from collections import Counter

from nltk.tokenize import word_tokenize
from nltk.tokenize import sent_tokenize
from tokenizations import get_alignments, get_original_spans
from typing import List, Optional, Tuple

from transformers import *
from xml.etree import ElementTree

## Load data

In [3]:
root = "../data/"
reader = LaptopsDatasetReader()
train_data = reader.read(root + 'LaptopReview/Laptop_Train_v2.xml')
test_data = reader.read(root + 'LaptopReview/Laptops_Test_Data_phaseB.xml')

laptops_docs = train_data + test_data

HBox(children=(HTML(value='reading instances'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width…

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3045.0), HTML(value='')))







HBox(children=(HTML(value='reading instances'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width…

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=800.0), HTML(value='')))







## Tagging functions

In [4]:
dict_core = set()
with open(root + 'AutoNER_dicts/LaptopReview/dict_core.txt') as f:
    for line in f.readlines():
        line = line.strip().split()
        term = tuple(line[1:])
        dict_core.add(term)


dict_full = set()

with open(root + 'AutoNER_dicts/LaptopReview/dict_full.txt') as f:
    for line in f.readlines():
        line = line.strip().split()
        if len(line) > 1:
            dict_full.add(tuple(line))

lf = DictionaryMatcher("CoreDictionary", dict_core, uncased=True, i_label="I")
lf.apply(laptops_docs)

other_terms = [['BIOS'], ['color'], ['cord'], ['hinge'], ['hinges'],
               ['port'], ['speaker']]
lf = DictionaryMatcher("OtherTerms", other_terms, uncased=True, i_label="I")
lf.apply(laptops_docs)


class ReplaceThe(TaggingRule):
    def apply_instance(self, instance):
        tokens = [token.text for token in instance['tokens']]
        labels = ['ABS'] * len(tokens)

        for i in range(len(tokens) - 2):
            if tokens[i].lower() == 'replace' and tokens[i +
                                                         1].lower() == 'the':
                if instance['tokens'][i + 2].pos_ == "NOUN":
                    labels[i] = 'O'
                    labels[i + 1] = 'O'
                    labels[i + 2] = 'I'

        return labels


lf = ReplaceThe()
lf.apply(laptops_docs)


class iStuff(TaggingRule):
    def apply_instance(self, instance):
        tokens = [token.text for token in instance['tokens']]
        labels = ['ABS'] * len(tokens)

        for i in range(len(tokens)):
            if len(
                    tokens[i]) > 1 and tokens[i][0] == 'i' and tokens[i][1].isupper():
                labels[i] = 'I'

        return labels


lf = iStuff()
lf.apply(laptops_docs)


class Feelings(TaggingRule):
    feeling_words = {"like", "liked", "love", "dislike", "hate"}

    def apply_instance(self, instance):
        tokens = [token.text for token in instance['tokens']]
        labels = ['ABS'] * len(tokens)

        for i in range(len(tokens) - 2):
            if tokens[i].lower() in self.feeling_words and tokens[i +
                                                                  1].lower() == 'the':
                if instance['tokens'][i + 2].pos_ == "NOUN":
                    labels[i] = 'O'
                    labels[i + 1] = 'O'
                    labels[i + 2] = 'I'

        return labels


lf = Feelings()
lf.apply(laptops_docs)


class ProblemWithThe(TaggingRule):
    def apply_instance(self, instance):
        tokens = [token.text for token in instance['tokens']]
        labels = ['ABS'] * len(tokens)

        for i in range(len(tokens) - 3):
            if tokens[i].lower() == 'problem' and tokens[i + \
                               1].lower() == 'with' and tokens[i + 2].lower() == 'the':
                if instance['tokens'][i + 3].pos_ == "NOUN":
                    labels[i] = 'O'
                    labels[i + 1] = 'O'
                    labels[i + 2] = 'O'
                    labels[i + 3] = 'I'

        return labels


lf = ProblemWithThe()
lf.apply(laptops_docs)


class External(TaggingRule):
    def apply_instance(self, instance):
        tokens = [token.text for token in instance['tokens']]
        labels = ['ABS'] * len(tokens)

        for i in range(len(tokens) - 1):
            if tokens[i].lower() == 'external':
                labels[i] = 'I'
                labels[i + 1] = 'I'

        return labels


lf = External()
lf.apply(laptops_docs)


stop_words = {"a", "and", "as", "be", "but", "do", "even",
              "for", "from",
              "had", "has", "have", "i", "in", "is", "its", "just",
              "my", "no", "not", "of", "on", "or",
              "that", "the", "these", "this", "those", "to", "very",
              "what", "which", "who", "with"}


class StopWords(TaggingRule):

    def apply_instance(self, instance):
        labels = ['ABS'] * len(instance['tokens'])

        for i in range(len(instance['tokens'])):
            if instance['tokens'][i].lemma_ in stop_words:
                labels[i] = 'O'
        return labels


lf = StopWords()
lf.apply(laptops_docs)


class Punctuation(TaggingRule):
    pos = {"PUNCT"}

    def apply_instance(self, instance):
        labels = ['ABS'] * len(instance['tokens'])

        for i, pos in enumerate([token.pos_ for token in instance['tokens']]):
            if pos in self.pos:
                labels[i] = 'O'

        return labels


lf = Punctuation()
lf.apply(laptops_docs)


class Pronouns(TaggingRule):
    pos = {"PRON"}

    def apply_instance(self, instance):
        labels = ['ABS'] * len(instance['tokens'])

        for i, pos in enumerate([token.pos_ for token in instance['tokens']]):
            if pos in self.pos:
                labels[i] = 'O'

        return labels


lf = Pronouns()
lf.apply(laptops_docs)


class NotFeatures(TaggingRule):
    keywords = {"laptop", "computer", "pc"}

    def apply_instance(self, instance):
        labels = ['ABS'] * len(instance['tokens'])

        for i in range(len(instance['tokens'])):
            if instance['tokens'][i].lemma_ in self.keywords:
                labels[i] = 'O'
        return labels


lf = NotFeatures()
lf.apply(laptops_docs)


class Adv(TaggingRule):
    pos = {"ADV"}

    def apply_instance(self, instance):
        labels = ['ABS'] * len(instance['tokens'])

        for i, pos in enumerate([token.pos_ for token in instance['tokens']]):
            if pos in self.pos:
                labels[i] = 'O'

        return labels


lf = Adv()
lf.apply(laptops_docs)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3845.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3845.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3845.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3845.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3845.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3845.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3845.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3845.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3845.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3845.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3845.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3845.0), HTML(value='')))




## Linking functions

In [5]:
class CompoundPhrase(LinkingRule):
    def apply_instance(self, instance):
        links = [0] * len(instance['tokens'])
        for i in range(1, len(instance['tokens'])):
            if instance['tokens'][i - 1].dep_ == "compound":
                links[i] = 1

        return links


lf = CompoundPhrase()
lf.apply(laptops_docs)


lf = ElmoLinkingRule(.8)
lf.apply(laptops_docs)


class ExtractedPhrase(LinkingRule):
    def __init__(self, terms):
        self.term_dict = {}

        for term in terms:
            term = [token.lower() for token in term]
            if term[0] not in self.term_dict:
                self.term_dict[term[0]] = []
            self.term_dict[term[0]].append(term)

        # Sorts the terms in decreasing order so that we match the longest
        # first
        for first_token in self.term_dict.keys():
            to_sort = self.term_dict[first_token]
            self.term_dict[first_token] = sorted(
                to_sort, reverse=True, key=lambda x: len(x))

    def apply_instance(self, instance):
        tokens = [token.text.lower() for token in instance['tokens']]
        links = [0] * len(instance['tokens'])

        i = 0
        while i < len(tokens):
            if tokens[i] in self.term_dict:
                candidates = self.term_dict[tokens[i]]
                for c in candidates:
                    # Checks whether normalized AllenNLP tokens equal the list
                    # of string tokens defining the term in the dictionary
                    if i + len(c) <= len(tokens):
                        equal = True
                        for j in range(len(c)):
                            if tokens[i + j] != c[j]:
                                equal = False
                                break

                        # If tokens match, labels the instance tokens
                        if equal:
                            for j in range(i + 1, i + len(c)):
                                links[j] = 1
                            i = i + len(c) - 1
                            break
            i += 1

        return links


lf = ExtractedPhrase(dict_full)
lf.apply(laptops_docs)


class ConsecutiveCapitals(LinkingRule):
    def apply_instance(self, instance):
        links = [0] * len(instance['tokens'])
        # We skip the first pair since the first
        # token is almost always capitalized
        for i in range(2, len(instance['tokens'])):
            # We skip this token if it all capitals
            all_caps = True
            text = instance['tokens'][i].text
            for char in text:
                if char.islower():
                    all_caps = False
                    break

            if not all_caps and text[0].isupper(
            ) and instance['tokens'][i - 1].text[0].isupper():
                links[i] = 1

        return links


lf = ConsecutiveCapitals()
lf.apply(laptops_docs)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3845.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3845.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3845.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3845.0), HTML(value='')))




## Define Functions

In [6]:
def txt_to_token_span(tokens: List[str],
                      text: str,
                      txt_spans: List[tuple]):
    """
    Transfer text-domain spans to token-domain spans
    :param tokens: tokens
    :param text: text
    :param txt_spans: text spans tuples: (start, end, ...)
    :return: a list of transferred span tuples.
    """
    token_indices = get_original_spans(tokens, text)
    tgt_spans = list()
    for txt_span in txt_spans:
        spacy_start = txt_span[0]
        spacy_end = txt_span[1]
        start = None
        end = None
        for i, (s, e) in enumerate(token_indices):
            if s <= spacy_start < e:
                start = i
            if s <= spacy_end <= e:
                end = i + 1
            if (start is not None) and (end is not None):
                break
        assert (start is not None) and (end is not None), ValueError("input spans out of scope")
        tgt_spans.append((start, end))
    return tgt_spans

In [7]:
def respan(src_tokens: List[str],
           tgt_tokens: List[str],
           src_span: List[tuple]):
    """
    transfer original spans to target spans
    :param src_tokens: source tokens
    :param tgt_tokens: target tokens
    :param src_span: a list of span tuples. The first element in the tuple
    should be the start index and the second should be the end index
    :return: a list of transferred span tuples.
    """
    s2t, _ = get_alignments(src_tokens, tgt_tokens)
    tgt_spans = list()
    for spans in src_span:
        start = s2t[spans[0]][0]
        if spans[1] < len(s2t):
            end = s2t[spans[1]-1][-1] + 1
        else:
            end = s2t[-1][-1]
        if end == start:
            end += 1
        tgt_spans.append((start, end))

    return tgt_spans

In [8]:
def label_to_span(labels: List[str],
                  scheme: Optional[str] = 'BIO') -> dict:
    """
    convert labels to spans
    :param labels: a list of labels
    :param scheme: labeling scheme, in ['BIO', 'BILOU'].
    :return: labeled spans, a list of tuples (start_idx, end_idx, label)
    """
    assert scheme in ['BIO', 'BILOU'], ValueError("unknown labeling scheme")

    labeled_spans = dict()
    i = 0
    while i < len(labels):
        if labels[i] == 'O' or labels[i] == 'ABS':
            i += 1
            continue
        else:
            if scheme == 'BIO':
                if labels[i][0] == 'B':
                    start = i
                    lb = labels[i][2:]
                    i += 1
                    try:
                        while labels[i][0] == 'I':
                            i += 1
                        end = i
                        labeled_spans[(start, end)] = lb
                    except IndexError:
                        end = i
                        labeled_spans[(start, end)] = lb
                        i += 1
                # this should not happen
                elif labels[i][0] == 'I':
                    i += 1
            elif scheme == 'BILOU':
                if labels[i][0] == 'U':
                    start = i
                    end = i + 1
                    lb = labels[i][2:]
                    labeled_spans[(start, end)] = lb
                    i += 1
                elif labels[i][0] == 'B':
                    start = i
                    lb = labels[i][2:]
                    i += 1
                    try:
                        while labels[i][0] != 'L':
                            i += 1
                        end = i
                        labeled_spans[(start, end)] = lb
                    except IndexError:
                        end = i
                        labeled_spans[(start, end)] = lb
                        break
                    i += 1
                else:
                    i += 1

    return labeled_spans

In [9]:
def build_bert_emb(sents: List[str],
                   tokenizer,
                   model,
                   device: str):
    bert_embs = list()
    for i, sent in enumerate(sents):

        joint_sent = ' '.join(sent)
        bert_tokens = tokenizer.tokenize(joint_sent)

        input_ids = torch.tensor([tokenizer.encode(joint_sent, add_special_tokens=True)], device=device)
        # calculate BERT last layer embeddings
        with torch.no_grad():
            last_hidden_states = model(input_ids)[0].squeeze(0).to('cpu')
            trunc_hidden_states = last_hidden_states[1:-1, :]

        ori2bert, bert2ori = get_alignments(sent, bert_tokens)

        emb_list = list()
        for idx in ori2bert:
            emb = trunc_hidden_states[idx, :]
            emb_list.append(emb.mean(dim=0))

        # TODO: using the embedding of [CLS] may not be the best idea
        # It does not matter since that embedding is not used in the training
        emb_list = [last_hidden_states[0, :]] + emb_list
        bert_emb = torch.stack(emb_list)
        bert_embs.append(bert_emb)
    return bert_embs

## Construct my dataset

In [10]:
LABEL = 'TERM'
LINK = 'LINK'

In [11]:
root = "../data/"
train_path = root + 'LaptopReview/Laptop_Train_v2.xml'
test_path = root + 'LaptopReview/Laptops_Test_Data_phaseB.xml'

root = ElementTree.parse(train_path).getroot()
train_xml_sents = root.findall("./sentence")
root = ElementTree.parse(test_path).getroot()
test_xml_sents = root.findall("./sentence")

In [12]:
sentences = list()
for xml_sent in train_xml_sents:
    text = xml_sent.find("text").text
    sentences.append(text)
for xml_sent in test_xml_sents:
    text = xml_sent.find("text").text
    sentences.append(text)

In [13]:
src_token_list = list()
src_anno_list = list()
weak_anno_list = list()
link_anno_list = list()

allen_data = laptops_docs
mapping_dict = {0:'O', 1:'I'}

for src_txt, allen_annos in zip(sentences, allen_data):
    
    # handle the data read from the source text
    src_tokens = word_tokenize(src_txt)
    for i in range(len(src_tokens)):
        if src_tokens[i] == r'``' or src_tokens[i] == r"''":
            src_tokens[i] = r'"'
    
    
    allen_tokens = list(map(str, allen_annos['tokens']))
    
    src_labels = list(allen_annos['tags'])
    pre_anno = 'O'
    for i in range(len(src_labels)):
        current_anno = src_labels[i]
        if src_labels[i] == 'I':
            if pre_anno != 'I':
                src_labels[i] = 'B-' + LABEL
            else:
                src_labels[i] = 'I-' + LABEL
        pre_anno = current_anno
    src_spans = label_to_span(src_labels)
    src_spans = respan(allen_tokens, src_tokens, src_spans)
    
    src_annos = dict()
    for span in src_spans:
        src_annos[span] = LABEL

    src_token_list.append(src_tokens)
    src_anno_list.append(src_annos)
    
    # handle the data constructed using Allennlp
    weak_anno = dict()
    
    for k in allen_annos['WISER_LABELS']:
        std_lbs = allen_annos['WISER_LABELS'][k][:]
        
        pre_anno = 'O'
        for i in range(len(std_lbs)):
            current_anno = std_lbs[i]
            if std_lbs[i] == 'I':
                if pre_anno != 'I':
                    std_lbs[i] = 'B-' + LABEL
                else:
                    std_lbs[i] = 'I-' + LABEL
            pre_anno = current_anno
        weak_span = label_to_span(std_lbs)

        src_weak_span = respan(allen_tokens, src_tokens, weak_span)
        src_weak_anno = dict()
        for span in src_weak_span:
            src_weak_anno[span] = [(LABEL, 1.0)]
            
        weak_anno[k] = src_weak_anno
    weak_anno_list.append(weak_anno)

    
    linked_dict = dict()
    for src, entity_lbs in allen_annos['WISER_LINKS'].items():
        entity_lbs = [mapping_dict[lb] for lb in entity_lbs]

        pre_anno = 'O'
        for i in range(len(entity_lbs)):
            current_anno = entity_lbs[i]
            if entity_lbs[i] == 'I':
                if pre_anno != 'I':
                    entity_lbs[i] = 'B-' + LINK
                else:
                    entity_lbs[i] = 'I-' + LINK
            pre_anno = current_anno

        entity_spans = label_to_span(entity_lbs)
        complete_span = dict()
        for (start, end), lb in entity_spans.items():
            if start != 0:
                start = start - 1
            complete_span[(start, end)] = lb
        src_link_span = respan(allen_tokens, src_tokens, complete_span)
        linked_dict[src] = src_link_span
    link_anno_list.append(linked_dict)

In [14]:
updated_link_anno_list = list()
for tag_anno, link_anno in zip(weak_anno_list, link_anno_list):
    tag_spans = list()
    for src, spans in tag_anno.items():
        tag_spans += list(spans.keys())
    for i in range(len(tag_spans)):
        tag_spans[i] = set(range(tag_spans[i][0], tag_spans[i][1]))
    
    link_entities = dict()
    for src, spans in link_anno.items():
        valid_spans = list()
        for span in spans:
            if span[1] - span[0] == 1:
                continue
            span_set = set(range(span[0], span[1]))
            for tag_span in tag_spans:
                if span_set.intersection(tag_span):
                    valid_spans.append(span)
        valid_anno = {span: [(LABEL, 1.0)] for span in valid_spans}
        link_entities[src] = valid_anno
    updated_link_anno_list.append(link_entities)

In [15]:
combined_anno_list = list()
for tag_anno, link_anno in zip(weak_anno_list, updated_link_anno_list):
    comb_anno = dict()
    for k, v in tag_anno.items():
        comb_anno[k] = v
    for k, v in link_anno.items():
        comb_anno[k] = v
    combined_anno_list.append(comb_anno)

## Build Bert Embeddings

In [16]:
model_class = BertModel
tokenizer_class = BertTokenizer
# pretrained_model_name = 'bert-base-cased'
pretrained_model_name = 'bert-base-uncased'

tokenizer = tokenizer_class.from_pretrained(pretrained_model_name)
model = model_class.from_pretrained(pretrained_model_name)

In [17]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = model.to(device)
embs = build_bert_emb(src_token_list, tokenizer, model, device)

In [18]:
len(embs[-800:])

800

In [19]:
test_token_list = src_token_list[-800:]
test_anno_list = combined_anno_list[-800:]
test_lb_list = src_anno_list[-800:]
test_emb = embs[-800:]

indices = np.arange(len(src_token_list)-800)
np.random.shuffle(indices)
train_partition = (len(src_token_list)-800) * 4 // 5

train_token_list = list()
train_anno_list = list()
train_lb_list = list()
train_emb = list()
for i in indices[:train_partition]:
    train_token_list.append(src_token_list[i])
    train_anno_list.append(combined_anno_list[i])
    train_lb_list.append(src_anno_list[i])
    train_emb.append(embs[i])

dev_token_list = list()
dev_anno_list = list()
dev_lb_list = list()
dev_emb = list()
for i in indices[train_partition:]:
    dev_token_list.append(src_token_list[i])
    dev_anno_list.append(combined_anno_list[i])
    dev_lb_list.append(src_anno_list[i])
    dev_emb.append(embs[i])

## Save Data

In [20]:
train_data = {
    "sentences": train_token_list,
    "annotations": train_anno_list,
    "labels": train_lb_list
}

torch.save(train_data, f"Laptop-linked-train.pt")
torch.save(train_emb, f"Laptop-emb-train.pt")

dev_data = {
    "sentences": dev_token_list,
    "annotations": dev_anno_list,
    "labels": dev_lb_list
}

torch.save(dev_data, f"Laptop-linked-dev.pt")
torch.save(dev_emb, f"Laptop-emb-dev.pt")

test_data = {
    "sentences": test_token_list,
    "annotations": test_anno_list,
    "labels": test_lb_list
}

torch.save(test_data, f"Laptop-linked-test.pt")
torch.save(test_emb, f"Laptop-emb-test.pt")