## Description

In this notebook, we demonstrate how to train and evaluate ASM related models on a specific fold as detailed in Section 4.

In [2]:
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

import sys
sys.path.insert(0, sys.path[0] + '/..')

import os
import pandas as pd

In [3]:
import torch

torch.cuda.is_available()  # make sure torch in installed correctly

True

In [4]:
val_fold = 'img_class'  # this fold of papers is used as the valdation set
test_fold = 'misc'  # this fold of papers is used as the test set
seed = 42  # random seed
data_dir = '../data'  # the directory to put all downloaded training related files
output_dir = '../_tmp'  # the directory to store model outputs and predictions
os.makedirs(output_dir, exist_ok=True)

### Data Preparation

Data files we need to run this notebook:
- `attributed_source.jsonl` (download from https://github.com/allenai/S2abEL/blob/main/data/release_data.tar.gz)
- `papers.jsonl` (download from https://github.com/allenai/S2abEL/blob/main/data/release_data.tar.gz)
- `entity_linking.jsonl` (download from https://github.com/allenai/S2abEL/blob/main/data/release_data.tar.gz)
- `CTC.pkl` (need to run [cell type classification notebook](cell_type_classification.ipynb) to get this)
- `ref_related_ents.pkl` (contains entities related to each paper, download [here](https://github.com/allenai/S2abEL/blob/main/data/train_data.tar.gz))
- `EL_bm25f_ent_can.pkl` (contains top rated candidates using bm25f, download [here](https://github.com/allenai/S2abEL/blob/main/data/train_data.tar.gz))
- `datasets.json`  (download from [Papers with Code](https://github.com/paperswithcode/paperswithcode-data))
- `methods.json`  (download from [Papers with Code](https://github.com/paperswithcode/paperswithcode-data))

Generate data for Attributed Source Matching and Atrributed Source Retrieval

In [31]:
asm = pd.read_json(f'{data_dir}/attributed_source.jsonl', lines=True)
papers = pd.read_json(f'{data_dir}/papers.jsonl', lines=True)
ctc = pd.read_pickle(f'{data_dir}/CTC.pkl')
el = pd.read_json(f'{data_dir}/entity_linking.jsonl', lines=True)

asm = asm.merge(ctc, on='cell_id', how='inner')
asm = asm.merge(el[['cell_id', 'pwc_url']], on='cell_id', how='inner')  # a cell needs to have GT PwC link to evaluate CER performance

In [179]:
from ASM.utils import generate_RPI_ML, get_ref_extract

ref_extract = get_ref_extract(papers)
asm_data = generate_RPI_ML(ref_extract, asm)
asm_data.to_pickle(f'{data_dir}/ASM.pkl')

100%|████████████████████████████████████████████████████████████████████████| 8429/8429 [00:09<00:00, 919.39it/s]
100%|██████████████████████████████████████████████████████████████████| 345373/345373 [00:05<00:00, 59429.17it/s]


Generate data for Direct Retrieval

In [45]:
from ED.utils import convert_EL_cans_to_triplet_training_data
import json

with open(f'{data_dir}/methods.json') as f:
    methods = json.load(f)

with open(f'{data_dir}/datasets.json') as f:
    datasets = json.load(f)


ents = methods + datasets
ent_map = {}

## For each PwC entity, we get its name, full name, description, and pwc url.
for m in ents:
    name = '' if m['name'] is None else m['name']
    full_name = '' if m['full_name'] is None else m.get('full_name', '')
    description = '' if m['description'] is None else m.get('description', '')
    ent_map[m['url']] = (name, full_name, description, m['url'])

pwc_entities = pd.DataFrame(ent_map.values(), columns=['name', 'full_name', 'description', 'url'])
pwc_entities['type'] = pwc_entities['url'].apply(lambda x: 'method' if '/method/' in x else 'dataset')
pwc_entities.to_pickle(f'{data_dir}/PwC_entities.pkl')
EL_bm25f_ent_can = pd.read_pickle(f'{data_dir}/EL_bm25f_ent_can.pkl')
DR_train = convert_EL_cans_to_triplet_training_data(asm.merge(EL_bm25f_ent_can, on='cell_id', how='inner'), ent_map, 'candidates_100', 'candidates_100', top_n=50)
DR_train.to_pickle(f'{data_dir}/DR_train.pkl')

### Train an ASM model

The Config sets up the training and test configurations.

In [6]:
from common_utils.common_ML_utils import Config
from common_utils.common_data_processing_utils import cell_rep_features, paper_rep_features

config = Config(
        test_fold=test_fold,
        valid_fold=val_fold,
        input_file=f'{data_dir}/ASM.pkl',
        seed=seed,
        lr=2e-5,
        epoch=2,
        BS=32,  # training batch size
        grad_accum_step=1,  # gradient accumulation steps, in case GPT doesn't fit a whole batch
        pool_mtd='avg',  # the pooling method used at the last layer of the transformer
        loss='BCE',  # binary-cross-entropy loss
        pretrained = "allenai/scibert_scivocab_uncased",  # the underlying pretrianed model
        input_cols = cell_rep_features + paper_rep_features,  
                    # features used in the cell representation in paper Sec 4
        use_labels=True,  # supervised training
        drop_duplicates=True,  # whether or not to drop duplicates in the training set
        eval_steps=300,
        eval_BS=256,  # evaluation batch size
        drop_ones_with_cell_ref=True,  # drop cells that have in-cell references, because the 
                                        # attributed source for those cells are best found by
                                        # the in-cell reference instead of a ML model
        name=f'ASM_{val_fold}_{test_fold}_{seed}',
        save_dir=output_dir
    )

In [7]:
from ED.trainers import train_cross_encoder_notebook

# trains a model and save it to the save_dir/name
train_cross_encoder_notebook(config)

Some weights of the model checkpoint at allenai/scibert_scivocab_uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


dropping ones with cell_reference
345373
268617
1    238175
0      5138
Name: labels, dtype: int64
Dropping duplicates!
243313
Index(['region_type', 'row_pos', 'reverse_row_pos', 'col_pos',
       'reverse_col_pos', 'has_reference', 'cell_content', 'row_context',
       'col_context', 'context_sentences', 'idx', 'year', 'author', 'title',
       'abstract', 'labels'],
      dtype='object')
243313


Map:   0%|          | 0/243313 [00:00<?, ? examples/s]

Map:   0%|          | 0/243313 [00:00<?, ? examples/s]

Map:   0%|          | 0/243313 [00:00<?, ? examples/s]

Map:   0%|          | 0/243313 [00:00<?, ? examples/s]

Map:   0%|          | 0/243313 [00:00<?, ? examples/s]

Map:   0%|          | 0/243313 [00:00<?, ? examples/s]

Map:   0%|          | 0/243313 [00:00<?, ? examples/s]

Map:   0%|          | 0/243313 [00:00<?, ? examples/s]

Map:   0%|          | 0/243313 [00:00<?, ? examples/s]

Map:   0%|          | 0/243313 [00:00<?, ? examples/s]

Map:   0%|          | 0/243313 [00:00<?, ? examples/s]

Map:   0%|          | 0/243313 [00:00<?, ? examples/s]

Map:   0%|          | 0/243313 [00:00<?, ? examples/s]

Map:   0%|          | 0/243313 [00:00<?, ? examples/s]

Map:   0%|          | 0/243313 [00:00<?, ? examples/s]

Map:   0%|          | 0/15967 [00:00<?, ? examples/s]

Map:   0%|          | 0/15967 [00:00<?, ? examples/s]

Map:   0%|          | 0/15967 [00:00<?, ? examples/s]

Map:   0%|          | 0/15967 [00:00<?, ? examples/s]

Map:   0%|          | 0/15967 [00:00<?, ? examples/s]

Map:   0%|          | 0/15967 [00:00<?, ? examples/s]

Map:   0%|          | 0/15967 [00:00<?, ? examples/s]

Map:   0%|          | 0/15967 [00:00<?, ? examples/s]

Map:   0%|          | 0/15967 [00:00<?, ? examples/s]

Map:   0%|          | 0/15967 [00:00<?, ? examples/s]

Map:   0%|          | 0/15967 [00:00<?, ? examples/s]

Map:   0%|          | 0/15967 [00:00<?, ? examples/s]

Map:   0%|          | 0/15967 [00:00<?, ? examples/s]

Map:   0%|          | 0/15967 [00:00<?, ? examples/s]

Map:   0%|          | 0/15967 [00:00<?, ? examples/s]

Map:   0%|          | 0/9337 [00:00<?, ? examples/s]

Map:   0%|          | 0/9337 [00:00<?, ? examples/s]

Map:   0%|          | 0/9337 [00:00<?, ? examples/s]

Map:   0%|          | 0/9337 [00:00<?, ? examples/s]

Map:   0%|          | 0/9337 [00:00<?, ? examples/s]

Map:   0%|          | 0/9337 [00:00<?, ? examples/s]

Map:   0%|          | 0/9337 [00:00<?, ? examples/s]

Map:   0%|          | 0/9337 [00:00<?, ? examples/s]

Map:   0%|          | 0/9337 [00:00<?, ? examples/s]

Map:   0%|          | 0/9337 [00:00<?, ? examples/s]

Map:   0%|          | 0/9337 [00:00<?, ? examples/s]

Map:   0%|          | 0/9337 [00:00<?, ? examples/s]

Map:   0%|          | 0/9337 [00:00<?, ? examples/s]

Map:   0%|          | 0/9337 [00:00<?, ? examples/s]

Map:   0%|          | 0/9337 [00:00<?, ? examples/s]

100%|█████████████████████████████████████████████████████████████████████| 15208/15208 [3:08:13<00:00,  1.35it/s]


### Generate ASM predictions

In [16]:
from ASM.experiments import ASMEvalExpNB

exp = ASMEvalExpNB(f'{output_dir}/ASM_{val_fold}_{test_fold}_{seed}')

# generate predictions for the ASM task for each cell
# i.e., each cell will have papers in the referecen section and the current paper ranked by the possibility
# of being the attributed source for the cell
ASM_preds = exp.make_predictions(save_path=None, enhance=True)

Map:   0%|          | 0/38823 [00:00<?, ? examples/s]

Map:   0%|          | 0/38823 [00:00<?, ? examples/s]

Map:   0%|          | 0/38823 [00:00<?, ? examples/s]

Map:   0%|          | 0/38823 [00:00<?, ? examples/s]

Map:   0%|          | 0/38823 [00:00<?, ? examples/s]

Map:   0%|          | 0/38823 [00:00<?, ? examples/s]

Map:   0%|          | 0/38823 [00:00<?, ? examples/s]

Map:   0%|          | 0/38823 [00:00<?, ? examples/s]

Map:   0%|          | 0/38823 [00:00<?, ? examples/s]

Map:   0%|          | 0/38823 [00:00<?, ? examples/s]

Map:   0%|          | 0/38823 [00:00<?, ? examples/s]

Map:   0%|          | 0/38823 [00:00<?, ? examples/s]

Map:   0%|          | 0/38823 [00:00<?, ? examples/s]

Map:   0%|          | 0/38823 [00:00<?, ? examples/s]

Map:   0%|          | 0/38823 [00:00<?, ? examples/s]

100%|███████████████████████████████████████████████████████████████████████████| 152/152 [03:12<00:00,  1.27s/it]
100%|██████████████████████████████████████████████████████████████████████████| 904/904 [00:03<00:00, 277.13it/s]


### Generate ASR predictions

In [17]:
from ED.candidate_gen import get_ASR_candidates_encoders


# Use the ASM predictions to fetch candidates from PwC KB to generate ASR candidates, as detailed in 
# Section 4 of the paepr
ASR_preds = get_ASR_candidates_encoders(ASM_preds, 
                                        pd.read_pickle(f'{data_dir}/ref_related_ents.pkl'), 
                                        ASM_pred_cols=['ASM_preds'], 
                                        output_cols=['ASR_candidates'],
                                        top_ns=[100])

ASR_preds.to_pickle(f'{output_dir}/{val_fold}_{test_fold}_{seed}_ASR_preds')

100%|███████████████████████████████████████████████████████████████████████████████| 1/1 [00:23<00:00, 23.02s/it]


Evaluate ASR performance on one fold 

In [29]:
from ED.candidate_gen import test_candidate_effectiveness

el = pd.read_json(f'{data_dir}/entity_linking.jsonl', lines=True)
ASR_preds = ASR_preds.merge(el, on='cell_id', how='inner')

recall_at_K = test_candidate_effectiveness(ASR_preds[ASR_preds.fold==test_fold], 'ASR_candidates', 'ASR_candidates', top_n=50)[1]
print(recall_at_K)

85.36585365853658


### Train a DR model

In [34]:
from ASM.experiments import DirectRetrievalTripletTrain

# Train a Direct Retrieval model with a Bi-encoder arch to rank all candidates in PwC ontology
# detailed in Section 4 of the paper
exp = DirectRetrievalTripletTrain(
            seed=seed, 
            test_fold=test_fold,
            valid_fold=val_fold,
            BS=32,  # training batch size
            grad_accum_step=2,  # gradient accumulation steps, in case GPT doesn't fit a whole batch
            eval_BS=64,
            epoch=2,
            lr=2e-5,
            eval_steps=300,
            input_file=f'{data_dir}/DR_train.pkl', 
            save_dir=output_dir,
            name=f'DR_{val_fold}_{test_fold}_{seed}')
exp.train()

Some weights of the model checkpoint at allenai/scibert_scivocab_uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at allenai/scibert_scivocab_uncased were not used when initializing Ber

Dropping duplicates!
183240
183240


Map:   0%|          | 0/183240 [00:00<?, ? examples/s]

Map:   0%|          | 0/183240 [00:00<?, ? examples/s]

Map:   0%|          | 0/183240 [00:00<?, ? examples/s]

Map:   0%|          | 0/183240 [00:00<?, ? examples/s]

Map:   0%|          | 0/183240 [00:00<?, ? examples/s]

Map:   0%|          | 0/183240 [00:00<?, ? examples/s]

Map:   0%|          | 0/183240 [00:00<?, ? examples/s]

Map:   0%|          | 0/183240 [00:00<?, ? examples/s]

Map:   0%|          | 0/183240 [00:00<?, ? examples/s]

Map:   0%|          | 0/183240 [00:00<?, ? examples/s]

Map:   0%|          | 0/183240 [00:00<?, ? examples/s]

Map:   0%|          | 0/183240 [00:00<?, ? examples/s]

Map:   0%|          | 0/183240 [00:00<?, ? examples/s]

Map:   0%|          | 0/183240 [00:00<?, ? examples/s]

Map:   0%|          | 0/183240 [00:00<?, ? examples/s]

Map:   0%|          | 0/12067 [00:00<?, ? examples/s]

Map:   0%|          | 0/12067 [00:00<?, ? examples/s]

Map:   0%|          | 0/12067 [00:00<?, ? examples/s]

Map:   0%|          | 0/12067 [00:00<?, ? examples/s]

Map:   0%|          | 0/12067 [00:00<?, ? examples/s]

Map:   0%|          | 0/12067 [00:00<?, ? examples/s]

Map:   0%|          | 0/12067 [00:00<?, ? examples/s]

Map:   0%|          | 0/12067 [00:00<?, ? examples/s]

Map:   0%|          | 0/12067 [00:00<?, ? examples/s]

Map:   0%|          | 0/12067 [00:00<?, ? examples/s]

Map:   0%|          | 0/12067 [00:00<?, ? examples/s]

Map:   0%|          | 0/12067 [00:00<?, ? examples/s]

Map:   0%|          | 0/12067 [00:00<?, ? examples/s]

Map:   0%|          | 0/12067 [00:00<?, ? examples/s]

Map:   0%|          | 0/12067 [00:00<?, ? examples/s]

Training length 183240
Validation length 12067


11454it [5:31:23,  1.74s/it]                                                                                      


### Compute DR outputs

In [None]:
from common_utils.common_exp_utils import EvalExperimentNB

exp = EvalExperimentNB(model_path=f'{output_dir}/DR_{val_fold}_{test_fold}_{seed}')

## Use the trained DR model to generate PwC entity embeddings once
ent_embeds = exp.compute_ent_enmbeddings(ent_file_path=f'{data_dir}/PwC_entities.pkl',
                                        ent_emb_save_path=f'{output_dir}/DR_{val_fold}_{test_fold}_{seed}_embed', 
                                        mode='Triplet')

## For each cell, use the trained DR model to embed it and then find the PwC entities with
## closest distance in the embedding space
DR_preds = exp.generate_candidates(ent_file_path=f'{data_dir}/PwC_entities.pkl',
                                    ent_emb_save_path=f'{output_dir}/DR_{val_fold}_{test_fold}_{seed}_embed', 
                                    el_df = asm,
                                    mode='Triplet', 
                                    top_k=100)

DR_preds.to_pickle(f'{output_dir}/{val_fold}_{test_fold}_{seed}_DR_preds')

Evaluate DR performance on one fold

In [50]:
from ED.candidate_gen import test_candidate_effectiveness

recall_at_K = test_candidate_effectiveness(DR_preds[DR_preds.fold==test_fold], 'DR_candidates', 'DR_candidates', top_n=50)[1]
print(recall_at_K)

68.29268292682927


### Interleave DR and ASR candidates

In [52]:
from ED.candidate_gen import mix_candidate_set

outputs = mix_candidate_set(ASR_preds, 'ASR_candidates', DR_preds, 'DR_candidates', top_ks_each=(100, ))
outputs.to_pickle(f'{output_dir}/{val_fold}_{test_fold}_{seed}_CER_preds')

Evaluate DR+ASR performance

In [57]:
from ED.candidate_gen import test_candidate_effectiveness

recall_at_K = test_candidate_effectiveness(outputs[outputs.fold==test_fold], 'candidates_100', 'candidates_100', top_n=30)[1]
print(recall_at_K)

97.5609756097561
