In [None]:
import numpy as np
import spacy
import torch
import os
from Core.Constants import *
from Core import Annotate, IO, Util, Data, HMM
from tqdm.auto import tqdm
from transformers import *

---
## Load data and corresponding labels

In [None]:
DATA_PARTITION = "dev"

In [None]:
if DATA_PARTITION == "all":
    articles = list()
    labels = list()
    for name in ["train", "dev", "test"]:
        data = torch.load(f'CoNLL03-{name}.pt')
        articles += data['documents']
        labels += data['labels']
else:
    data = torch.load(f'CoNLL03-{DATA_PARTITION}.pt')
    articles = data['documents']
    labels = data['labels']

---
## Use SpaCy to get the weak labels and training priors

### Construct SpaCy documents from plain text

In [None]:
nlp = spacy.load('en_core_web_md')
docs = []
for sents in tqdm(articles):
    doc = Annotate.construct_doc(sents, nlp)
    docs.append(doc)

### Load annotators and annotate documents with weak labels

In [None]:
united_annotator = Annotate.UnitedAnnotator().add_all()

for doc in tqdm(docs):
    doc = united_annotator.annotate(doc)
torch.save(docs, f"CoNLL03-SpaCy-{DATA_PARTITION}.pt")

### Extract and save training priors

In [None]:
docs = torch.load(f"CoNLL03-SpaCy-{DATA_PARTITION}.pt")

In [None]:
sources_to_use = [l for l in SOURCE_NAMES if "conll2003" not in l]
hmm_model = HMM.HMMAnnotator(
    sources_to_keep=sources_to_use
)

In [None]:
x = [hmm_model.extract_sequence(doc) for doc in docs]

hmm_model._initialise_startprob(x)
hmm_model._initialise_transmat(x)
hmm_model._initialise_emissions(x)

initial_statistics = {
    "state_prior_count": hmm_model.startprob_prior,
    "state_prior": hmm_model.startprob_,
    "transition_count": hmm_model.transmat_prior,
    "transition_matrix": hmm_model.transmat_,
    "emission_strength": hmm_model.emission_priors,
    "emission_matrix": hmm_model.emission_probs
}
torch.save(initial_statistics, f'CoNLL03-init-stat-{DATA_PARTITION}.pt')

---
## Convert SpaCy annotation spans to the original sentence-level spans

In [None]:
docs = torch.load(f"CoNLL03-SpaCy-{DATA_PARTITION}.pt")

In [None]:
sentences = list()
sent_level_annos = list()
for doc, article in zip(docs, articles):
    assert len(list(doc.sents)) == len(article)
    sentences += article
    sent_level_annos += Data.annotate_doc_with_spacy(article, doc)

In [None]:
lb_spans = list()
for doc_labels in labels:
    for sent_labels in doc_labels:
        lb_spans.append(Data.label_to_span(sent_labels))

In [None]:
data = {
    "sentences": sentences,
    "annotations": sent_level_annos,
    "labels": lb_spans,
}
torch.save(data, f"Co03-linked-{DATA_PARTITION}.pt")

---
## Build BERT embedding for each sentence

In [None]:
model_class = BertModel
tokenizer_class = BertTokenizer
# pretrained_model_name = 'bert-base-cased'
pretrained_model_name = 'bert-base-uncased'

tokenizer = tokenizer_class.from_pretrained(pretrained_model_name)
model = model_class.from_pretrained(pretrained_model_name).to("cuda")

In [None]:
bert_embs = Data.build_bert_emb(sentences, tokenizer, model, 'cuda')

In [None]:
torch.save(bert_embs, f"Co03-emb-{DATA_PARTITION}.pt")