## Description

In this notebook, we demonstrate how to train and evaluate ASM related models on a specific fold as detailed in Section 4.

In [2]:
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

import sys
sys.path.insert(0, sys.path[0] + '/..')

import os
import pandas as pd

In [3]:
import torch

torch.cuda.is_available()  # make sure torch in installed correctly

True

In [9]:
val_fold = 'img_class'  # this fold of papers is used as the valdation set
test_fold = 'misc'  # this fold of papers is used as the test set
seed = 42  # random seed
data_dir = '../data/intermediate_data'  # the directory to put all downloaded training related files
output_dir = '../_tmp'  # the directory to store model outputs and predictions
os.makedirs(output_dir, exist_ok=True)

### Train an ASM model

The Config sets up the training and test configurations.

In [5]:
from common_utils.common_ML_utils import Config

config = Config(
        test_fold=test_fold,
        valid_fold=val_fold,
        input_file=f'{data_dir}/ASM.pkl',
        seed=seed,
        lr=2e-5,
        epoch=2,
        BS=32,  # training batch size
        grad_accum_step=1,  # gradient accumulation steps, in case GPT doesn't fit a whole batch
        pool_mtd='avg',  # the pooling method used at the last layer of the transformer
        loss='BCE',  # binary-cross-entropy loss
        pretrained = "allenai/scibert_scivocab_uncased",  # the underlying pretrianed model
        input_cols = ['region_type', 'row_id', 'reverse_row_id', 
                    'col_id', 'reverse_col_id', 'has_reference', 
                    'cell_content', 'row_context', 'col_context', 
                    'text_sentence_no_mask'],  
                    # features used in the cell representation in paper Sec 4
        use_labels=True,  # supervised training
        drop_duplicates=True,  # whether or not to drop duplicates in the training set
        eval_steps=300,
        eval_BS=256,  # evaluation batch size
        drop_ones_with_cell_ref=True,  # drop cells that have in-cell references, because the 
                                        # attributed source for those cells are best found by
                                        # the in-cell reference instead of a ML model
        name=f'ASM_{val_fold}_{test_fold}_{seed}',
        save_dir=output_dir
    )

In [None]:
from ED.trainers import train_cross_encoder_notebook

# trains a model and save it to the save_dir/name
train_cross_encoder_notebook(config)

Some weights of the model checkpoint at allenai/scibert_scivocab_uncased were not used when initializing BertModel: ['cls.predictions.decoder.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


dropping ones with cell_reference
345373
268617
1    238172
0      5141
Name: labels, dtype: int64
Dropping duplicates!
243313
243313


Map:   0%|          | 0/243313 [00:00<?, ? examples/s]

Map:   0%|          | 0/243313 [00:00<?, ? examples/s]

Map:   0%|          | 0/243313 [00:00<?, ? examples/s]

Map:   0%|          | 0/243313 [00:00<?, ? examples/s]

Map:   0%|          | 0/243313 [00:00<?, ? examples/s]

Map:   0%|          | 0/243313 [00:00<?, ? examples/s]

Map:   0%|          | 0/243313 [00:00<?, ? examples/s]

Map:   0%|          | 0/243313 [00:00<?, ? examples/s]

Map:   0%|          | 0/243313 [00:00<?, ? examples/s]

Map:   0%|          | 0/243313 [00:00<?, ? examples/s]

Map:   0%|          | 0/243313 [00:00<?, ? examples/s]

Map:   0%|          | 0/243313 [00:00<?, ? examples/s]

Map:   0%|          | 0/15967 [00:00<?, ? examples/s]

Map:   0%|          | 0/15967 [00:00<?, ? examples/s]

Map:   0%|          | 0/15967 [00:00<?, ? examples/s]

Map:   0%|          | 0/15967 [00:00<?, ? examples/s]

Map:   0%|          | 0/15967 [00:00<?, ? examples/s]

Map:   0%|          | 0/15967 [00:00<?, ? examples/s]

Map:   0%|          | 0/15967 [00:00<?, ? examples/s]

Map:   0%|          | 0/15967 [00:00<?, ? examples/s]

Map:   0%|          | 0/15967 [00:00<?, ? examples/s]

Map:   0%|          | 0/15967 [00:00<?, ? examples/s]

Map:   0%|          | 0/15967 [00:00<?, ? examples/s]

Map:   0%|          | 0/15967 [00:00<?, ? examples/s]

Map:   0%|          | 0/9337 [00:00<?, ? examples/s]

Map:   0%|          | 0/9337 [00:00<?, ? examples/s]

Map:   0%|          | 0/9337 [00:00<?, ? examples/s]

Map:   0%|          | 0/9337 [00:00<?, ? examples/s]

Map:   0%|          | 0/9337 [00:00<?, ? examples/s]

Map:   0%|          | 0/9337 [00:00<?, ? examples/s]

Map:   0%|          | 0/9337 [00:00<?, ? examples/s]

Map:   0%|          | 0/9337 [00:00<?, ? examples/s]

Map:   0%|          | 0/9337 [00:00<?, ? examples/s]

Map:   0%|          | 0/9337 [00:00<?, ? examples/s]

Map:   0%|          | 0/9337 [00:00<?, ? examples/s]

Map:   0%|          | 0/9337 [00:00<?, ? examples/s]

 64%|████████████████████████████████████████████████████████████████████████▏                                        | 9708/15208 [2:00:38<43:45,  2.09it/s]

### Generate ASM predictions

In [10]:
from ASM.experiments import ASMEvalExpNB

exp = ASMEvalExpNB(f'{output_dir}/ASM_{val_fold}_{test_fold}_{seed}')

# generate predictions for the ASM task for each cell
# i.e., each cell will have papers in the referecen section and the current paper ranked by the possibility
# of being the attributed source for the cell
ASM_preds = exp.make_predictions(save_path=None, enhance=True)

Map:   0%|          | 0/38823 [00:00<?, ? examples/s]

Map:   0%|          | 0/38823 [00:00<?, ? examples/s]

Map:   0%|          | 0/38823 [00:00<?, ? examples/s]

Map:   0%|          | 0/38823 [00:00<?, ? examples/s]

Map:   0%|          | 0/38823 [00:00<?, ? examples/s]

Map:   0%|          | 0/38823 [00:00<?, ? examples/s]

Map:   0%|          | 0/38823 [00:00<?, ? examples/s]

Map:   0%|          | 0/38823 [00:00<?, ? examples/s]

Map:   0%|          | 0/38823 [00:00<?, ? examples/s]

Map:   0%|          | 0/38823 [00:00<?, ? examples/s]

Map:   0%|          | 0/38823 [00:00<?, ? examples/s]

Map:   0%|          | 0/38823 [00:00<?, ? examples/s]

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 152/152 [03:11<00:00,  1.26s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 904/904 [00:03<00:00, 252.74it/s]


### Generate ASR predictions

In [13]:
from ED.candidate_gen import get_ASR_candidates_encoders


# Use the ASM predictions to fetch candidates from PwC KB to generate ASR candidates, as detailed in 
# Section 4 of the paepr
ASR_preds = get_ASR_candidates_encoders(ASM_preds, 
                                        pd.read_pickle(f'{data_dir}/ref_related_ents.pkl'), 
                                        ASM_pred_cols=['ASM_preds'], 
                                        output_cols=['ASR_candidates'],
                                        top_ns=[100])

ASR_preds.to_pickle(f'{output_dir}/{val_fold}_{test_fold}_{seed}_ASR_preds')

Evaluate ASR performance on one fold 

In [12]:
from ED.candidate_gen import test_candidate_effectiveness

recall_at_K = test_candidate_effectiveness(ASR_preds[ASR_preds.fold==test_fold], 'ASR_candidates', 'ASR_candidates', top_n=50)[1]
print(recall_at_K)

85.36585365853658


### Train a DR model

In [14]:
from ASM.experiments import DirectRetrievalTripletTrain

# Train a Direct Retrieval model with a Bi-encoder arch to rank all candidates in PwC ontology
# detailed in Section 4 of the paper
exp = DirectRetrievalTripletTrain(
            seed=seed, 
            test_fold=test_fold,
            valid_fold=val_fold,
            BS=32,  # training batch size
            grad_accum_step=2,  # gradient accumulation steps, in case GPT doesn't fit a whole batch
            eval_BS=64,
            epoch=2,
            lr=2e-5,
            eval_steps=300,
            input_file=f'{data_dir}/DR_train.pkl', 
            save_dir=output_dir,
            name=f'DR_{val_fold}_{test_fold}_{seed}')
exp.train()

Some weights of the model checkpoint at allenai/scibert_scivocab_uncased were not used when initializing BertModel: ['cls.predictions.decoder.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at allenai/scibert_scivocab_uncased were not used when initializing Ber

Dropping duplicates!
183240
183240


Map:   0%|          | 0/183240 [00:00<?, ? examples/s]

Map:   0%|          | 0/183240 [00:00<?, ? examples/s]

Map:   0%|          | 0/183240 [00:00<?, ? examples/s]

Map:   0%|          | 0/183240 [00:00<?, ? examples/s]

Map:   0%|          | 0/183240 [00:00<?, ? examples/s]

Map:   0%|          | 0/183240 [00:00<?, ? examples/s]

Map:   0%|          | 0/183240 [00:00<?, ? examples/s]

Map:   0%|          | 0/183240 [00:00<?, ? examples/s]

Map:   0%|          | 0/183240 [00:00<?, ? examples/s]

Map:   0%|          | 0/183240 [00:00<?, ? examples/s]

Map:   0%|          | 0/183240 [00:00<?, ? examples/s]

Map:   0%|          | 0/183240 [00:00<?, ? examples/s]

Map:   0%|          | 0/183240 [00:00<?, ? examples/s]

Map:   0%|          | 0/183240 [00:00<?, ? examples/s]

Map:   0%|          | 0/183240 [00:00<?, ? examples/s]

Map:   0%|          | 0/12067 [00:00<?, ? examples/s]

Map:   0%|          | 0/12067 [00:00<?, ? examples/s]

Map:   0%|          | 0/12067 [00:00<?, ? examples/s]

Map:   0%|          | 0/12067 [00:00<?, ? examples/s]

Map:   0%|          | 0/12067 [00:00<?, ? examples/s]

Map:   0%|          | 0/12067 [00:00<?, ? examples/s]

Map:   0%|          | 0/12067 [00:00<?, ? examples/s]

Map:   0%|          | 0/12067 [00:00<?, ? examples/s]

Map:   0%|          | 0/12067 [00:00<?, ? examples/s]

Map:   0%|          | 0/12067 [00:00<?, ? examples/s]

Map:   0%|          | 0/12067 [00:00<?, ? examples/s]

Map:   0%|          | 0/12067 [00:00<?, ? examples/s]

Map:   0%|          | 0/12067 [00:00<?, ? examples/s]

Map:   0%|          | 0/12067 [00:00<?, ? examples/s]

Map:   0%|          | 0/12067 [00:00<?, ? examples/s]

Training length 183240
Validation length 12067


11454it [5:32:04,  1.74s/it]                                                                                                                                 


### Compute DR outputs

In [15]:
from common_utils.common_exp_utils import EvalExperimentNB

exp = EvalExperimentNB(model_path=f'{output_dir}/DR_{val_fold}_{test_fold}_{seed}')

## Use the trained DR model to generate PwC entity embeddings once
ent_embeds = exp.compute_ent_enmbeddings(ent_file_path=f'{data_dir}/PwC_entities.pkl',
                                        ent_emb_save_path=f'{output_dir}/DR_{val_fold}_{test_fold}_{seed}_embed', 
                                        mode='Triplet')

## For each cell, use the trained DR model to embed it and then find the PwC entities with
## closest distance in the embedding space
el = pd.read_pickle(f'{data_dir}/EL_bm25f_ent_can.pkl')
ctc = pd.read_pickle(f'{data_dir}/CTC.pkl')
el_df = el[['ext_id', 'pwc_url']].merge(ctc, on='ext_id', how='inner')  # get context data for cells
DR_preds = exp.generate_candidates(ent_file_path=f'{data_dir}/PwC_entities.pkl',
                                    ent_emb_save_path=f'{output_dir}/DR_{val_fold}_{test_fold}_{seed}_embed', 
                                    el_df = el_df,
                                    mode='Triplet', 
                                    top_k=100)

DR_preds.to_pickle(f'{output_dir}/{val_fold}_{test_fold}_{seed}_DR_preds')

Map:   0%|          | 0/1942 [00:00<?, ? examples/s]

Map:   0%|          | 0/1942 [00:00<?, ? examples/s]

Map:   0%|          | 0/1942 [00:00<?, ? examples/s]

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 31/31 [00:09<00:00,  3.26it/s]


Map:   0%|          | 0/6550 [00:00<?, ? examples/s]

Map:   0%|          | 0/6550 [00:00<?, ? examples/s]

Map:   0%|          | 0/6550 [00:00<?, ? examples/s]

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 103/103 [00:25<00:00,  4.04it/s]


Map:   0%|          | 0/904 [00:00<?, ? examples/s]

Map:   0%|          | 0/904 [00:00<?, ? examples/s]

Map:   0%|          | 0/904 [00:00<?, ? examples/s]

Map:   0%|          | 0/904 [00:00<?, ? examples/s]

Map:   0%|          | 0/904 [00:00<?, ? examples/s]

Map:   0%|          | 0/904 [00:00<?, ? examples/s]

Map:   0%|          | 0/904 [00:00<?, ? examples/s]

Map:   0%|          | 0/904 [00:00<?, ? examples/s]

Map:   0%|          | 0/904 [00:00<?, ? examples/s]

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:04<00:00,  3.34it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 904/904 [05:14<00:00,  2.88it/s]


Evaluate DR performance on one fold

In [17]:
from ED.candidate_gen import test_candidate_effectiveness

recall_at_K = test_candidate_effectiveness(DR_preds[DR_preds.fold==test_fold], 'DR_candidates', 'DR_candidates', top_n=50)[1]
print(recall_at_K)

70.73170731707317


### Interleave DR and ASR candidates

In [19]:
from ED.candidate_gen import mix_candidate_set

outputs = mix_candidate_set(ASR_preds, 'ASR_candidates', DR_preds, 'DR_candidates', top_ks_each=(100, ))
outputs.to_pickle(f'{output_dir}/{val_fold}_{test_fold}_{seed}_CER_preds')

Evaluate DR+ASR performance

In [25]:
from ED.candidate_gen import test_candidate_effectiveness

recall_at_K = test_candidate_effectiveness(outputs[outputs.fold==test_fold], 'candidates_100', 'candidates_100', top_n=50)[1]
print(recall_at_K)

98.78048780487805
