# Finetuning a language model to predict documents ranking

In this notebook, we'll attempt to finetune a pre-trained language model in the task of ranking prediction.
We'll use the CORD-19 dataset compiled by the Allen Institute for Artificial Intelligence, which comprehends more than 200.000 papers related to the Coronavirus pandemic.

We'll build the references graph and use the PageRank static ranking algorithm to assess the relevance of each paper.
Then, we'll put a classifier head on top of a pre-trained language model and finetune it using the ranking scores as supervision in a semi-supervised way.

## Building the references graph

Refer to the following notebooks to inspect the source code used to build the references graph.

- https://github.com/Inria-Chile/risotto/blob/master/01_references.ipynb
- https://github.com/Inria-Chile/risotto/blob/master/05_cook_artifacts.ipynb

In [None]:
import numpy as np

from risotto.artifacts import load_papers_artifact


def get_papers():
    papers = load_papers_artifact().fillna("N/A")
    papers["pagerank"] = np.log(papers["pagerank"])
    mean_pagerank = papers["pagerank"].mean()
    std_pagerank = papers["pagerank"].std()
    papers["pagerank"] = (papers["pagerank"] - mean_pagerank) / std_pagerank
    min_pagerank = papers["pagerank"].min()
    max_pagerank = papers["pagerank"].max()
    papers["pagerank"] = (papers["pagerank"] - min_pagerank) / (max_pagerank - min_pagerank)
    return papers

papers = get_papers()
display(papers.head())
papers.columns, papers.shape

Unnamed: 0_level_0,pagerank,affiliation,country,sha,source_x,title,doi,pmcid,pubmed_id,license,...,publish_time,authors,journal,mag_id,who_covidence_id,arxiv_id,pdf_json_files,pmc_json_files,url,s2_id
cord_uid,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
ug7v899j,0.0,,,d1aafb70c066a2068b02786f8929fd9c900897fb,PMC,Clinical features of culture-proven Mycoplasma...,10.1186/1471-2334-1-6,PMC35282,11472600.0,no-cc,...,2001-07-04,"Madani, Tariq A; Al-Ghamdi, Aisha A",BMC Infect Dis,,,,document_parses/pdf_json/d1aafb70c066a2068b027...,document_parses/pmc_json/PMC35282.xml.json,https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3...,
02tnwd4m,0.030459,University of Alabama at Birmingham,USA,6b0567729c2143a66d737eb0a2f63f2dce2e5a7d,PMC,Nitric oxide: a pro-inflammatory mediator in l...,10.1186/rr14,PMC59543,11668000.0,no-cc,...,2000-08-15,"Vliet, Albert van der; Eiserich, Jason P; Cros...",Respir Res,,,,document_parses/pdf_json/6b0567729c2143a66d737...,document_parses/pmc_json/PMC59543.xml.json,https://www.ncbi.nlm.nih.gov/pmc/articles/PMC5...,
ejv2xln0,0.21586,Washington University School of Medicine,USA,06ced00a5fc04215949aa72528f2eeaae1d58927,PMC,Surfactant protein-D and pulmonary host defense,10.1186/rr19,PMC59549,11668000.0,no-cc,...,2000-08-25,"Crouch, Erika C",Respir Res,,,,document_parses/pdf_json/06ced00a5fc04215949aa...,document_parses/pmc_json/PMC59549.xml.json,https://www.ncbi.nlm.nih.gov/pmc/articles/PMC5...,
2b73a28n,0.043255,,,348055649b6b8cf2b9a376498df9bf41f7123605,PMC,Role of endothelin-1 in lung disease,10.1186/rr44,PMC59574,11686900.0,no-cc,...,2001-02-22,"Fagan, Karen A; McMurtry, Ivan F; Rodman, David M",Respir Res,,,,document_parses/pdf_json/348055649b6b8cf2b9a37...,document_parses/pmc_json/PMC59574.xml.json,https://www.ncbi.nlm.nih.gov/pmc/articles/PMC5...,
9785vg6d,0.036417,National Institutes of Health (Laboratory of H...,USA,5f48792a5fa08bed9f56016f4981ae2ca6031b32,PMC,Gene expression in epithelial cells in respons...,10.1186/rr61,PMC59580,11686900.0,no-cc,...,2001-05-11,"Domachowske, Joseph B; Bonville, Cynthia A; Ro...",Respir Res,,,,document_parses/pdf_json/5f48792a5fa08bed9f560...,document_parses/pmc_json/PMC59580.xml.json,https://www.ncbi.nlm.nih.gov/pmc/articles/PMC5...,


(Index(['pagerank', 'affiliation', 'country', 'sha', 'source_x', 'title', 'doi',
        'pmcid', 'pubmed_id', 'license', 'abstract', 'publish_time', 'authors',
        'journal', 'mag_id', 'who_covidence_id', 'arxiv_id', 'pdf_json_files',
        'pmc_json_files', 'url', 's2_id'],
       dtype='object'),
 (62427, 21))

## Finetuning a pre--trained language model

We'll implement our model using PyTorch Lightning.

In [None]:
!pip install -U transformers torch pytorch-lightning

Requirement already up-to-date: transformers in ./venv-risotto/lib/python3.8/site-packages (3.0.2)
Requirement already up-to-date: torch in ./venv-risotto/lib/python3.8/site-packages (1.6.0)
Requirement already up-to-date: pytorch-lightning in ./venv-risotto/lib/python3.8/site-packages (0.8.5)
You should consider upgrading via the '/Users/rodolfo/repos/risotto/venv-risotto/bin/python3 -m pip install --upgrade pip' command.[0m


In [None]:
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from pytorch_lightning import Trainer
from pytorch_lightning.core.lightning import LightningModule

from transformers import AutoModelForSequenceClassification, AutoTokenizer, AdamW


class PapersDataset(Dataset):
    def __init__(self, df):
        self._df = df
    
    def __getitem__(self, idx):
        item = self._df.iloc[idx]
        text = item["title"] + " " + item["abstract"]
        pagerank = item["pagerank"]
        return {
            "text": text,
            "pagerank": torch.tensor(pagerank, dtype=torch.float)
        }
    
    def __len__(self):
        return len(self._df)
        
        
class RankingPredictor(LightningModule):
    def __init__(self, base_model, learning_rate=1e-5):
        super().__init__()
        
        self.language_model = AutoModelForSequenceClassification.from_pretrained(base_model)
        self.tokenizer = AutoTokenizer.from_pretrained(base_model)
        
        self.learning_rate = learning_rate
        
    
    def forward(self, papers):
        papers_encoded = self.tokenizer(
            papers,
            return_tensors="pt",
            padding=True,
            truncation=True
        )
        input_ids = papers_encoded["input_ids"]
        attention_mask = papers_encoded["attention_mask"]
        outputs = self.language_model(input_ids, attention_mask=attention_mask)
        return outputs
    
    def configure_optimizers(self):
        optimizer = AdamW(model.parameters(), lr=self.learning_rate)
        return optimizer
    
    def loss(self, predicted, target):
        predicted_activated = F.softmax(predicted, dim=1)
        predicted_sliced = predicted_activated[:,1].squeeze()
        mse = F.mse_loss(predicted_sliced, target)
        return mse
    
    def _inference(self, batch, _):
        text = batch["text"]
        target = batch["pagerank"]
        
        predicted = self(text)
        loss = self.loss(predicted, target)
        
        return {
            "loss": loss,
        }
    
    def training_step(self, batch, batch_idx):
        return self._inference(batch, batch_idx)

In [None]:
ds = PapersDataset(papers)
dl = DataLoader(
    ds,
    batch_size=4,
    shuffle=True,
)
model = RankingPredictor("bert-base-uncased")
trainer = Trainer(fast_dev_run=True)
trainer.fit(model, train_dataloader=dl)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

Epoch 1:   0%|          | 0/1 [00:00<?, ?it/s] 



AttributeError: 'tuple' object has no attribute 'softmax'