In [None]:
import argparse
import math
import os

import torch
from torch.nn import DataParallel
from torch.optim import Optimizer
import transformers
from transformers import AdamW

from tqdm import tqdm

import spert.models
from spert.entities import Dataset
#from spert.evaluator import Evaluator
from spert.input_reader import JsonInputReader, BaseInputReader
from spert.loss import SpERTLoss, Loss
#from spert.sampling import Sampler
#from trainer import BaseTrainer
from bert.tokenization.bert_tokenization import FullTokenizer# TF TOKENIZER
from spert import sampling
from transformers import BertTokenizer #PYTORCH TOKENIZER

from torch.utils.data import DataLoader

#SCRIPT_PATH = os.path.dirname(os.path.realpath(__file__))
from transformers import AutoTokenizer #, AutoModelForMaskedLM
tokenizer = AutoTokenizer.from_pretrained("neuralmind/bert-base-portuguese-cased")
#model = AutoModelForMaskedLM.from_pretrained("neuralmind/bert-base-portuguese-cased")

## Loading datasets in the input reader

In [None]:
types_path = "../spert-data/datasets/types.json"
train_path = "../spert-data/datasets/train.json"
valid_path = "../spert-data/datasets/evaluate.json"

input_reader = JsonInputReader(types_path = types_path, 
                               tokenizer = tokenizer, 
                               encoding = 'utf-8', 
                               max_span_size = 10,
                               neg_entity_count = 100,
                               neg_rel_count = 100)
input_reader.read(dataset_path = train_path, dataset_label = 'train')
input_reader.read(dataset_path = valid_path, dataset_label = 'validation')
train_dataset = input_reader.get_dataset('train')
validation_dataset = input_reader.get_dataset('validation')
print("Number of documents in train = ", train_dataset.document_count)
print("Number of documents in validation = ", validation_dataset.document_count)
train_dataset.switch_mode(Dataset.TRAIN_MODE)
data_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, drop_last=True, num_workers=0, 
                         collate_fn=sampling.collate_fn_padding)

## Understanding data used for training

In [None]:
import numpy as np
for batch in data_loader:
    print('--------------        ENTRADA DA RN            ------------------')
    print('--------------          encodings          ------------------')
    print('encodings = ', tokenizer.convert_ids_to_tokens(batch['encodings'][0]))
    print('encodings = ', tokenizer.convert_ids_to_tokens(batch['encodings'][1]))
    assert len(batch['encodings'][0]) == len(batch['encodings'][1])
    print('O documento deste batch com texto maior possui', 
          len(batch['encodings'][1]), 'tokens.')
    print('context_masks = ', batch['context_masks'])
    assert len(batch['context_masks'][0]) == len(batch['context_masks'][1])
    assert len(batch['context_masks'][0]) == len(batch['encodings'][0])
    print('size = ', len(batch['context_masks'][1]))
    print('--------------          entity          ------------------')
    print('entity_masks = ', batch['entity_masks'])
    print('entity_types = ', batch['entity_types'])
    assert len(batch['entity_masks'][0]) == len(batch['entity_masks'][1])
    print('size = ', len(batch['entity_masks'][1]))
    print('entity_sizes = ', batch['entity_sizes'])
    assert len(batch['entity_sizes'][0]) == len(batch['entity_sizes'][1])
    assert len(batch['entity_sizes'][0]) == len(batch['entity_masks'][0])
    assert len(batch['entity_sizes'][0]) == len(batch['entity_types'][0])
    print('O documento com mais entidades neste batch possui ', 
          len(batch['entity_sizes'][1]), ' entidades. Cada entidade tem seu tamanho',
          'definido no batch[entity_sizes] e sua localização no batch[entity_masks].',
          'Há muitas entidades pq foram geradas randomicamente para os exemplos negativos.')
    print('--------------          rels          ------------------')
    print('rels = ', batch['rels'])
    assert len(batch['rels'][0]) == len(batch['rels'][1])
    print('size = ', len(batch['rels'][1]))
    print('rel_types = ', batch['rel_types'])
    print('rel_masks = ', batch['rel_masks'])
    assert len(batch['rel_masks'][0]) == len(batch['rel_masks'][1]) #mesmo batch
    assert len(batch['rel_masks'][0]) == len(batch['rels'][0]) # ambos são qtde de relações
    assert len(batch['rel_masks'][0]) == len(batch['rel_types'][0]) # ambos são qtde de relações
    assert len(batch['rel_masks'][0][0]) == len(batch['encodings'][0]) #ambos são tamanho de texto
    print('Há', len(batch['rel_masks'][0]), 'relações. O rel_masks,',
          'marca a posição entre duas entidades, mas sem as entidades. Isso pode',
          'não ser bom para minha aplicação por haver pouca informação.')
    break

def test_rel_and_ents(batch, batch_n):
    encoded_text = np.array(batch['encodings'][batch_n])
    print(tokenizer.convert_ids_to_tokens(encoded_text))
    print('Showing only relations and entities with TYPE different from 0:')
    for entity_mask, entity_type in zip(batch['entity_masks'][batch_n], batch['entity_types'][batch_n]):
        if entity_type.numpy() > 0:
            mask = np.array(entity_mask, dtype = np.bool)
            print('   entity type =', entity_type.numpy(),
                  ' - entity =', tokenizer.convert_ids_to_tokens(list(encoded_text[mask])))
    for rel_mask, rel_type in zip(batch['rel_masks'][batch_n], batch['rel_types'][batch_n]):
        if rel_type.numpy() > 0:
            mask = np.array(rel_mask, dtype = np.bool)
            print('   rel type =', rel_type.numpy(),
                  ' - rel =', tokenizer.convert_ids_to_tokens(list(encoded_text[mask])))
test_rel_and_ents(batch, 0)
test_rel_and_ents(batch, 1)

### Visualizing the datasets

In [None]:
docs = train_dataset.documents[10:20] 

for doc in docs:
    print("\n---------------------------------------")
    print(tokenizer.convert_ids_to_tokens(doc.encoding))
    print(doc.encoding)
    print('\nRelations:')
    for rel in doc.relations:
        print(rel.head_entity, "   -   ", rel.tail_entity)
    print('\nEntities:')
    for entity in doc.entities:
        print(entity.phrase)

# Rodando o SpERT
## Problemas:
- (ARRUMADO) Problema: Estou usando a base em inglês (bert-base-cased) ainda e o Tokenizer é tão ruim que meu dataset ultrapassa o limite de 510 tokens. E ultrapassa muito. Preciso usar o "neuralmind/bert-base-portuguese-cased".
- (ARRUMADO) É IMPORTANTE QUE O 'NONE' (tanto nas relations quanto nas entidades) NO ARQUIVO spert-data/datasets/types.json ESTEJA POR ÚLTIMO! Caso contrário dá o erro: IndexError: Target 13 is out of bounds.
- RuntimeError: [enforce fail at ..\c10\core\CPUAllocator.cpp:79] data. DefaultCPUAllocator: not enough memory: you tried to allocate 3480791040 bytes. Ocorreu ao dar evaluate após treinar com duas epochs. Batch de evaluate estava em 1 e de treino em 2. Subi ambos pra 2.

config='Namespace(cache_path=None, config='configs/example_train.conf', cpu=False, debug=False, epochs=20, eval_batch_size=1, example_count=None, final_eval=True, freeze_transformer=False, init_eval=False, label='conll04_train', log_path='data/log/', lowercase=False, lr=5e-05, lr_warmup=0.1, max_grad_norm=1.0, max_pairs=1000, max_span_size=10, model_path='bert-base-cased', model_type='spert', neg_entity_count=100, neg_relation_count=100, no_overlapping=False, prop_drop=0.1, rel_filter_threshold=0.4, sampling_processes=4, save_optimizer=False, save_path='data/save/', seed=None, size_embedding=25, store_examples=True, store_predictions=True, tokenizer_path='bert-base-cased', train_batch_size=2, train_log_iter=100, train_path='../spert-data/datasets/train.json', types_path='../spert-data/datasets/types.json', valid_path='../spert-data/datasets/evaluate.json', weight_decay=0.01)'

In [None]:
!python ./spert.py train --config configs/example_train.conf

In [1]:
%%timeit
# REQUIRES NOTHING ABOVE. RUN SOLO.
import argparse
from spert import input_reader
from spert.spert_trainer import SpERTTrainer
import configparser

def eval_dict(config):
    for key, value in config.items():
        if key != 'model_type':
            try:
                config[key] = eval(value)
            except:
                pass
    return config
#from config_reader import _read_config
#config = _read_config('configs/example_train.conf')[0][1]
config = configparser.ConfigParser()
config.read('configs/example_train.conf')
config = dict(config['1'].items())
config = eval_dict(config)
run_args = argparse.Namespace(**config)
print('run_args:', run_args)
for i in range (4):
    print('i = ', i)
    trainer = SpERTTrainer(run_args)
    print('------------------------')
    print('------training----------')
    print('------------------------')
    trainer.train(train_path=run_args.train_path, valid_path=run_args.valid_path,
                  types_path=run_args.types_path, input_reader_cls=input_reader.JsonInputReader)

run_args: Namespace(cache_path=None, config='configs/example_train.conf', cpu=True, debug=False, epochs=2, eval_batch_size=2, example_count=1, final_eval=True, freeze_transformer=False, init_eval=True, label='conll04_train', log_path='data/log/', lowercase=False, lr=5e-05, lr_warmup=0.1, max_grad_norm=1.0, max_pairs=1000, max_span_size=10, model_path='bert-base-cased', model_type='spert', neg_entity_count=100, neg_relation_count=100, no_overlapping=False, prop_drop=0.1, rel_filter_threshold=0.4, sampling_processes=0, save_optimizer=False, save_path='data/save/', seed=123, size_embedding=25, store_examples=True, store_predictions=True, tokenizer_path='neuralmind/bert-base-portuguese-cased', train_batch_size=2, train_log_iter=100, train_path='../spert-data/datasets/train.json', types_path='../spert-data/datasets/types.json', valid_path='../spert-data/datasets/evaluate.json', weight_decay=0.01)
i =  0
save_path: data/save/conll04_train\2022-03-15_10_48_55.422105  - self._args.label: conll

Parse dataset 'train': 100%|██████████| 4522/4522 [00:42<00:00, 105.35it/s]
Parse dataset 'valid': 100%|██████████| 1131/1131 [00:10<00:00, 111.54it/s]

2022-03-15 10:49:51,984 [MainThread  ] [INFO ]  Relation type count: 2
2022-03-15 10:49:51,985 [MainThread  ] [INFO ]  Entity type count: 14
2022-03-15 10:49:51,986 [MainThread  ] [INFO ]  Entities:
2022-03-15 10:49:51,987 [MainThread  ] [INFO ]  None=14
2022-03-15 10:49:51,988 [MainThread  ] [INFO ]  Artigo=1
2022-03-15 10:49:51,988 [MainThread  ] [INFO ]  Parágrafo=2
2022-03-15 10:49:51,989 [MainThread  ] [INFO ]  Inciso=3
2022-03-15 10:49:51,989 [MainThread  ] [INFO ]  Alínea=4
2022-03-15 10:49:51,990 [MainThread  ] [INFO ]  Diploma=5
2022-03-15 10:49:51,991 [MainThread  ] [INFO ]  Tema do STJ=6
2022-03-15 10:49:51,991 [MainThread  ] [INFO ]  Súmula do STJ=7
2022-03-15 10:49:51,992 [MainThread  ] [INFO ]  Tema do STF=8
2022-03-15 10:49:51,993 [MainThread  ] [INFO ]  Súmula do STF=9
2022-03-15 10:49:51,993 [MainThread  ] [INFO ]  Súmula do TRF3=10
2022-03-15 10:49:51,994 [MainThread  ] [INFO ]  Súmula Vinculante=11
2022-03-15 10:49:51,995 [MainThread  ] [INFO ]  Item=12
2022-03-15 10




2022-03-15 10:49:54,607 [MainThread  ] [INFO ]  Evaluate: valid


To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at  ..\aten\src\ATen\native\BinaryOps.cpp:467.)
  return torch.floor_divide(self, other)
Evaluate epoch 0: 100%|██████████| 566/566 [27:14<00:00,  2.89s/it]  


Evaluation

--- Entities (named entity recognition (NER)) ---
An entity is considered correct if the entity type and span is predicted correctly

                type    precision       recall     f1-score      support
              Artigo        89.74        99.00        94.15         1803
              Inciso        91.53        99.56        95.38          684
      Súmula do TRF3         0.00         0.00         0.00            1
          Precedente        40.00        72.10        51.45          319
           Parágrafo        80.25        97.01        87.84          603
         Tema do STF        25.00        87.50        38.89            8
         Tema do STJ         0.00         0.00         0.00            9
       Súmula do STF        58.33        60.87        59.57           23
              Alínea        68.91        90.11        78.10           91
   Súmula Vinculante         0.00         0.00         0.00            1
       Súmula do STJ        62.65        85.25     

Train epoch 0:  10%|█         | 229/2261 [10:14<1:30:50,  2.68s/it]


KeyboardInterrupt: 

## Documentation

### input_reader

input_reader.read({'train': train_path, 'validation': valid_path})

.context_size = 512

.vocabulary_size = 29794

.datasets = {'train': <entities.Dataset>, 'validation': <entities.Dataset>}

    train_dataset = input_reader.get_dataset('train')
    
.entity_types = OrderedDict([('None', <entities.EntityType>), ('Artigo', <entities.EntityType>)...)])

    .entity_type_count = 13
    entity_type_none = input_reader.get_entity_type(0)

.relation_types = OrderedDict([('P', <entities.RelationType>), ('None', <entities.RelationType>)])
    
    .relation_type_count = 2
    relation_type_none = input_reader.get_relation_type(1)










### dataset

(train_dataset)

.label = train

.documents = [<entities.Document>, <entities.Document>, ...]

.document_count = 707

.entities = [<entities.Entity>, <entities.Entity>, ...]

    .entity_count = 3686

.relations = [<entities.Relation>, <entities.Relation>, ...]

    .relation_count = 2151


.input_reader = 

.iterate_documents

.iterate_relations

.create_document

.create_entity

.create_relation

.create_token



### document

.doc_id = 1

.encoding = [101, 298, 8746, 14643, 442, ...]

tokenizer.convert_ids_to_tokens(train_dataset.documents[1].encoding) = ['[CLS]', 'dos', 'dispositivos', ...]

.entities = [<entities.Entity>, <entities.Entity>, ...]

.relations = [<entities.Relation>, <entities.Relation>, ...]

.tokens = <entities.TokenSpan object at 0x7f2eab851710>


### TokenSpan

len(train_dataset.documents[1].encoding) = 156

.span = (1, 155)

.span_start = 1

.span_end = 155

len() = number of tokens

### Entity

.as_tuple

.entity_type = <entities.EntityType object at 0x7f2ed3d606d0>

.phrase = 202

.span = (54, 55)

    .span_start = 54
    .span_end = 55
    .tokens = <entities.TokenSpan>

### relation

.as_tuple

.first_entity = decreto - lei [UNK] 28 ##8 / 67

.second_entity = ##o

.reverse = True

.head_entity (quando .reverse = True, .first_entity = .head_entity

.tail_entity = decreto - lei [UNK] 28 ##8 / 67

.relation_type = <entities.RelationType>


In [None]:
train_batch_size = 4
eval_batch_size = 1
neg_entity_count = 100
neg_relation_count = 100
lr = 5e-5
lr_warmup = 0.1
weight_decay = 0.01

rel_filter_threshold = 0.4
size_embedding = 25
prop_drop = 0.1
max_span_size = 10
store_examples = True
sampling_processes = 4
sampling_limit = 100
max_pairs = 500

In [None]:
%%time
epochs = 10
size_embedding = 100
freeze_transformer = True

context_size = input_reader.context_size
rel_type_count = input_reader.relation_type_count

train_sample_count = train_dataset.document_count
print('train_sample_count=',train_sample_count)
updates_epoch = train_sample_count // train_batch_size
updates_total = updates_epoch * epochs

#use model_path to start fresh
#use _save_path to continue the training
model_path='bert-base-portuguese-cased'
_save_path = "data/saved_model/final_model"
checkpoint_path = _save_path

def create_train_sampler():
    sampler = Sampler(0,0)
    sampler = sampler.create_train_sampler(dataset=train_dataset, batch_size=train_batch_size, max_span_size=max_span_size,
                                                     context_size=context_size, neg_entity_count=neg_entity_count,
                                                     neg_rel_count=neg_relation_count, truncate=True)
    return sampler
sampler = create_train_sampler()

model_class = models.get_model('SpERT')

# from_pretrained is a method from the superclass BertPreTrainedModel that loads BERT's weights

model = model_class.from_pretrained(checkpoint_path,
                                            cache_dir="",
                                            # SpERT model parameters
                                            cls_token=tokenizer.convert_tokens_to_ids(['[CLS]'])[0],
                                            # no node for 'none' class
                                            relation_types=input_reader.relation_type_count - 1,
                                            entity_types=input_reader.entity_type_count,
                                            max_pairs=max_pairs,
                                            prop_drop=0.1,
                                            size_embedding=size_embedding,
                                            freeze_transformer=freeze_transformer)
model.zero_grad()

iteration = 0
total = train_dataset.document_count // train_batch_size
def _get_optimizer_params(model):
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_params = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
         'weight_decay': weight_decay},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]

    return optimizer_params

rel_criterion = torch.nn.BCEWithLogitsLoss(reduction='none')
entity_criterion = torch.nn.CrossEntropyLoss(reduction='none')
optimizer_params = _get_optimizer_params(model)
optimizer = AdamW(optimizer_params, lr=lr, weight_decay=weight_decay, correct_bias=False)
scheduler = transformers.get_linear_schedule_with_warmup(optimizer,
                                                         num_warmup_steps=lr_warmup * updates_total,
                                                         num_training_steps=updates_total)

from transformers import PreTrainedModel
import util
def _save_model(save_path: str, model: PreTrainedModel, iteration: int, optimizer: Optimizer = None,
                    save_as_best: bool = False, extra: dict = None, include_iteration: int = True, name: str = 'model'):
    extra_state = dict(iteration=iteration)

    if optimizer:
        extra_state['optimizer'] = optimizer.state_dict()

    if extra:
        extra_state.update(extra)

    if save_as_best:
        dir_path = os.path.join(save_path, '%s_best' % name)
    else:
        dir_name = '%s_%s' % (name, iteration) if include_iteration else name
        dir_path = os.path.join(save_path, dir_name)

    util.create_directories_dir(dir_path)

    if isinstance(model, DataParallel):
        model.module.save_pretrained(dir_path)
    else:
        model.save_pretrained(dir_path)
    state_path = os.path.join(dir_path, 'extra.state')
    torch.save(extra_state, state_path)
    print("model saved")

# save final model
save_optimizer = True
extra = dict(epoch=epochs, updates_epoch=updates_epoch, epoch_iteration=0)
global_iteration = epochs * updates_epoch

In [None]:
#for n in sampler:
    #print(n)

# SAMPLER
iterator of <sampling.TrainTensorBatch>

# TrainTensorBatch

.ctx_masks = tensor shape (batch, context_size)

.encodings = tensor shape (batch, context_size)



.entity_types = tensor shape (batch, num_of_entities) of integers

.entity_masks = tensor shape (batch, num_of_entities, 512) of 0 or 1, with num_of_entities = max number of entities in the batch

.entity_sizes = tensor shape (batch, num_of_entities) of integers

.entity_sample_masks = vector, True if exists an entity in that position, False elsewhere. Needed because in the batch each example has a different number of entities. So, when a stacked them in a matrix I padded them. The padded ones are False, the real ones are True. shape (6, num_of_entities)



.rels = tensor shape (batch, num_of_relations, 2), with num_of_relations = max number of relations in the batch

    .rel_types

    .rel_masks = tensor shape (batch, num_of_relations, 512)

    .rel_sample_masks = vector, all True. shape (6, num_of_relations)



In [None]:
%%time

max_grad_norm = 1.0

compute_loss = SpERTLoss(rel_criterion, entity_criterion, model, optimizer, scheduler, max_grad_norm)

for _ in range(epochs):
    sampler = create_train_sampler()
    for batch in tqdm(sampler, total=total):
        model.train()

        # relation types to one-hot encoding
        rel_types_onehot = torch.zeros([batch.rel_types.shape[0], batch.rel_types.shape[1],
                                        rel_type_count], dtype=torch.float32)
        rel_types_onehot.scatter_(2, batch.rel_types.unsqueeze(2), 1)
        rel_types_onehot = rel_types_onehot[:, :, 1:]  # all zeros for 'none' relation

        # forward step
        entity_logits, rel_logits = model(batch.encodings, batch.ctx_masks, batch.entity_masks,
                                          batch.entity_sizes, batch.rels, batch.rel_masks)

        # compute loss and optimize parameters
        batch_loss = compute_loss.compute(rel_logits, rel_types_onehot, entity_logits,
                                          batch.entity_types, batch.rel_sample_masks, batch.entity_sample_masks)

        # logging
        iteration += 1
        print(batch_loss)

    _save_model(_save_path, model = model, iteration = global_iteration,
                             optimizer=optimizer if save_optimizer else None, extra=extra,
                             include_iteration=False, name='final_model')

In [None]:
%%time
from evaluator import Evaluator
example_count=1
epoch=1

evaluator = Evaluator(validation_dataset, input_reader, tokenizer,
                              rel_filter_threshold, example_count,
                              valid_path, epoch, validation_dataset.label)

        # create batch sampler
eval_sampler = Sampler(0,0)
eval_sampler = eval_sampler.create_eval_sampler(validation_dataset, eval_batch_size, max_span_size,
                                            input_reader.context_size, truncate=False)


with torch.no_grad():
    model.eval()

    # iterate batches
    total = math.ceil(validation_dataset.document_count / eval_batch_size)
    for batch in tqdm(eval_sampler, total=total, desc='Evaluate epoch %s' % epoch):
        # move batch to selected device
        # run model (forward pass)
        entity_clf, rel_clf, rels = model(batch.encodings, batch.ctx_masks, batch.entity_masks,
                                          batch.entity_sizes, batch.entity_spans, batch.entity_sample_masks,
                                          evaluate=True)

        # evaluate batch
        evaluator.eval_batch(entity_clf, rel_clf, rels, batch)

global_iteration = epoch * updates_epoch + iteration
ner_eval, rel_eval, rel_nec_eval = evaluator.compute_scores()
#self._log_eval(*ner_eval, *rel_eval, *rel_nec_eval,
               #epoch, iteration, global_iteration, validation_dataset.label)
evaluator.store_examples()

In [None]:
tokenizer.

In [None]:
entity_ctx = torch.zeros([0, 2048, 768])
entity_spans_pool = torch.zeros([4, 818, 768])
size_embeddings = torch.zeros([4, 818, 100])
print(entity_ctx.unsqueeze(1).shape)
print(entity_spans_pool.shape[1])
entity_repr = entity_ctx.unsqueeze(1).repeat(1, entity_spans_pool.shape[1], 1)

In [None]:
a = torch.tensor([2,2])
b = a.repeat(2,2)
b

In [None]:
for epoch in range(epochs):
    # train epoch
    self._train_epoch(model, compute_loss, optimizer, train_dataset, updates_epoch, epoch,
                      input_reader.context_size, input_reader.relation_type_count)
    # context_size é o tamanho do maior parágrafo considerando:
    # byte-pair document encoding including special tokens ([CLS] and [SEP])

    # eval validation sets
    if not args.final_eval or (epoch == args.epochs - 1):
        self._eval(model, validation_dataset, input_reader, epoch + 1, updates_epoch)

In [None]:
jtokens=['controle', 'de', 'constitucional', '##idade', 'de', 'normas', ':', 'reserva', 'de', 'plena', '##rio', 'constitui', '##ca', '##o', 'federal', 'art', '97', ':', 'repu', '##ta', '-', 'se', 'declara', '##tor', '##io', 'de', 'incons', '##titu', '##cional', '##idade', 'o', 'acorda', '##o', 'que', '-', 'embora', 'sem', 'o', 'explic', '##itar', '-', 'afas', '##ta', 'a', 'inc', '##iden', '##cia', 'da', 'norma', 'ord', '##inar', '##ia', 'per', '##tin', '##ente', 'a', 'li', '##de', 'para', 'decid', '##i', '-', 'la', 'sob', 'crit', '##eri', '##os', 'diversos', 'alega', '##damente', 'extra', '##idos', 'da', 'constitui', '##ca', '##o']

for i, token_phrase in enumerate(jtokens):
    token_encoding = tokenizer.convert_tokens_to_ids(token_phrase)
    print(token_encoding, ' - ', token_phrase)
    #span_start, span_end = (len(doc_encoding), len(doc_encoding) + len(token_encoding))

    # Cria um token, adiciona no dataset e o retorna pra ca
    #token = dataset.create_token(i, span_start, span_end, token_phrase)

    #doc_tokens.append(token)
    #doc_encoding += token_encoding

In [None]:
tokenizer.convert_tokens_to_ids('q')

In [None]:
paragrafo ='Implicitamente, considera-se que o art. 60 da Constituição é inalterável, pois alterações neste artigo permitiriam uma revisão completa da Constituição. Nos casos não abordados pelo art. 60, é possível propor emendas. Os órgãos competentes para submeter emendas são: a Câmara dos Deputados, o Senado Federal, o Presidente da República e de mais da metade das Assembleias Legislativas das unidades da Federação, manifestando-se, cada uma delas, pela maioria relativa de seus membros. '

a=tokenizer.tokenize(paragrafo)
for b in a:
    c=tokenizer.convert_tokens_to_ids(b)
    print(c)
print(len(a), ' - ', len(b))

In [None]:
b

In [None]:
print(input_reader.context_size)
print(input_reader.vocabulary_size)
print(input_reader.entity_type_count)
print(input_reader.relation_types)
print(input_reader.datasets)
for rel_type in input_reader.relation_types:
    print(rel_type)
dataset = input_reader.datasets['train']
docs = dataset._documents
print(docs[0]._entities[1])



In [None]:
import torch
tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'tokenizer', 'bert-base-cased')

text_1 = "Who was Jim Henson ?"
text_2 = "Jim Henson was a puppeteer"

# Tokenized input with special tokens around it (for BERT: [CLS] at the beginning and [SEP] at the end)
indexed_tokens = tokenizer.encode(text_1, text_2, add_special_tokens=True)


In [None]:
class SpERTTrainer(BaseTrainer):
    """ Joint entity and relation extraction training and evaluation """

    def __init__(self, args: argparse.Namespace):
        super().__init__(args)

        model_dir = os.path.abspath(os.getcwd())+"/../../../../LIAA-3R/SINARA/bert-base-pt-cased-tensorflow"
        self._tf_bert_tokenizer = FullTokenizer(vocab_file=os.path.join(model_dir, "vocab.txt"))

        # byte-pair encoding
        self._tokenizer = BertTokenizer.from_pretrained(args.tokenizer_path,
                                                        do_lower_case=args.lowercase,
                                                        cache_dir=args.cache_path)

        # path to export relation extraction examples to
        self._examples_path = os.path.join(self._log_path, 'examples_%s_%s_epoch_%s.html')

        # sampler (create and batch training/evaluation samples)
        self._sampler = Sampler(processes=args.sampling_processes, limit=args.sampling_limit)

    def train(self, train_path: str, valid_path: str, types_path: str):
        
        args = self.args

        self._logger.info("Datasets: %s, %s" % (train_path, valid_path))
        self._logger.info("Model type: %s" % args.model_type)

        # create log csv files
        self._init_train_logging('train')
        self._init_eval_logging('validation')

        # read datasets
        input_reader = JsonInputReader(types_path, self._tokenizer, self._logger)
        input_reader.read({'train': train_path, 'validation': valid_path})
        self._log_datasets(input_reader)

        train_dataset = input_reader.get_dataset('train')
        
        
        train_sample_count = train_dataset.document_count
        updates_epoch = train_sample_count // args.train_batch_size
        updates_total = updates_epoch * args.epochs

        validation_dataset = input_reader.get_dataset('validation')

        self._logger.info("Updates per epoch: %s" % updates_epoch)
        self._logger.info("Updates total: %s" % updates_total)

        # create model
        model_class = models.get_model(self.args.model_type)

        # load model
        model = model_class.from_pretrained(self.args.model_path,
                                            cache_dir=self.args.cache_path,
                                            # SpERT model parameters
                                            cls_token=self._tokenizer.convert_tokens_to_ids('[CLS]'),
                                            # no node for 'none' class
                                            relation_types=input_reader.relation_type_count - 1,
                                            entity_types=input_reader.entity_type_count,
                                            max_pairs=self.args.max_pairs,
                                            prop_drop=self.args.prop_drop,
                                            size_embedding=self.args.size_embedding,
                                            freeze_transformer=self.args.freeze_transformer)

        # create loss function
        rel_criterion = torch.nn.BCEWithLogitsLoss(reduction='none')
        entity_criterion = torch.nn.CrossEntropyLoss(reduction='none')
        compute_loss = SpERTLoss(rel_criterion, entity_criterion, model, optimizer, scheduler, max_grad_norm)

        # eval validation set
        if args.init_eval:
            self._eval(model, validation_dataset, input_reader, 0, updates_epoch)

        # train
        for epoch in range(args.epochs):
            # train epoch
            self._train_epoch(model, compute_loss, optimizer, train_dataset, updates_epoch, epoch,
                              input_reader.context_size, input_reader.relation_type_count)
            # context_size é o tamanho do maior parágrafo considerando:
            # byte-pair document encoding including special tokens ([CLS] and [SEP])

            # eval validation sets
            if not args.final_eval or (epoch == args.epochs - 1):
                self._eval(model, validation_dataset, input_reader, epoch + 1, updates_epoch)

        # save final model
        extra = dict(epoch=args.epochs, updates_epoch=updates_epoch, epoch_iteration=0)
        global_iteration = args.epochs * updates_epoch
        self._save_model(self._save_path, model, global_iteration,
                         optimizer=optimizer if self.args.save_optimizer else None, extra=extra,
                         include_iteration=False, name='final_model')

        self._logger.info("Logged in: %s" % self._log_path)
        self._logger.info("Saved in: %s" % self._save_path)

        self._sampler.join()

    def eval(self, dataset_path: str, types_path: str):
        # read datasets
        input_reader = JsonInputReader(types_path, self._tokenizer, self._logger)
        input_reader.read({'evaluate': dataset_path})
  
        # create model
        model_class = models.get_model(self.args.model_type)

        # load model
        model = model_class.from_pretrained(self.args.model_path,
                                            cache_dir=self.args.cache_path,
                                            # additional model parameters
                                            cls_token=self._tokenizer.convert_tokens_to_ids('[CLS]'),
                                            # no node for 'none' class
                                            relation_types=input_reader.relation_type_count - 1,
                                            entity_types=input_reader.entity_type_count,
                                            max_pairs=self.args.max_pairs,
                                            prop_drop=self.args.prop_drop,
                                            size_embedding=self.args.size_embedding,
                                            freeze_transformer=self.args.freeze_transformer)

        # evaluate
        self._eval(model, input_reader.get_dataset('evaluate'), input_reader)
        self._logger.info("Logged in: %s" % self._log_path)

        self._sampler.join()

    def _train_epoch(self, model: torch.nn.Module, compute_loss: Loss, optimizer: Optimizer, dataset: Dataset,
                     updates_epoch: int, epoch: int, context_size: int, rel_type_count: int):
        self._logger.info("Train epoch: %s" % epoch)

        # randomly shuffle data
        order = torch.randperm(dataset.document_count)
        sampler = self._sampler.create_train_sampler(dataset, self.args.train_batch_size, self.args.max_span_size,
                                                     context_size, self.args.neg_entity_count,
                                                     self.args.neg_relation_count, order=order, truncate=True)

        model.zero_grad()

        iteration = 0
        total = dataset.document_count // self.args.train_batch_size
        for batch in tqdm(sampler, total=total, desc='Train epoch %s' % epoch):
            model.train()

            # relation types to one-hot encoding
            rel_types_onehot = torch.zeros([batch.rel_types.shape[0], batch.rel_types.shape[1],
                                            rel_type_count], dtype=torch.float32).to(self._device)
            rel_types_onehot.scatter_(2, batch.rel_types.unsqueeze(2), 1)
            rel_types_onehot = rel_types_onehot[:, :, 1:]  # all zeros for 'none' relation

            # forward step
            entity_logits, rel_logits = model(batch.encodings, batch.ctx_masks, batch.entity_masks,
                                              batch.entity_sizes, batch.rels, batch.rel_masks)

            # compute loss and optimize parameters
            batch_loss = compute_loss.compute(rel_logits, rel_types_onehot, entity_logits,
                                              batch.entity_types, batch.rel_sample_masks, batch.entity_sample_masks)

            # logging
            iteration += 1
            global_iteration = epoch * updates_epoch + iteration

            if global_iteration % self.args.train_log_iter == 0:
                self._log_train(optimizer, batch_loss, epoch, iteration, global_iteration, dataset.label)

        return iteration

    def _eval(self, model: torch.nn.Module, dataset: Dataset, input_reader: JsonInputReader,
              epoch: int = 0, updates_epoch: int = 0, iteration: int = 0):
        self._logger.info("Evaluate: %s" % dataset.label)

        if isinstance(model, DataParallel):
            # currently no multi GPU support during evaluation
            model = model.module

        # create evaluator
        evaluator = Evaluator(dataset, input_reader, self._tokenizer,
                              self.args.rel_filter_threshold, self.args.example_count,
                              self._examples_path, epoch, dataset.label)

        # create batch sampler
        sampler = self._sampler.create_eval_sampler(dataset, self.args.eval_batch_size, self.args.max_span_size,
                                                    input_reader.context_size, truncate=False)

        with torch.no_grad():
            model.eval()

            # iterate batches
            total = math.ceil(dataset.document_count / self.args.eval_batch_size)
            for batch in tqdm(sampler, total=total, desc='Evaluate epoch %s' % epoch):
                # move batch to selected device
                batch = batch.to(self._device)

                # run model (forward pass)
                entity_clf, rel_clf, rels = model(batch.encodings, batch.ctx_masks, batch.entity_masks,
                                                  batch.entity_sizes, batch.entity_spans, batch.entity_sample_masks,
                                                  evaluate=True)

                # evaluate batch
                evaluator.eval_batch(entity_clf, rel_clf, rels, batch)

        global_iteration = epoch * updates_epoch + iteration
        ner_eval, rel_eval, rel_nec_eval = evaluator.compute_scores()
        self._log_eval(*ner_eval, *rel_eval, *rel_nec_eval,
                       epoch, iteration, global_iteration, dataset.label)

        if self.args.store_examples:
            evaluator.store_examples()

    def _get_optimizer_params(self, model):
        param_optimizer = list(model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_params = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
             'weight_decay': self.args.weight_decay},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]

        return optimizer_params

    def _log_train(self, optimizer: Optimizer, loss: float, epoch: int,
                   iteration: int, global_iteration: int, label: str):
        # average loss
        avg_loss = loss / self.args.train_batch_size
        # get current learning rate
        lr = self._get_lr(optimizer)[0]

        # log to tensorboard
        self._log_tensorboard(label, 'loss', loss, global_iteration)
        self._log_tensorboard(label, 'loss_avg', avg_loss, global_iteration)
        self._log_tensorboard(label, 'lr', lr, global_iteration)

        # log to csv
        self._log_csv(label, 'loss', loss, epoch, iteration, global_iteration)
        self._log_csv(label, 'loss_avg', avg_loss, epoch, iteration, global_iteration)
        self._log_csv(label, 'lr', lr, epoch, iteration, global_iteration)

    def _log_eval(self, ner_prec_micro: float, ner_rec_micro: float, ner_f1_micro: float,
                  ner_prec_macro: float, ner_rec_macro: float, ner_f1_macro: float,

                  rel_prec_micro: float, rel_rec_micro: float, rel_f1_micro: float,
                  rel_prec_macro: float, rel_rec_macro: float, rel_f1_macro: float,

                  rel_nec_prec_micro: float, rel_nec_rec_micro: float, rel_nec_f1_micro: float,
                  rel_nec_prec_macro: float, rel_nec_rec_macro: float, rel_nec_f1_macro: float,
                  epoch: int, iteration: int, global_iteration: int, label: str):

        # log to tensorboard
        self._log_tensorboard(label, 'eval/ner_prec_micro', ner_prec_micro, global_iteration)
        self._log_tensorboard(label, 'eval/ner_recall_micro', ner_rec_micro, global_iteration)
        self._log_tensorboard(label, 'eval/ner_f1_micro', ner_f1_micro, global_iteration)
        self._log_tensorboard(label, 'eval/ner_prec_macro', ner_prec_macro, global_iteration)
        self._log_tensorboard(label, 'eval/ner_recall_macro', ner_rec_macro, global_iteration)
        self._log_tensorboard(label, 'eval/ner_f1_macro', ner_f1_macro, global_iteration)

        self._log_tensorboard(label, 'eval/rel_prec_micro', rel_prec_micro, global_iteration)
        self._log_tensorboard(label, 'eval/rel_recall_micro', rel_rec_micro, global_iteration)
        self._log_tensorboard(label, 'eval/rel_f1_micro', rel_f1_micro, global_iteration)
        self._log_tensorboard(label, 'eval/rel_prec_macro', rel_prec_macro, global_iteration)
        self._log_tensorboard(label, 'eval/rel_recall_macro', rel_rec_macro, global_iteration)
        self._log_tensorboard(label, 'eval/rel_f1_macro', rel_f1_macro, global_iteration)

        self._log_tensorboard(label, 'eval/rel_nec_prec_micro', rel_nec_prec_micro, global_iteration)
        self._log_tensorboard(label, 'eval/rel_nec_recall_micro', rel_nec_rec_micro, global_iteration)
        self._log_tensorboard(label, 'eval/rel_nec_f1_micro', rel_nec_f1_micro, global_iteration)
        self._log_tensorboard(label, 'eval/rel_nec_prec_macro', rel_nec_prec_macro, global_iteration)
        self._log_tensorboard(label, 'eval/rel_nec_recall_macro', rel_nec_rec_macro, global_iteration)
        self._log_tensorboard(label, 'eval/rel_nec_f1_macro', rel_nec_f1_macro, global_iteration)

        # log to csv
        self._log_csv(label, 'eval', ner_prec_micro, ner_rec_micro, ner_f1_micro,
                      ner_prec_macro, ner_rec_macro, ner_f1_macro,

                      rel_prec_micro, rel_rec_micro, rel_f1_micro,
                      rel_prec_macro, rel_rec_macro, rel_f1_macro,

                      rel_nec_prec_micro, rel_nec_rec_micro, rel_nec_f1_micro,
                      rel_nec_prec_macro, rel_nec_rec_macro, rel_nec_f1_macro,
                      epoch, iteration, global_iteration)

    def _log_datasets(self, input_reader):
        self._logger.info("Relation type count: %s" % input_reader.relation_type_count)
        self._logger.info("Entity type count: %s" % input_reader.entity_type_count)

        self._logger.info("Entities:")
        for e in input_reader.entity_types.values():
            self._logger.info(e.verbose_name + '=' + str(e.index))

        self._logger.info("Relations:")
        for r in input_reader.relation_types.values():
            self._logger.info(r.verbose_name + '=' + str(r.index))

        for k, d in input_reader.datasets.items():
            self._logger.info('Dataset: %s' % k)
            self._logger.info("Document count: %s" % d.document_count)
            self._logger.info("Relation count: %s" % d.relation_count)
            self._logger.info("Entity count: %s" % d.entity_count)

        self._logger.info("Context size: %s" % input_reader.context_size)

    def _init_train_logging(self, label):
        self._add_dataset_logging(label,
                                  data={'lr': ['lr', 'epoch', 'iteration', 'global_iteration'],
                                        'loss': ['loss', 'epoch', 'iteration', 'global_iteration'],
                                        'loss_avg': ['loss_avg', 'epoch', 'iteration', 'global_iteration']})

    def _init_eval_logging(self, label):
        self._add_dataset_logging(label,
                                  data={'eval': ['ner_prec_micro', 'ner_rec_micro', 'ner_f1_micro',
                                                 'ner_prec_macro', 'ner_rec_macro', 'ner_f1_macro',
                                                 'rel_prec_micro', 'rel_rec_micro', 'rel_f1_micro',
                                                 'rel_prec_macro', 'rel_rec_macro', 'rel_f1_macro',
                                                 'rel_nec_prec_micro', 'rel_nec_rec_micro', 'rel_nec_f1_micro',
                                                 'rel_nec_prec_macro', 'rel_nec_rec_macro', 'rel_nec_f1_macro',
                                                 'epoch', 'iteration', 'global_iteration']})
