<a href="https://colab.research.google.com/github/MugiwaraNoRushi/Nlp-Project-2022/blob/main/Nlp_final_project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Import and Copy

In [1]:
# google drive imports
from google.colab import drive

In [2]:
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [3]:
!cp -R gdrive/MyDrive/Project-NLP ./
# Follow github instructions to create this folder Project-NLP and save it in your drive !! 

In [None]:
# First time when you do this, it will say to restart the runtime, do it and it will start working! 
!pip install -r Project-NLP/requirements.txt

In [None]:
# The runtime env is ready to go
# we can run our models here 
# The gpus have a time limit of around 4 hours

In [17]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss

from allennlp.nn.util import batched_index_select
from allennlp.nn import util, Activation
from allennlp.modules import FeedForward

import numpy as np

from transformers import BertTokenizer, BertPreTrainedModel, BertModel
from transformers import AlbertTokenizer, AlbertPreTrainedModel, AlbertModel

import os
import json
import logging

In [19]:
# Data Set import
import copy
from collections import Counter
from torch.utils.data import DataLoader, TensorDataset

In [20]:
# Main function import
import argparse
import sys
import random
import time
from tqdm import tqdm
from transformers import AdamW, get_linear_schedule_with_warmup

# Model

In [8]:
logger = logging.getLogger('root')

In [6]:
class BertForEntity(BertPreTrainedModel):
    def __init__(self, config, num_ner_labels, head_hidden_dim=150, width_embedding_dim=150, max_span_length=8):
        super().__init__(config)

        self.bert = BertModel(config)
        self.hidden_dropout = nn.Dropout(config.hidden_dropout_prob)
        self.width_embedding = nn.Embedding(max_span_length+1, width_embedding_dim)
        
        self.ner_classifier = nn.Sequential(
            FeedForward(input_dim=config.hidden_size*2+width_embedding_dim, 
                        num_layers=2,
                        hidden_dims=head_hidden_dim,
                        activations=F.relu,
                        dropout=0.2),
            nn.Linear(head_hidden_dim, num_ner_labels)
        )

        self.init_weights()

    def _get_span_embeddings(self, input_ids, spans, token_type_ids=None, attention_mask=None):
        sequence_output, pooled_output = self.bert(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
        
        sequence_output = self.hidden_dropout(sequence_output)

        """
        spans: [batch_size, num_spans, 3]; 0: left_ned, 1: right_end, 2: width
        spans_mask: (batch_size, num_spans, )
        """
        spans_start = spans[:, :, 0].view(spans.size(0), -1)
        spans_start_embedding = batched_index_select(sequence_output, spans_start)
        spans_end = spans[:, :, 1].view(spans.size(0), -1)
        spans_end_embedding = batched_index_select(sequence_output, spans_end)

        spans_width = spans[:, :, 2].view(spans.size(0), -1)
        spans_width_embedding = self.width_embedding(spans_width)

        # Concatenate embeddings of left/right points and the width embedding
        spans_embedding = torch.cat((spans_start_embedding, spans_end_embedding, spans_width_embedding), dim=-1)
        """
        spans_embedding: (batch_size, num_spans, hidden_size*2+embedding_dim)
        """
        return spans_embedding

    def forward(self, input_ids, spans, spans_mask, spans_ner_label=None, token_type_ids=None, attention_mask=None):
        spans_embedding = self._get_span_embeddings(input_ids, spans, token_type_ids=token_type_ids, attention_mask=attention_mask)
        ffnn_hidden = []
        hidden = spans_embedding
        for layer in self.ner_classifier:
            hidden = layer(hidden)
            ffnn_hidden.append(hidden)
        logits = ffnn_hidden[-1]

        if spans_ner_label is not None:
            loss_fct = CrossEntropyLoss(reduction='sum')
            if attention_mask is not None:
                active_loss = spans_mask.view(-1) == 1
                active_logits = logits.view(-1, logits.shape[-1])
                active_labels = torch.where(
                    active_loss, spans_ner_label.view(-1), torch.tensor(loss_fct.ignore_index).type_as(spans_ner_label)
                )
                loss = loss_fct(active_logits, active_labels)
            else:
                loss = loss_fct(logits.view(-1, logits.shape[-1]), spans_ner_label.view(-1))
            return loss, logits, spans_embedding
        else:
            return logits, spans_embedding, spans_embedding

In [7]:
class EntityModel():

    def __init__(self, args, num_ner_labels):
        super().__init__()

        bert_model_name = args.model
        vocab_name = bert_model_name
        
        if args.bert_model_dir is not None:
            bert_model_name = str(args.bert_model_dir) + '/'
            # vocab_name = bert_model_name + 'vocab.txt'
            vocab_name = bert_model_name
           

       
        self.tokenizer = BertTokenizer.from_pretrained(vocab_name)
        self.bert_model = BertForEntity.from_pretrained(bert_model_name, num_ner_labels=num_ner_labels, max_span_length=args.max_span_length)

        self._model_device = 'cpu'
        self.move_model_to_cuda()

    def move_model_to_cuda(self):
        if not torch.cuda.is_available():
            logger.error('No CUDA found!')
            exit(-1)
        logger.info('Moving to CUDA...')
        self._model_device = 'cuda'
        self.bert_model.cuda()
        logger.info('# GPUs = %d'%(torch.cuda.device_count()))
        if torch.cuda.device_count() > 1:
            self.bert_model = torch.nn.DataParallel(self.bert_model)

    def _get_input_tensors(self, tokens, spans, spans_ner_label):
        start2idx = []
        end2idx = []
        
        bert_tokens = []
        bert_tokens.append(self.tokenizer.cls_token)
        for token in tokens:
            start2idx.append(len(bert_tokens))
            sub_tokens = self.tokenizer.tokenize(token)
            bert_tokens += sub_tokens
            end2idx.append(len(bert_tokens)-1)
        bert_tokens.append(self.tokenizer.sep_token)

        indexed_tokens = self.tokenizer.convert_tokens_to_ids(bert_tokens)
        tokens_tensor = torch.tensor([indexed_tokens])

        bert_spans = [[start2idx[span[0]], end2idx[span[1]], span[2]] for span in spans]
        bert_spans_tensor = torch.tensor([bert_spans])

        spans_ner_label_tensor = torch.tensor([spans_ner_label])

        return tokens_tensor, bert_spans_tensor, spans_ner_label_tensor

    def _get_input_tensors_batch(self, samples_list, training=True):
        tokens_tensor_list = []
        bert_spans_tensor_list = []
        spans_ner_label_tensor_list = []
        sentence_length = []

        max_tokens = 0
        max_spans = 0
        for sample in samples_list:
            tokens = sample['tokens']
            spans = sample['spans']
            spans_ner_label = sample['spans_label']

            tokens_tensor, bert_spans_tensor, spans_ner_label_tensor = self._get_input_tensors(tokens, spans, spans_ner_label)
            tokens_tensor_list.append(tokens_tensor)
            bert_spans_tensor_list.append(bert_spans_tensor)
            spans_ner_label_tensor_list.append(spans_ner_label_tensor)
            assert(bert_spans_tensor.shape[1] == spans_ner_label_tensor.shape[1])
            if (tokens_tensor.shape[1] > max_tokens):
                max_tokens = tokens_tensor.shape[1]
            if (bert_spans_tensor.shape[1] > max_spans):
                max_spans = bert_spans_tensor.shape[1]
            sentence_length.append(sample['sent_length'])
        sentence_length = torch.Tensor(sentence_length)

        # apply padding and concatenate tensors
        final_tokens_tensor = None
        final_attention_mask = None
        final_bert_spans_tensor = None
        final_spans_ner_label_tensor = None
        final_spans_mask_tensor = None
        for tokens_tensor, bert_spans_tensor, spans_ner_label_tensor in zip(tokens_tensor_list, bert_spans_tensor_list, spans_ner_label_tensor_list):
            # padding for tokens
            num_tokens = tokens_tensor.shape[1]
            tokens_pad_length = max_tokens - num_tokens
            attention_tensor = torch.full([1,num_tokens], 1, dtype=torch.long)
            if tokens_pad_length>0:
                pad = torch.full([1,tokens_pad_length], self.tokenizer.pad_token_id, dtype=torch.long)
                tokens_tensor = torch.cat((tokens_tensor, pad), dim=1)
                attention_pad = torch.full([1,tokens_pad_length], 0, dtype=torch.long)
                attention_tensor = torch.cat((attention_tensor, attention_pad), dim=1)

            # padding for spans
            num_spans = bert_spans_tensor.shape[1]
            spans_pad_length = max_spans - num_spans
            spans_mask_tensor = torch.full([1,num_spans], 1, dtype=torch.long)
            if spans_pad_length>0:
                pad = torch.full([1,spans_pad_length,bert_spans_tensor.shape[2]], 0, dtype=torch.long)
                bert_spans_tensor = torch.cat((bert_spans_tensor, pad), dim=1)
                mask_pad = torch.full([1,spans_pad_length], 0, dtype=torch.long)
                spans_mask_tensor = torch.cat((spans_mask_tensor, mask_pad), dim=1)
                spans_ner_label_tensor = torch.cat((spans_ner_label_tensor, mask_pad), dim=1)

            # update final outputs
            if final_tokens_tensor is None:
                final_tokens_tensor = tokens_tensor
                final_attention_mask = attention_tensor
                final_bert_spans_tensor = bert_spans_tensor
                final_spans_ner_label_tensor = spans_ner_label_tensor
                final_spans_mask_tensor = spans_mask_tensor
            else:
                final_tokens_tensor = torch.cat((final_tokens_tensor,tokens_tensor), dim=0)
                final_attention_mask = torch.cat((final_attention_mask, attention_tensor), dim=0)
                final_bert_spans_tensor = torch.cat((final_bert_spans_tensor, bert_spans_tensor), dim=0)
                final_spans_ner_label_tensor = torch.cat((final_spans_ner_label_tensor, spans_ner_label_tensor), dim=0)
                final_spans_mask_tensor = torch.cat((final_spans_mask_tensor, spans_mask_tensor), dim=0)
        #logger.info(final_tokens_tensor)
        #logger.info(final_attention_mask)
        #logger.info(final_bert_spans_tensor)
        #logger.info(final_bert_spans_tensor.shape)
        #logger.info(final_spans_mask_tensor.shape)
        #logger.info(final_spans_ner_label_tensor.shape)
        return final_tokens_tensor, final_attention_mask, final_bert_spans_tensor, final_spans_mask_tensor, final_spans_ner_label_tensor, sentence_length

    def run_batch(self, samples_list, try_cuda=True, training=True):
        # convert samples to input tensors
        tokens_tensor, attention_mask_tensor, bert_spans_tensor, spans_mask_tensor, spans_ner_label_tensor, sentence_length = self._get_input_tensors_batch(samples_list, training)

        output_dict = {
            'ner_loss': 0,
        }

        if training:
            self.bert_model.train()
            ner_loss, ner_logits, spans_embedding = self.bert_model(
                input_ids = tokens_tensor.to(self._model_device),
                spans = bert_spans_tensor.to(self._model_device),
                spans_mask = spans_mask_tensor.to(self._model_device),
                spans_ner_label = spans_ner_label_tensor.to(self._model_device),
                attention_mask = attention_mask_tensor.to(self._model_device),
            )
            output_dict['ner_loss'] = ner_loss.sum()
            output_dict['ner_llh'] = F.log_softmax(ner_logits, dim=-1)
        else:
            self.bert_model.eval()
            with torch.no_grad():
                ner_logits, spans_embedding, last_hidden = self.bert_model(
                    input_ids = tokens_tensor.to(self._model_device),
                    spans = bert_spans_tensor.to(self._model_device),
                    spans_mask = spans_mask_tensor.to(self._model_device),
                    spans_ner_label = None,
                    attention_mask = attention_mask_tensor.to(self._model_device),
                )
            _, predicted_label = ner_logits.max(2)
            predicted_label = predicted_label.cpu().numpy()
            last_hidden = last_hidden.cpu().numpy()
            
            predicted = []
            pred_prob = []
            hidden = []
            for i, sample in enumerate(samples_list):
                ner = []
                prob = []
                lh = []
                for j in range(len(sample['spans'])):
                    ner.append(predicted_label[i][j])
                    # prob.append(F.softmax(ner_logits[i][j], dim=-1).cpu().numpy())
                    prob.append(ner_logits[i][j].cpu().numpy())
                    lh.append(last_hidden[i][j])
                predicted.append(ner)
                pred_prob.append(prob)
                hidden.append(lh)
            output_dict['pred_ner'] = predicted
            output_dict['ner_probs'] = pred_prob
            output_dict['ner_last_hidden'] = hidden

        return output_dict

# Utils for Models

In [13]:
def batchify(samples, batch_size):
    """
    Batchfy samples with a batch size
    """
    num_samples = len(samples)

    list_samples_batches = []
    
    # if a sentence is too long, make itself a batch to avoid GPU OOM
    to_single_batch = []
    for i in range(0, len(samples)):
        if len(samples[i]['tokens']) > 350:
            to_single_batch.append(i)
    
    for i in to_single_batch:
        logger.info('Single batch sample: %s-%d', samples[i]['doc_key'], samples[i]['sentence_ix'])
        list_samples_batches.append([samples[i]])
    samples = [sample for i, sample in enumerate(samples) if i not in to_single_batch]

    for i in range(0, len(samples), batch_size):
        list_samples_batches.append(samples[i:i+batch_size])

    assert(sum([len(batch) for batch in list_samples_batches]) == num_samples)

    return list_samples_batches

def overlap(s1, s2):
    if s2.start_sent >= s1.start_sent and s2.start_sent <= s1.end_sent:
        return True
    if s2.end_sent >= s1.start_sent and s2.end_sent <= s1.end_sent:
        return True
    return False

In [14]:
def convert_dataset_to_samples(dataset, max_span_length, ner_label2id=None, context_window=0, split=0):
    """
    Extract sentences and gold entities from a dataset
    """
    # split: split the data into train and dev (for ACE04)
    # split == 0: don't split
    # split == 1: return first 90% (train)
    # split == 2: return last 10% (dev)
    samples = []
    num_ner = 0
    max_len = 0
    max_ner = 0
    num_overlap = 0
    
    if split == 0:
        data_range = (0, len(dataset))
    elif split == 1:
        data_range = (0, int(len(dataset)*0.9))
    elif split == 2:
        data_range = (int(len(dataset)*0.9), len(dataset))

    for c, doc in enumerate(dataset):
        if c < data_range[0] or c >= data_range[1]:
            continue
        for i, sent in enumerate(doc):
            num_ner += len(sent.ner)
            sample = {
                'doc_key': doc._doc_key,
                'sentence_ix': sent.sentence_ix,
            }
            if context_window != 0 and len(sent.text) > context_window:
                logger.info('Long sentence: {} {}'.format(sample, len(sent.text)))
                # print('Exclude:', sample)
                # continue
            sample['tokens'] = sent.text
            sample['sent_length'] = len(sent.text)
            sent_start = 0
            sent_end = len(sample['tokens'])

            max_len = max(max_len, len(sent.text))
            max_ner = max(max_ner, len(sent.ner))

            if context_window > 0:
                add_left = (context_window-len(sent.text)) // 2
                add_right = (context_window-len(sent.text)) - add_left
                
                # add left context
                j = i - 1
                while j >= 0 and add_left > 0:
                    context_to_add = doc[j].text[-add_left:]
                    sample['tokens'] = context_to_add + sample['tokens']
                    add_left -= len(context_to_add)
                    sent_start += len(context_to_add)
                    sent_end += len(context_to_add)
                    j -= 1

                # add right context
                j = i + 1
                while j < len(doc) and add_right > 0:
                    context_to_add = doc[j].text[:add_right]
                    sample['tokens'] = sample['tokens'] + context_to_add
                    add_right -= len(context_to_add)
                    j += 1

            sample['sent_start'] = sent_start
            sample['sent_end'] = sent_end
            sample['sent_start_in_doc'] = sent.sentence_start
            
            sent_ner = {}
            for ner in sent.ner:
                sent_ner[ner.span.span_sent] = ner.label

            span2id = {}
            sample['spans'] = []
            sample['spans_label'] = []
            for i in range(len(sent.text)):
                for j in range(i, min(len(sent.text), i+max_span_length)):
                    sample['spans'].append((i+sent_start, j+sent_start, j-i+1))
                    span2id[(i, j)] = len(sample['spans'])-1
                    if (i, j) not in sent_ner:
                        sample['spans_label'].append(0)
                    else:
                        sample['spans_label'].append(ner_label2id[sent_ner[(i, j)]])
            samples.append(sample)
    avg_length = sum([len(sample['tokens']) for sample in samples]) / len(samples)
    max_length = max([len(sample['tokens']) for sample in samples])
    logger.info('# Overlap: %d'%num_overlap)
    logger.info('Extracted %d samples from %d documents, with %d NER labels, %.3f avg input length, %d max length'%(len(samples), data_range[1]-data_range[0], num_ner, avg_length, max_length))
    logger.info('Max Length: %d, max NER: %d'%(max_len, max_ner))
    return samples, num_ner

In [15]:
class NpEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        else:
            return super(NpEncoder, self).default(obj)

# Data Set Class

In [10]:
def fields_to_batches(d, keys_to_ignore=[]):
    keys = [key for key in d.keys() if key not in keys_to_ignore]
    lengths = [len(d[k]) for k in keys]
    assert len(set(lengths)) == 1
    length = lengths[0]
    res = [{k: d[k][i] for k in keys} for i in range(length)]
    return res

def get_sentence_of_span(span, sentence_starts, doc_tokens):
    """
    Return the index of the sentence that the span is part of.
    """
    # Inclusive sentence ends
    sentence_ends = [x - 1 for x in sentence_starts[1:]] + [doc_tokens - 1]
    in_between = [span[0] >= start and span[1] <= end
                  for start, end in zip(sentence_starts, sentence_ends)]
    assert sum(in_between) == 1
    the_sentence = in_between.index(True)
    return the_sentence


In [11]:
class Dataset:
    def __init__(self, json_file, pred_file=None, doc_range=None):
        self.js = self._read(json_file, pred_file)
        if doc_range is not None:
            self.js = self.js[doc_range[0]:doc_range[1]]
        self.documents = [Document(js) for js in self.js]

    def update_from_js(self, js):
        self.js = js
        self.documents = [Document(js) for js in self.js]

    def _read(self, json_file, pred_file=None):
        gold_docs = [json.loads(line) for line in open(json_file)]
        if pred_file is None:
            return gold_docs

        pred_docs = [json.loads(line) for line in open(pred_file)]
        merged_docs = []
        for gold, pred in zip(gold_docs, pred_docs):
            assert gold["doc_key"] == pred["doc_key"]
            assert gold["sentences"] == pred["sentences"]
            merged = copy.deepcopy(gold)
            for k, v in pred.items():
                if "predicted" in k:
                    merged[k] = v
            merged_docs.append(merged)

        return merged_docs

    def __getitem__(self, ix):
        return self.documents[ix]

    def __len__(self):
        return len(self.documents)

####################

# Code to do evaluation of predictions for a loaded dataset.

def safe_div(num, denom):
    if denom > 0:
        return num / denom
    else:
        return 0


def compute_f1(predicted, gold, matched):
    # F1 score.
    precision = safe_div(matched, predicted)
    recall = safe_div(matched, gold)
    f1 = safe_div(2 * precision * recall, precision + recall)
    return dict(precision=precision, recall=recall, f1=f1)


def evaluate_sent(sent, counts):
    correct_ner = set()
    # Entities.
    counts["ner_gold"] += len(sent.ner)
    counts["ner_predicted"] += len(sent.predicted_ner)
    for prediction in sent.predicted_ner:
        if any([prediction == actual for actual in sent.ner]):
            counts["ner_matched"] += 1
            correct_ner.add(prediction.span)

    # Relations.
    counts["relations_gold"] += len(sent.relations)
    counts["relations_predicted"] += len(sent.predicted_relations)
    for prediction in sent.predicted_relations:
        if any([prediction == actual for actual in sent.relations]):
            counts["relations_matched"] += 1
            if (prediction.pair[0] in correct_ner) and (prediction.pair[1] in correct_ner):
                counts["strict_relations_matched"] += 1

    # Return the updated counts.
    return counts

def evaluate_predictions(dataset):
    counts = Counter()

    for doc in dataset:
        for sent in doc:
            counts = evaluate_sent(sent, counts)

    scores_ner = compute_f1(
        counts["ner_predicted"], counts["ner_gold"], counts["ner_matched"])
    scores_relations = compute_f1(
        counts["relations_predicted"], counts["relations_gold"], counts["relations_matched"])
    scores_strict_relations = compute_f1(
        counts["relations_predicted"], counts["relations_gold"], counts["strict_relations_matched"])

    return dict(ner=scores_ner, relation=scores_relations, strict_relation=scores_strict_relations)

def analyze_relation_coverage(dataset):
    
    def overlap(s1, s2):
        if s2.start_sent >= s1.start_sent and s2.start_sent <= s1.end_sent:
            return True
        if s2.end_sent >= s1.start_sent and s2.end_sent <= s1.end_sent:
            return True
        return False

    nrel_gold = 0
    nrel_pred_cover = 0
    nrel_top_cover = 0

    npair_pred = 0
    npair_top = 0

    nrel_overlap = 0

    for d in dataset:
        for s in d:
            pred = set([ner.span for ner in s.predicted_ner])
            top = set([ner.span for ner in s.top_spans])
            npair_pred += len(s.predicted_ner) * (len(s.predicted_ner) - 1)
            npair_top += len(s.top_spans) * (len(s.top_spans) - 1)
            for r in s.relations:
                nrel_gold += 1
                if (r.pair[0] in pred) and (r.pair[1] in pred):
                    nrel_pred_cover += 1
                if (r.pair[0] in top) and (r.pair[1] in top):
                    nrel_top_cover += 1
                
                if overlap(r.pair[0], r.pair[1]):
                    nrel_overlap += 1

    print('Coverage by predicted entities: %.3f (%d / %d), #candidates: %d'%(nrel_pred_cover/nrel_gold*100.0, nrel_pred_cover, nrel_gold, npair_pred))
    print('Coverage by top 0.4 spans: %.3f (%d / %d), #candidates: %d'%(nrel_top_cover/nrel_gold*100.0, nrel_top_cover, nrel_gold, npair_top))
    print('Overlap: %.3f (%d / %d)'%(nrel_overlap / nrel_gold * 100.0, nrel_overlap, nrel_gold))


In [12]:
class Document:
    def __init__(self, js):
        self._doc_key = js["doc_key"]
        entries = fields_to_batches(js, ["doc_key", "clusters", "predicted_clusters", "section_starts"])
        sentence_lengths = [len(entry["sentences"]) for entry in entries]
        sentence_starts = np.cumsum(sentence_lengths)
        sentence_starts = np.roll(sentence_starts, 1)
        sentence_starts[0] = 0
        self.sentence_starts = sentence_starts
        self.sentences = [Sentence(entry, sentence_start, sentence_ix)
                          for sentence_ix, (entry, sentence_start)
                          in enumerate(zip(entries, sentence_starts))]
        if "clusters" in js:
            self.clusters = [Cluster(entry, i, self)
                             for i, entry in enumerate(js["clusters"])]
        if "predicted_clusters" in js:
            self.predicted_clusters = [Cluster(entry, i, self)
                                       for i, entry in enumerate(js["predicted_clusters"])]

    def __repr__(self):
        return "\n".join([str(i) + ": " + " ".join(sent.text) for i, sent in enumerate(self.sentences)])

    def __getitem__(self, ix):
        return self.sentences[ix]

    def __len__(self):
        return len(self.sentences)

    def print_plaintext(self):
        for sent in self:
            print(" ".join(sent.text))


    def find_cluster(self, entity, predicted=True):
        """
        Search through erence clusters and return the one containing the query entity, if it's
        part of a cluster. If we don't find a match, return None.
        """
        clusters = self.predicted_clusters if predicted else self.clusters
        for clust in clusters:
            for entry in clust:
                if entry.span == entity.span:
                    return clust

        return None

    @property
    def n_tokens(self):
        return sum([len(sent) for sent in self.sentences])


class Sentence:
    def __init__(self, entry, sentence_start, sentence_ix):
        self.sentence_start = sentence_start
        self.text = entry["sentences"]
        self.sentence_ix = sentence_ix
        # Gold
        if "ner_flavor" in entry:
            self.ner = [NER(this_ner, self.text, sentence_start, flavor=this_flavor)
                        for this_ner, this_flavor in zip(entry["ner"], entry["ner_flavor"])]
        elif "ner" in entry:
            self.ner = [NER(this_ner, self.text, sentence_start)
                        for this_ner in entry["ner"]]
        if "relations" in entry:
            self.relations = [Relation(this_relation, self.text, sentence_start) for
                              this_relation in entry["relations"]]
        if "events" in entry:
            self.events = Events(entry["events"], self.text, sentence_start)

        # Predicted
        if "predicted_ner" in entry:
            self.predicted_ner = [NER(this_ner, self.text, sentence_start, flavor=None) for
                                  this_ner in entry["predicted_ner"]]
        if "predicted_relations" in entry:
            self.predicted_relations = [Relation(this_relation, self.text, sentence_start) for
                                        this_relation in entry["predicted_relations"]]
        if "predicted_events" in entry:
            self.predicted_events = Events(entry["predicted_events"], self.text, sentence_start)

        # Top spans
        if "top_spans" in entry:
            self.top_spans = [NER(this_ner, self.text, sentence_start, flavor=None) for
                                this_ner in entry["top_spans"]]

    def __repr__(self):
        the_text = " ".join(self.text)
        the_lengths = np.array([len(x) for x in self.text])
        tok_ixs = ""
        for i, offset in enumerate(the_lengths):
            true_offset = offset if i < 10 else offset - 1
            tok_ixs += str(i)
            tok_ixs += " " * true_offset

        return the_text + "\n" + tok_ixs

    def __len__(self):
        return len(self.text)

    def get_flavor(self, argument):
        the_ner = [x for x in self.ner if x.span == argument.span]
        if len(the_ner) > 1:
            print("Weird")
        if the_ner:
            the_flavor = the_ner[0].flavor
        else:
            the_flavor = None
        return the_flavor


class Span:
    def __init__(self, start, end, text, sentence_start):
        self.start_doc = start
        self.end_doc = end
        self.span_doc = (self.start_doc, self.end_doc)
        self.start_sent = start - sentence_start
        self.end_sent = end - sentence_start
        self.span_sent = (self.start_sent, self.end_sent)
        self.text = text[self.start_sent:self.end_sent + 1]

    def __repr__(self):
        return str((self.start_sent, self.end_sent, self.text))

    def __eq__(self, other):
        return (self.span_doc == other.span_doc and
                self.span_sent == other.span_sent and
                self.text == other.text)

    def __hash__(self):
        tup = self.span_doc + self.span_sent + (" ".join(self.text),)
        return hash(tup)


class Token:
    def __init__(self, ix, text, sentence_start):
        self.ix_doc = ix
        self.ix_sent = ix - sentence_start
        self.text = text[self.ix_sent]

    def __repr__(self):
        return str((self.ix_sent, self.text))


class Trigger:
    def __init__(self, token, label):
        self.token = token
        self.label = label

    def __repr__(self):
        return self.token.__repr__()[:-1] + ", " + self.label + ")"


class Argument:
    def __init__(self, span, role, event_type):
        self.span = span
        self.role = role
        self.event_type = event_type

    def __repr__(self):
        return self.span.__repr__()[:-1] + ", " + self.event_type + ", " + self.role + ")"

    def __eq__(self, other):
        return (self.span == other.span and
                self.role == other.role and
                self.event_type == other.event_type)

    def __hash__(self):
        return self.span.__hash__() + hash((self.role, self.event_type))


class NER:
    def __init__(self, ner, text, sentence_start, flavor=None):
        self.span = Span(ner[0], ner[1], text, sentence_start)
        self.label = ner[2]
        self.flavor = flavor

    def __repr__(self):
        return self.span.__repr__() + ": " + self.label

    def __eq__(self, other):
        return (self.span == other.span and
                self.label == other.label and
                self.flavor == other.flavor)


class Relation:
    def __init__(self, relation, text, sentence_start):
        start1, end1 = relation[0], relation[1]
        start2, end2 = relation[2], relation[3]
        label = relation[4]
        span1 = Span(start1, end1, text, sentence_start)
        span2 = Span(start2, end2, text, sentence_start)
        self.pair = (span1, span2)
        self.label = label

    def __repr__(self):
        return self.pair[0].__repr__() + ", " + self.pair[1].__repr__() + ": " + self.label

    def __eq__(self, other):
        return (self.pair == other.pair) and (self.label == other.label)


class AtomicRelation:
    def __init__(self, ent0, ent1, label):
        self.ent0 = ent0
        self.ent1 = ent1
        self.label = label

    @classmethod
    def from_relation(cls, relation):
        ent0 = " ".join(relation.pair[0].text)
        ent1 = " ".join(relation.pair[1].text)
        label = relation.label
        return cls(ent0, ent1, label)

    def __repr__(self):
        return f"({self.ent0} | {self.ent1} | {self.label})"



class Event:
    def __init__(self, event, text, sentence_start):
        trig = event[0]
        args = event[1:]
        trigger_token = Token(trig[0], text, sentence_start)
        self.trigger = Trigger(trigger_token, trig[1])

        self.arguments = []
        for arg in args:
            span = Span(arg[0], arg[1], text, sentence_start)
            self.arguments.append(Argument(span, arg[2], self.trigger.label))

    def __repr__(self):
        res = "<"
        res += self.trigger.__repr__() + ":\n"
        for arg in self.arguments:
            res += 6 * " " + arg.__repr__() + ";\n"
        res = res[:-2] + ">"
        return res


class Events:
    def __init__(self, events_json, text, sentence_start):
        self.event_list = [Event(this_event, text, sentence_start) for this_event in events_json]
        self.triggers = set([event.trigger for event in self.event_list])
        self.arguments = set([arg for event in self.event_list for arg in event.arguments])

    def __len__(self):
        return len(self.event_list)

    def __getitem__(self, i):
       return self.event_list[i]

    def __repr__(self):
        return "\n\n".join([event.__repr__() for event in self.event_list])

    def span_matches(self, argument):
        return set([candidate for candidate in self.arguments
                    if candidate.span.span_sent == argument.span.span_sent])

    def event_type_matches(self, argument):
        return set([candidate for candidate in self.span_matches(argument)
                    if candidate.event_type == argument.event_type])

    def matches_except_event_type(self, argument):
        matched = [candidate for candidate in self.span_matches(argument)
                   if candidate.event_type != argument.event_type
                   and candidate.role == argument.role]
        return set(matched)

    def exact_match(self, argument):
        for candidate in self.arguments:
            if candidate == argument:
                return True
        return False


class Cluster:
    def __init__(self, cluster, cluster_id, document):
        members = []
        for entry in cluster:
            sentence_ix = get_sentence_of_span(entry, document.sentence_starts, document.n_tokens)
            sentence = document[sentence_ix]
            span = Span(entry[0], entry[1], sentence.text, sentence.sentence_start)
            ners = [x for x in sentence.ner if x.span == span]
            assert len(ners) <= 1
            ner = ners[0] if len(ners) == 1 else None
            to_append = ClusterMember(span, ner, sentence, cluster_id)
            members.append(to_append)

        self.members = members
        self.cluster_id = cluster_id

    def __repr__(self):
        return f"{self.cluster_id}: " + self.members.__repr__()

    def __getitem__(self, ix):
        return self.members[ix]


class ClusterMember:
    def __init__(self, span, ner, sentence, cluster_id):
        self.span = span
        self.ner = ner
        self.sentence = sentence
        self.cluster_id = cluster_id

    def __repr__(self):
        return f"<{self.sentence.sentence_ix}> " + self.span.__repr__()


# Utils for Data Sets 

In [16]:
task_ner_labels = {
    'ace04': ['FAC', 'WEA', 'LOC', 'VEH', 'GPE', 'ORG', 'PER'],
    'ace05': ['FAC', 'WEA', 'LOC', 'VEH', 'GPE', 'ORG', 'PER'],
    'scierc': ['Method', 'OtherScientificTerm', 'Task', 'Generic', 'Material', 'Metric'],
}

def get_labelmap(label_list):
    label2id = {}
    id2label = {}
    for i, label in enumerate(label_list):
        label2id[label] = i + 1
        id2label[i + 1] = label
    return label2id, id2label


# Main function

In [21]:
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
                    datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)
logger = logging.getLogger('root')

In [22]:
def save_model(model, args):
    """
    Save the model to the output directory
    """
    logger.info('Saving model to %s...'%(args.output_dir))
    model_to_save = model.bert_model.module if hasattr(model.bert_model, 'module') else model.bert_model
    model_to_save.save_pretrained(args.output_dir)
    model.tokenizer.save_pretrained(args.output_dir)

def output_ner_predictions(model, batches, dataset, output_file):
    """
    Save the prediction as a json file
    """
    ner_result = {}
    span_hidden_table = {}
    tot_pred_ett = 0
    for i in range(len(batches)):
        output_dict = model.run_batch(batches[i], training=False)
        pred_ner = output_dict['pred_ner']
        for sample, preds in zip(batches[i], pred_ner):
            off = sample['sent_start_in_doc'] - sample['sent_start']
            k = sample['doc_key'] + '-' + str(sample['sentence_ix'])
            ner_result[k] = []
            for span, pred in zip(sample['spans'], preds):
                span_id = '%s::%d::(%d,%d)'%(sample['doc_key'], sample['sentence_ix'], span[0]+off, span[1]+off)
                if pred == 0:
                    continue
                ner_result[k].append([span[0]+off, span[1]+off, ner_id2label[pred]])
            tot_pred_ett += len(ner_result[k])

    logger.info('Total pred entities: %d'%tot_pred_ett)

    js = dataset.js
    for i, doc in enumerate(js):
        doc["predicted_ner"] = []
        doc["predicted_relations"] = []
        for j in range(len(doc["sentences"])):
            k = doc['doc_key'] + '-' + str(j)
            if k in ner_result:
                doc["predicted_ner"].append(ner_result[k])
            else:
                logger.info('%s not in NER results!'%k)
                doc["predicted_ner"].append([])
            
            doc["predicted_relations"].append([])

        js[i] = doc

    logger.info('Output predictions to %s..'%(output_file))
    with open(output_file, 'w') as f:
        f.write('\n'.join(json.dumps(doc, cls=NpEncoder) for doc in js))

def evaluate(model, batches, tot_gold):
    """
    Evaluate the entity model
    """
    logger.info('Evaluating...')
    c_time = time.time()
    cor = 0
    tot_pred = 0
    l_cor = 0
    l_tot = 0

    for i in range(len(batches)):
        output_dict = model.run_batch(batches[i], training=False)
        pred_ner = output_dict['pred_ner']
        for sample, preds in zip(batches[i], pred_ner):
            for gold, pred in zip(sample['spans_label'], preds):
                l_tot += 1
                if pred == gold:
                    l_cor += 1
                if pred != 0 and gold != 0 and pred == gold:
                    cor += 1
                if pred != 0:
                    tot_pred += 1
                   
    acc = l_cor / l_tot
    logger.info('Accuracy: %5f'%acc)
    logger.info('Cor: %d, Pred TOT: %d, Gold TOT: %d'%(cor, tot_pred, tot_gold))
    p = cor / tot_pred if cor > 0 else 0.0
    r = cor / tot_gold if cor > 0 else 0.0
    f1 = 2 * (p * r) / (p + r) if cor > 0 else 0.0
    logger.info('P: %.5f, R: %.5f, F1: %.5f'%(p, r, f1))
    logger.info('Used time: %f'%(time.time()-c_time))
    return f1

def setseed(seed):
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


In [24]:

parser = argparse.ArgumentParser()

parser.add_argument('--task', type=str, default=None, required=True, choices=['ace04', 'ace05', 'scierc'])

parser.add_argument('--data_dir', type=str, default=None, required=True, 
                    help="path to the preprocessed dataset")
parser.add_argument('--output_dir', type=str, default='entity_output', 
                    help="output directory of the entity model")

parser.add_argument('--max_span_length', type=int, default=8, 
                    help="spans w/ length up to max_span_length are considered as candidates")
parser.add_argument('--train_batch_size', type=int, default=32, 
                    help="batch size during training")
parser.add_argument('--eval_batch_size', type=int, default=32, 
                    help="batch size during inference")
parser.add_argument('--learning_rate', type=float, default=1e-5, 
                    help="learning rate for the BERT encoder")
parser.add_argument('--task_learning_rate', type=float, default=1e-4, 
                    help="learning rate for task-specific parameters, i.e., classification head")
parser.add_argument('--warmup_proportion', type=float, default=0.1, 
                    help="the ratio of the warmup steps to the total steps")
parser.add_argument('--num_epoch', type=int, default=100, 
                    help="number of the training epochs")
parser.add_argument('--print_loss_step', type=int, default=100, 
                    help="how often logging the loss value during training")
parser.add_argument('--eval_per_epoch', type=int, default=1, 
                    help="how often evaluating the trained model on dev set during training")
parser.add_argument("--bertadam", action="store_true", help="If bertadam, then set correct_bias = False")

parser.add_argument('--do_train', action='store_true', 
                    help="whether to run training")
parser.add_argument('--train_shuffle', action='store_true',
                    help="whether to train with randomly shuffled data")
parser.add_argument('--do_eval', action='store_true', 
                    help="whether to run evaluation")
parser.add_argument('--eval_test', action='store_true', 
                    help="whether to evaluate on test set")
parser.add_argument('--dev_pred_filename', type=str, default="ent_pred_dev.json", help="the prediction filename for the dev set")
parser.add_argument('--test_pred_filename', type=str, default="ent_pred_test.json", help="the prediction filename for the test set")

parser.add_argument('--use_albert', action='store_true', 
                    help="whether to use ALBERT model")
parser.add_argument('--model', type=str, default='bert-base-uncased', 
                    help="the base model name (a huggingface model)")
parser.add_argument('--bert_model_dir', type=str, default=None, 
                    help="the base model directory")

parser.add_argument('--seed', type=int, default=0)

parser.add_argument('--context_window', type=int, required=True, default=None, 
                    help="the context window size W for the entity model")

args = parser.parse_args()
args.train_data = os.path.join(args.data_dir, 'train.json')
args.dev_data = os.path.join(args.data_dir, 'dev.json')
args.test_data = os.path.join(args.data_dir, 'test.json')

if 'albert' in args.model:
    logger.info('Use Albert: %s'%args.model)
    args.use_albert = True

setseed(args.seed)

if not os.path.exists(args.output_dir):
    os.makedirs(args.output_dir)

if args.do_train:
    logger.addHandler(logging.FileHandler(os.path.join(args.output_dir, "train.log"), 'w'))
else:
    logger.addHandler(logging.FileHandler(os.path.join(args.output_dir, "eval.log"), 'w'))

logger.info(sys.argv)
logger.info(args)

ner_label2id, ner_id2label = get_labelmap(task_ner_labels[args.task])

num_ner_labels = len(task_ner_labels[args.task]) + 1
model = EntityModel(args, num_ner_labels=num_ner_labels)

dev_data = Dataset(args.dev_data)
dev_samples, dev_ner = convert_dataset_to_samples(dev_data, args.max_span_length, ner_label2id=ner_label2id, context_window=args.context_window)
dev_batches = batchify(dev_samples, args.eval_batch_size)

usage: ipykernel_launcher.py [-h] --task {ace04,ace05,scierc} --data_dir
                             DATA_DIR [--output_dir OUTPUT_DIR]
                             [--max_span_length MAX_SPAN_LENGTH]
                             [--train_batch_size TRAIN_BATCH_SIZE]
                             [--eval_batch_size EVAL_BATCH_SIZE]
                             [--learning_rate LEARNING_RATE]
                             [--task_learning_rate TASK_LEARNING_RATE]
                             [--warmup_proportion WARMUP_PROPORTION]
                             [--num_epoch NUM_EPOCH]
                             [--print_loss_step PRINT_LOSS_STEP]
                             [--eval_per_epoch EVAL_PER_EPOCH] [--bertadam]
                             [--do_train] [--train_shuffle] [--do_eval]
                             [--eval_test]
                             [--dev_pred_filename DEV_PRED_FILENAME]
                             [--test_pred_filename TEST_PRED_FILENAME]
                

SystemExit: ignored

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [25]:
if args.do_train:
    train_data = Dataset(args.train_data)
    train_samples, train_ner = convert_dataset_to_samples(train_data, args.max_span_length, ner_label2id=ner_label2id, context_window=args.context_window)
    train_batches = batchify(train_samples, args.train_batch_size)
    best_result = 0.0

    param_optimizer = list(model.bert_model.named_parameters())
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer
            if 'bert' in n]},
        {'params': [p for n, p in param_optimizer
            if 'bert' not in n], 'lr': args.task_learning_rate}]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, correct_bias=not(args.bertadam))
    t_total = len(train_batches) * args.num_epoch
    scheduler = get_linear_schedule_with_warmup(optimizer, int(t_total*args.warmup_proportion), t_total)
    
    tr_loss = 0
    tr_examples = 0
    global_step = 0
    eval_step = len(train_batches) // args.eval_per_epoch
    for _ in tqdm(range(args.num_epoch)):
        if args.train_shuffle:
            random.shuffle(train_batches)
        for i in tqdm(range(len(train_batches))):
            output_dict = model.run_batch(train_batches[i], training=True)
            loss = output_dict['ner_loss']
            loss.backward()

            tr_loss += loss.item()
            tr_examples += len(train_batches[i])
            global_step += 1

            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

            if global_step % args.print_loss_step == 0:
                logger.info('Epoch=%d, iter=%d, loss=%.5f'%(_, i, tr_loss / tr_examples))
                tr_loss = 0
                tr_examples = 0

            if global_step % eval_step == 0:
                f1 = evaluate(model, dev_batches, dev_ner)
                if f1 > best_result:
                    best_result = f1
                    logger.info('!!! Best valid (epoch=%d): %.2f' % (_, f1*100))
                    save_model(model, args)

NameError: ignored

In [None]:
if args.do_eval:
    args.bert_model_dir = args.output_dir
    model = EntityModel(args, num_ner_labels=num_ner_labels)
    if args.eval_test:
        test_data = Dataset(args.test_data)
        prediction_file = os.path.join(args.output_dir, args.test_pred_filename)
    else:
        test_data = Dataset(args.dev_data)
        prediction_file = os.path.join(args.output_dir, args.dev_pred_filename)
    test_samples, test_ner = convert_dataset_to_samples(test_data, args.max_span_length, ner_label2id=ner_label2id, context_window=args.context_window)
    test_batches = batchify(test_samples, args.eval_batch_size)
    evaluate(model, test_batches, test_ner)
    output_ner_predictions(model, test_batches, test_data, output_file=prediction_file)