# Running the relation model
In this notebook we will train and evalute, from scratch, the relation model proposed in the research paper [A Frustratingly Easy Approach for Entity and Relation Extraction](https://arxiv.org/pdf/2010.12812.pdf).

This is a reproduction based on the instructions left by the authors in their [GitHub repo](https://github.com/princeton-nlp/PURE)

**Environment information**

- Windows 11
- Python 3.6.13
- pip 21.2.2

## Basic setup
Firstly, run the notebook relation_setup to load needed classes and functions

In [1]:
%run relation_setup.ipynb

  from .autonotebook import tqdm as notebook_tqdm


## Training and evaluating the relation model from scratch

Now we train our own relation model from scratch on the same dataset. And then we will evaluate it using the test data nad compare its performance to the pre-trained model

In [28]:
model_name = 'allenai/scibert_scivocab_uncased'
add_new_tokens = False
no_cuda = False
do_train = True
do_eval = True
eval_test = True
do_lower_case = True
entity_output_dir = os.getcwd() + '/scierc_models/from-scratch/ent-scib-ctx0/'
entity_predictions_dev = 'ent_pred_dev.json'
eval_with_gold = True
context_window = 0
max_seq_length = 128
entity_predictions_test = 'ent_pred_test.json'
seed = 0
output_dir = os.getcwd() + '/scierc_models/from-scratch/rel-scib-ctx0/'
negative_label = 'no_relation'
task = 'scierc'
train_mode = 'random_sorted'
train_batch_size = 8
eval_batch_size = 8
num_train_epochs = 10
train_file = "/scierc_data/processed_data/json/train.json"
eval_per_epoch = 10
learning_rate = 2e-5
prediction_file = 'predictions.json'
BertLayerNorm = torch.nn.LayerNorm
train_mode = 'random_sorted'
bertadam = True
warmup_proportion = 0.1
eval_metric = 'f1'
task_rel_labels = {
    'ace04': ['PER-SOC', 'OTHER-AFF', 'ART', 'GPE-AFF', 'EMP-ORG', 'PHYS'],
    'ace05': ['ART', 'ORG-AFF', 'GEN-AFF', 'PHYS', 'PER-SOC', 'PART-WHOLE'],
    'scierc': ['PART-OF', 'USED-FOR', 'FEATURE-OF', 'CONJUNCTION', 'EVALUATE-FOR', 'HYPONYM-OF', 'COMPARE'],
}

In [29]:
CLS = "[CLS]"
SEP = "[SEP]"

RelationModel = BertForRelation

device = torch.device("cuda" if torch.cuda.is_available() and not no_cuda else "cpu")
n_gpu = torch.cuda.device_count()

# train set
if do_train:
    train_dataset, train_examples, train_nrel = generate_relation_data(train_file, use_gold=True, context_window=context_window)
# dev set
if (do_eval and do_train) or (do_eval and not(eval_test)):
    eval_dataset, eval_examples, eval_nrel = generate_relation_data(os.path.join(entity_output_dir, entity_predictions_dev), use_gold=eval_with_gold, context_window=context_window)
# test set
if eval_test:
    test_dataset, test_examples, test_nrel = generate_relation_data(os.path.join(entity_output_dir, entity_predictions_test), use_gold=eval_with_gold, context_window=context_window)
    
setseed(seed)

if not do_train and not do_eval:
    raise ValueError("At least one of `do_train` or `do_eval` must be True.")

if not os.path.exists(output_dir):
    os.makedirs(output_dir)
if do_train:
    logger.addHandler(logging.FileHandler(os.path.join(output_dir, "train.log"), 'w'))
else:
    logger.addHandler(logging.FileHandler(os.path.join(output_dir, "eval.log"), 'w'))
    
# get label_list
if os.path.exists(os.path.join(output_dir, 'label_list.json')):
    with open(os.path.join(output_dir, 'label_list.json'), 'r') as f:
        label_list = json.load(f)
else:
    label_list = [negative_label] + task_rel_labels[task]
    with open(os.path.join(output_dir, 'label_list.json'), 'w') as f:
        json.dump(label_list, f)
label2id = {label: i for i, label in enumerate(label_list)}
id2label = {i: label for i, label in enumerate(label_list)}
num_labels = len(label_list)

tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=do_lower_case)
if add_new_tokens:
    add_marker_tokens(tokenizer, task_ner_labels[task])

if os.path.exists(os.path.join(output_dir, 'special_tokens.json')):
    with open(os.path.join(output_dir, 'special_tokens.json'), 'r') as f:
        special_tokens = json.load(f)
else:
    special_tokens = {}
    
if do_eval and (do_train or not(eval_test)):
    eval_features = convert_examples_to_features(
        eval_examples, label2id, max_seq_length, tokenizer, special_tokens, unused_tokens=not(add_new_tokens))
    logger.info("***** Dev *****")
    logger.info("  Num examples = %d", len(eval_examples))
    logger.info("  Batch size = %d", eval_batch_size)
    all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
    all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
    all_sub_idx = torch.tensor([f.sub_idx for f in eval_features], dtype=torch.long)
    all_obj_idx = torch.tensor([f.obj_idx for f in eval_features], dtype=torch.long)
    eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_sub_idx, all_obj_idx)
    eval_dataloader = DataLoader(eval_data, batch_size=eval_batch_size)
    eval_label_ids = all_label_ids

    
if do_train:
    train_features = convert_examples_to_features(
        train_examples, label2id, max_seq_length, tokenizer, special_tokens, unused_tokens=not(add_new_tokens))
    if train_mode == 'sorted' or train_mode == 'random_sorted':
        train_features = sorted(train_features, key=lambda f: np.sum(f.input_mask))
    else:
        random.shuffle(train_features)
    all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
    all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
    all_sub_idx = torch.tensor([f.sub_idx for f in train_features], dtype=torch.long)
    all_obj_idx = torch.tensor([f.obj_idx for f in train_features], dtype=torch.long)
    train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_sub_idx, all_obj_idx)
    train_dataloader = DataLoader(train_data, batch_size=train_batch_size)
    train_batches = [batch for batch in train_dataloader]

    num_train_optimization_steps = len(train_dataloader) * num_train_epochs

    logger.info("***** Training *****")
    logger.info("  Num examples = %d", len(train_examples))
    logger.info("  Batch size = %d", train_batch_size)
    logger.info("  Num steps = %d", num_train_optimization_steps)

    best_result = None
    eval_step = max(1, len(train_batches) // eval_per_epoch)

    lr = learning_rate
    model = RelationModel.from_pretrained(
        'allenai/scibert_scivocab_uncased', cache_dir=str(PYTORCH_PRETRAINED_BERT_CACHE), num_rel_labels=num_labels)
    if hasattr(model, 'bert'):
        model.bert.resize_token_embeddings(len(tokenizer))
    elif hasattr(model, 'albert'):
        model.albert.resize_token_embeddings(len(tokenizer))
    else:
        raise TypeError("Unknown model class")

    model.to(device)
    if n_gpu > 1:
        model = torch.nn.DataParallel(model)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer
                    if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer
                    if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=lr, correct_bias=not(bertadam))
    scheduler = get_linear_schedule_with_warmup(optimizer, int(num_train_optimization_steps * warmup_proportion), num_train_optimization_steps)

    start_time = time.time()
    global_step = 0
    tr_loss = 0
    nb_tr_examples = 0
    nb_tr_steps = 0
    for epoch in range(int(num_train_epochs)):
        model.train()
        logger.info("Start epoch #{} (lr = {})...".format(epoch, lr))
        if train_mode == 'random' or train_mode == 'random_sorted':
            random.shuffle(train_batches)
        for step, batch in enumerate(train_batches):
            batch = tuple(t.to(device) for t in batch)
            input_ids, input_mask, segment_ids, label_ids, sub_idx, obj_idx = batch
            loss = model(input_ids, segment_ids, input_mask, label_ids, sub_idx, obj_idx)
            if n_gpu > 1:
                loss = loss.mean()

            loss.backward()

            tr_loss += loss.item()
            nb_tr_examples += input_ids.size(0)
            nb_tr_steps += 1

            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            global_step += 1

            if (step + 1) % eval_step == 0:
                logger.info('Epoch: {}, Step: {} / {}, used_time = {:.2f}s, loss = {:.6f}'.format(
                            epoch, step + 1, len(train_batches),
                            time.time() - start_time, tr_loss / nb_tr_steps))
                save_model = False
                if do_eval:
                    preds, result, logits = evaluate(model, device, eval_dataloader, eval_label_ids, num_labels, e2e_ngold=eval_nrel)
                    model.train()
                    result['global_step'] = global_step
                    result['epoch'] = epoch
                    result['learning_rate'] = lr
                    result['batch_size'] = train_batch_size

                    if (best_result is None) or (result[eval_metric] > best_result[eval_metric]):
                        best_result = result
                        logger.info("!!! Best dev %s (lr=%s, epoch=%d): %.2f" %
                                    (eval_metric, str(lr), epoch, result[eval_metric] * 100.0))
                        save_trained_model(output_dir, model, tokenizer)

if do_eval:
    logger.info(special_tokens)
    if eval_test:
        eval_dataset = test_dataset
        eval_examples = test_examples
        eval_features = convert_examples_to_features(
            test_examples, label2id, max_seq_length, tokenizer, special_tokens, unused_tokens=not(add_new_tokens))
        eval_nrel = test_nrel
        logger.info(special_tokens)
        logger.info("***** Test *****")
        logger.info("  Num examples = %d", len(test_examples))
        logger.info("  Batch size = %d", eval_batch_size)
        all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
        all_sub_idx = torch.tensor([f.sub_idx for f in eval_features], dtype=torch.long)
        all_obj_idx = torch.tensor([f.obj_idx for f in eval_features], dtype=torch.long)
        eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_sub_idx, all_obj_idx)
        eval_dataloader = DataLoader(eval_data, batch_size=eval_batch_size)
        eval_label_ids = all_label_ids
    model = RelationModel.from_pretrained(output_dir, num_rel_labels=num_labels)
    model.to(device)
    preds, result, logits = evaluate(model, device, eval_dataloader, eval_label_ids, num_labels, e2e_ngold=eval_nrel)

    logger.info('*** Evaluation Results ***')
    for key in sorted(result.keys()):
        logger.info("  %s = %s", key, str(result[key]))

    print_pred_json(eval_dataset, eval_examples, preds, id2label, os.path.join(output_dir, prediction_file))

11/25/2023 19:32:21 - INFO - run_relation - Generate relation data from /scierc_data/processed_data/json/train.json
11/25/2023 19:32:22 - INFO - run_relation - #samples: 16872, max #sent.samples: 156
11/25/2023 19:32:22 - INFO - run_relation - Generate relation data from C:\Users\odaim\Documents\PURE reproduction/scierc_models/from-scratch/ent-scib-ctx0/ent_pred_dev.json
11/25/2023 19:32:22 - INFO - run_relation - #samples: 2488, max #sent.samples: 110
11/25/2023 19:32:22 - INFO - run_relation - Generate relation data from C:\Users\odaim\Documents\PURE reproduction/scierc_models/from-scratch/ent-scib-ctx0/ent_pred_test.json
11/25/2023 19:32:22 - INFO - run_relation - #samples: 5062, max #sent.samples: 156
11/25/2023 19:32:24 - INFO - transformers.configuration_utils - loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/allenai/scibert_scivocab_uncased/config.json from cache at C:\Users\odaim/.cache\torch\transformers\199e28e62d2210c23d63625bd9eecc20cf72a156b2

## Results

**Accuracy**: 0.8954

**Evaluation Loss**: 0.5672

**Precision**: 0.7106\
**Recall**: 0.6581\
**F1 Score**: 0.6833

Implications:

- Our model has a higher accuracy, indicating that it makes correct predictions for a larger proportion of the test set.
- The pre-trained model has a lower evaluation loss, suggesting better performance in terms of model fit.
- Our model has a higher F1 score, indicating a better balance between precision and recall.
- Our model has higher precision, suggesting that a larger proportion of its predicted positive instances are correct.
- The pre-trained model has slightly higher recall, indicating that it identifies a larger proportion of actual positive instances.

Our model generally outperforms the pre-trained model across most metrics. It has higher accuracy, a better F1 score, higher precision, and slightly lower recall. However, it's important to consider the specific requirements and goals of our application:

- If precision is crucial (minimizing false positives), our model is more favorable.
- If recall is crucial (minimizing false negatives), the pre-trained model has a slight edge.
- The difference in evaluation loss suggests that the pre-trained model may have better overall model fit.

In summary, the model is performing well in identifying entity relations, but there is still room for improvement, particularly in precision. The results provide insights into how well the model is generalizing to unseen data and its effectiveness in extracting relations between entities. Fine-tuning or adjusting the model architecture may be considered to further improve performance.