In [31]:
! pip install nltk



In [1]:
import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\odaim\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [11]:
import json
import os
from unidecode import unidecode

from nltk.tokenize import word_tokenize

In [100]:
def normalize_nyt_sample(sample):
    data = json.loads(sample)
    sentence = data['sentText'].replace('Jr.', 'Jr')
    sentence = sentence.replace('U.S.A.', 'U.S.A')
    sentence = sentence.replace('P.M.', 'P.M')
    tokens = word_tokenize(sentence)
    norm = {}
    norm['doc_key'] = data['articleId']
    norm['sentences'] = [tokens]
    norm['ner'] = []
    norm['relations'] = []
    norm['clusters'] = []
    
    entities = []
    rels = []
    # print(tokens)
    for entity in data['entityMentions']:
        ent = entity['text'].replace('Jr.', 'Jr')
        ent = ent.replace('U.S.A.', 'U.S.A')
        ent = ent.replace('P.M.', 'P.M')
        ner = word_tokenize(ent)
        # print(ner)
        label = entity['label'].title()
        entities.append([tokens.index(ner[0]), tokens.index(ner[-1]), label])
    norm['ner'].append(entities)
        
    for relation in data['relationMentions']:
        label = relation['label'].split('/')[-1].replace('_','-').upper()
        source = unidecode(relation['em1Text'].replace('Jr.', 'Jr'))
        source = word_tokenize(source)
        # print(source)
        target = unidecode(relation['em2Text'].replace('Jr.', 'Jr'))
        target = word_tokenize(target)
        # print(target)
        rels.append([tokens.index(source[0]), tokens.index(source[-1]), tokens.index(target[0]), tokens.index(target[-1]), label])
    norm['relations'].append(rels)
        
    return norm

In [101]:
samp = """{"sentText": "Was it just last month that Wal-Mart 's chief executive , H. Lee Scott Jr. , said his company would be a kinder , gentler corporate citizen and never again bulldoze a local government to let it open more stores ?", "articleId": "/m/vinci8/data1/riedel/projects/relation/kb/nyt1/docstore/nyt-2005-2006.backup/1674506.xml.pb", "relationMentions": [{"em1Text": "H. Lee Scott Jr.", "em2Text": "Wal-Mart", "label": "/business/person/company"}], "entityMentions": [{"start": 0, "label": "ORGANIZATION", "text": "Wal-Mart"}, {"start": 1, "label": "PERSON", "text": "H. Lee Scott Jr."}], "sentId": "1"}"""
normalize_nyt_sample(samp)

{'doc_key': '/m/vinci8/data1/riedel/projects/relation/kb/nyt1/docstore/nyt-2005-2006.backup/1674506.xml.pb',
 'sentences': [['Was',
   'it',
   'just',
   'last',
   'month',
   'that',
   'Wal-Mart',
   "'s",
   'chief',
   'executive',
   ',',
   'H.',
   'Lee',
   'Scott',
   'Jr',
   ',',
   'said',
   'his',
   'company',
   'would',
   'be',
   'a',
   'kinder',
   ',',
   'gentler',
   'corporate',
   'citizen',
   'and',
   'never',
   'again',
   'bulldoze',
   'a',
   'local',
   'government',
   'to',
   'let',
   'it',
   'open',
   'more',
   'stores',
   '?']],
 'ner': [[[6, 6, 'Organization'], [11, 14, 'Person']]],
 'relations': [[[11, 14, 6, 6, 'COMPANY']]],
 'clusters': []}

In [4]:
nyt_data_dir = os.getcwd() + '/other_data/nyt_er_dataset/'

def write_normal_data(in_dir, out_dir):
    with open(in_dir) as f:
        for line in f:
            try:
                maped_sample = normalize_nyt_sample(line)
            except:
                print(line)
                break
            with open(out_dir, 'a') as normalized:
                normalized.write(json.dumps(maped_sample) + "\n")

In [110]:
nyt_train_data_path = nyt_data_dir + 'train.json'
nyt_train_norm_data_path = nyt_data_dir + 'norm_train.json'
            
write_normal_data(nyt_train_data_path, nyt_train_norm_data_path)

In [111]:
nyt_valid_data_path = nyt_data_dir + 'valid.json'
nyt_valid_norm_data_path = nyt_data_dir + 'norm_valid.json'

write_normal_data(nyt_valid_data_path, nyt_valid_norm_data_path)

In [112]:
nyt_test_data_path = nyt_data_dir + 'test.json'
nyt_test_norm_data_path = nyt_data_dir + 'norm_test.json'

write_normal_data(nyt_test_data_path, nyt_test_norm_data_path)

In [1]:
%run entity_model/entity_setup.ipynb

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
task_ner_labels = {
    'ace04': ['FAC', 'WEA', 'LOC', 'VEH', 'GPE', 'ORG', 'PER'],
    'ace05': ['FAC', 'WEA', 'LOC', 'VEH', 'GPE', 'ORG', 'PER'],
    'scierc': ['Method', 'OtherScientificTerm', 'Task', 'Generic', 'Material', 'Metric'],
    'nyt': ['Location', 'Person', 'Organization']
}

In [5]:
data_dir = nyt_data_dir
output_dir = os.getcwd() + '/nyt_models/ent-scib-ctx0/'
task = 'nyt'
max_span_length = 8
context_window = 0
eval_batch_size = 32
test_pred_filename = 'ent_pred_test.json'
dev_pred_filename = 'ent_pred_dev.json'

In [15]:
train_data = os.path.join(data_dir, 'norm_train.json')
dev_data = os.path.join(data_dir, 'norm_valid.json')
test_data = os.path.join(data_dir, 'norm_test.json')

In [7]:
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

In [8]:
ner_label2id, ner_id2label = get_labelmap(task_ner_labels[task])

In [9]:
dev_data = Dataset(dev_data)
dev_samples, dev_ner = convert_dataset_to_samples(dev_data, max_span_length, ner_label2id=ner_label2id, context_window=context_window)
dev_batches = batchify(dev_samples, eval_batch_size)

01/26/2024 14:16:01 - INFO - root - # Overlap: 0
01/26/2024 14:16:01 - INFO - root - Extracted 5000 samples from 5000 documents, with 15923 NER labels, 37.763 avg input length, 100 max length
01/26/2024 14:16:01 - INFO - root - Max Length: 100, max NER: 13


In [10]:
data_dir = nyt_data_dir
output_dir = os.getcwd() + '/nyt_models/from-scratch/ent-scib-ctx0/'
task = 'scierc'
num_ner_labels = len(task_ner_labels[task]) + 1
max_span_length = 8
context_window = 300
eval_batch_size = 32
train_batch_size = 2
learning_rate = 1e-5
task_learning_rate = 5e-4
bertadam = True # If bertadam, then set correct_bias = False
num_epoch = 10 # number of the training epochs
warmup_proportion = 0.1 # the ratio of the warmup steps to the total steps
eval_per_epoch = 1 # how often evaluating the trained model on dev set during training
train_shuffle = True # whether to train with randomly shuffled data
print_loss_step = 100 # how often logging the loss value during training

In [11]:
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

In [140]:
model = EntityModel(model='allenai/scibert_scivocab_uncased', use_albert=False, max_span_length=max_span_length, num_ner_labels=num_ner_labels)

01/25/2024 23:22:25 - INFO - transformers.tokenization_utils_base - Model name 'allenai/scibert_scivocab_uncased' not found in model shortcut name list (bert-base-uncased, bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, bert-base-multilingual-cased, bert-base-chinese, bert-base-german-cased, bert-large-uncased-whole-word-masking, bert-large-cased-whole-word-masking, bert-large-uncased-whole-word-masking-finetuned-squad, bert-large-cased-whole-word-masking-finetuned-squad, bert-base-cased-finetuned-mrpc, bert-base-german-dbmdz-cased, bert-base-german-dbmdz-uncased, TurkuNLP/bert-base-finnish-cased-v1, TurkuNLP/bert-base-finnish-uncased-v1, wietsedv/bert-base-dutch-cased). Assuming 'allenai/scibert_scivocab_uncased' is a path, a model identifier, or url to a directory containing tokenizer files.
01/25/2024 23:22:29 - INFO - transformers.tokenization_utils_base - loading file https://s3.amazonaws.com/models.huggingface.co/bert/allenai/scibert_scivoca

In [12]:
train_data = Dataset(train_data)

In [142]:
train_samples, train_ner = convert_dataset_to_samples(train_data, max_span_length, ner_label2id=ner_label2id, context_window=context_window)
train_batches = batchify(train_samples, train_batch_size)
best_result = 0.0

param_optimizer = list(model.bert_model.named_parameters())
optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer
        if 'bert' in n]},
    {'params': [p for n, p in param_optimizer
        if 'bert' not in n], 'lr': task_learning_rate}]
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate, correct_bias=not(bertadam))
t_total = len(train_batches) * num_epoch
scheduler = get_linear_schedule_with_warmup(optimizer, int(t_total*warmup_proportion), t_total)

tr_loss = 0
tr_examples = 0
global_step = 0
eval_step = len(train_batches) // eval_per_epoch
for _ in tqdm(range(num_epoch), position=0, leave=True):
    if train_shuffle:
        random.shuffle(train_batches)
    for i in tqdm(range(len(train_batches)), position=0, leave=True):
        output_dict = model.run_batch(train_batches[i], training=True)
        loss = output_dict['ner_loss']
        loss.backward()

        tr_loss += loss.item()
        tr_examples += len(train_batches[i])
        global_step += 1

        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

        if global_step % print_loss_step == 0:
            logger.info('Epoch=%d, iter=%d, loss=%.5f'%(_, i, tr_loss / tr_examples))
            tr_loss = 0
            tr_examples = 0

        if global_step % eval_step == 0:
            f1 = evaluate(model, dev_batches, dev_ner)
            if f1 > best_result:
                best_result = f1
                logger.info('!!! Best valid (epoch=%d): %.2f' % (_, f1*100))
                save_model(model, output_dir)

01/25/2024 23:24:22 - INFO - root - # Overlap: 0
01/25/2024 23:24:22 - INFO - root - Extracted 56196 samples from 56196 documents, with 177461 NER labels, 37.817 avg input length, 100 max length
01/25/2024 23:24:22 - INFO - root - Max Length: 100, max NER: 20
  0%|          | 99/28098 [00:19<1:19:37,  5.86it/s]01/25/2024 23:24:42 - INFO - root - Epoch=0, iter=99, loss=526.39888
  1%|          | 199/28098 [00:36<1:17:48,  5.98it/s]01/25/2024 23:24:59 - INFO - root - Epoch=0, iter=199, loss=489.87780
  1%|          | 299/28098 [00:53<1:13:44,  6.28it/s]01/25/2024 23:25:16 - INFO - root - Epoch=0, iter=299, loss=419.22339
  1%|▏         | 399/28098 [01:11<1:16:48,  6.01it/s]01/25/2024 23:25:34 - INFO - root - Epoch=0, iter=399, loss=133.45635
  2%|▏         | 499/28098 [01:28<1:21:15,  5.66it/s]01/25/2024 23:25:51 - INFO - root - Epoch=0, iter=499, loss=28.05477
  2%|▏         | 599/28098 [01:46<1:21:01,  5.66it/s]01/25/2024 23:26:09 - INFO - root - Epoch=0, iter=599, loss=22.22239
  2%|▏

KeyboardInterrupt: 

In [13]:
bert_model_dir = output_dir
num_ner_labels = len(task_ner_labels[task]) + 1
model = EntityModel(model='allenai/scibert_scivocab_uncased', bert_model_dir=bert_model_dir, use_albert=False, max_span_length=max_span_length, num_ner_labels=num_ner_labels)

01/26/2024 14:16:14 - INFO - root - Loading BERT model from C:\Users\odaim\Documents\PURE reproduction/nyt_models/from-scratch/ent-scib-ctx0//
01/26/2024 14:16:14 - INFO - transformers.tokenization_utils_base - Model name 'C:\Users\odaim\Documents\PURE reproduction/nyt_models/from-scratch/ent-scib-ctx0//' not found in model shortcut name list (bert-base-uncased, bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, bert-base-multilingual-cased, bert-base-chinese, bert-base-german-cased, bert-large-uncased-whole-word-masking, bert-large-cased-whole-word-masking, bert-large-uncased-whole-word-masking-finetuned-squad, bert-large-cased-whole-word-masking-finetuned-squad, bert-base-cased-finetuned-mrpc, bert-base-german-dbmdz-cased, bert-base-german-dbmdz-uncased, TurkuNLP/bert-base-finnish-cased-v1, TurkuNLP/bert-base-finnish-uncased-v1, wietsedv/bert-base-dutch-cased). Assuming 'C:\Users\odaim\Documents\PURE reproduction/nyt_models/from-scratch/ent-scib-ct

In [16]:
test_data = Dataset(dev_data)
prediction_file = os.path.join(output_dir, dev_pred_filename)
    
test_samples, test_ner = convert_dataset_to_samples(test_data, max_span_length, ner_label2id=ner_label2id, context_window=context_window)
test_batches = batchify(test_samples, eval_batch_size)
evaluate(model, test_batches, test_ner)
output_ner_predictions(model, test_batches, test_data, output_file=prediction_file)

01/26/2024 14:17:20 - INFO - root - # Overlap: 0
01/26/2024 14:17:20 - INFO - root - Extracted 5000 samples from 5000 documents, with 15923 NER labels, 37.763 avg input length, 100 max length
01/26/2024 14:17:20 - INFO - root - Max Length: 100, max NER: 13
01/26/2024 14:17:20 - INFO - root - Evaluating...
01/26/2024 14:20:31 - INFO - root - Accuracy: 0.998733
01/26/2024 14:20:31 - INFO - root - Cor: 14324, Pred TOT: 15405, Gold TOT: 15923
01/26/2024 14:20:31 - INFO - root - P: 0.92983, R: 0.89958, F1: 0.91445
01/26/2024 14:20:31 - INFO - root - Used time: 191.038483
01/26/2024 14:23:42 - INFO - root - Total pred entities: 15405
01/26/2024 14:23:42 - INFO - root - Output predictions to C:\Users\odaim\Documents\PURE reproduction/nyt_models/from-scratch/ent-scib-ctx0/ent_pred_dev.json..
