Skip to content

Guzpenha/transformer_rankers

master
Switch branches/tags
Code

Latest commit

 

Git stats

Files

Permalink
Failed to load latest commit information.
Type
Name
Latest commit message
Commit time
 
 
 
 
 
 
 
 
 
 
 
 

Transformer-Rankers

Documentation license

Transformer-rankers is a library to conduct ranking experiments with transformers.

Most of the research experiments performed focused on the task of conversation response ranking, see EACL'21 and ECIR'20. This repo is intended to be used to perform research experiments and not to create production ready systems. Better alternatives for general ranking models are either pyterrier or pyserini.

Examples

Open In Colab Fine tune pointwise BERT for conversation response ranking.

Wandb report Wandb report of fine tunning BERT for conversation response ranking.

Setup

The following will clone the repo, install a virtual env and install the library with the requirements.

#Clone the repo
git clone https://github.com/Guzpenha/transformer_rankers.git
cd transformer_rankers    

#Optionally use a virtual enviroment
python3 -m venv env
source env/bin/activate

#Optionally use a virtual enviroment
pip install -e .
pip install -r requirements.txt

Code example: BERT-ranker for dialogue

The folowing example uses BERT for the task of conversation response ranking using MANtIS corpus. We can download the data as follows:

from transformer_rankers.datasets import downloader

#Download the data with DataDownloader
data_folder = "data"
dataDownloader = downloader.DataDownloader("mantis", data_folder)
dataDownloader.download_and_preprocess()

And train BERT for pointwise learning to rank with randomly sampled negative samples:

from transformers import BertTokenizer
from transformer_rankers.models import pointwise_bert
from transformer_rankers.trainers import transformer_trainer
from transformer_rankers.datasets import dataset, preprocess_crr
from transformer_rankers.negative_samplers import negative_sampling 
from transformer_rankers.eval import results_analyses_tools

#Load the dataset
task = "mantis"
train = pd.read_csv(data_folder+task+"/train.tsv", sep="\t")
valid = pd.read_csv(data_folder+task+"/valid.tsv", sep="\t")

#Instantiate random negative samplers (1 for training 9 negative candidates for test)
# the library also supports BM25 and sentenceBERT negative samplers.
ns_train = negative_sampling.RandomNegativeSampler(list(train["response"].values), 1)
ns_val = negative_sampling.RandomNegativeSampler(list(valid["response"].values) + \
    list(train["response"].values), 9)

tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
special_tokens_dict = {'additional_special_tokens': ['[UTTERANCE_SEP]', '[TURN_SEP]'] }
tokenizer.add_special_tokens(special_tokens_dict)

#Create the loaders for the datasets, with the respective negative samplers        
dataloader = dataset.QueryDocumentDataLoader(train_df=train, val_df=valid, test_df=valid,
                                tokenizer=tokenizer, negative_sampler_train=ns_train, 
                                negative_sampler_val=ns_val, task_type='classification', 
                                train_batch_size=6, val_batch_size=6, max_seq_len=512, 
                                sample_data=-1, cache_path="{}/{}".format(data_folder, task))

train_loader, val_loader, test_loader = dataloader.get_pytorch_dataloaders()


model = pointwise_bert.BertForPointwiseLearning.from_pretrained('bert-base-cased')
# we added [UTTERANCE_SEP] and [TURN_SEP] to the vocabulary so we need to resize the token embeddings
model.resize_token_embeddings(len(dataloader.tokenizer)) 

#Instantiate trainer that handles fitting.
trainer = transformer_trainer.TransformerTrainer(model=model,train_loader=train_loader,
                                val_loader=val_loader, test_loader=test_loader, 
                                num_ns_eval=9, task_type="classification", tokenizer=tokenizer,
                                validate_every_epoch=1, num_validation_batches=-1,
                                num_epochs=1, lr=0.0005, sacred_ex=None,
                                validate_every_steps=-1, num_training_instances=-1)

#Train the model
logging.info("Fitting BERT-ranker for MANtIS")
trainer.fit()

#Predict for test (in our example the validation set)
logging.info("Predicting")
preds, labels, _ = trainer.test()
res = results_analyses_tools.\
    evaluate_and_aggregate(preds, labels, ['ndcg_cut_10'])

for metric, v in res.items():
    logging.info("Test {} : {:4f}".format(metric, v))

About

A library to conduct ranking experiments with transformers.

Resources

License

Stars

Watchers

Forks

Languages