Installing Requirements

In [None]:
%%writefile requirements.txt
boto3==1.12.36
botocore==1.15.36
certifi==2020.4.5.1
chardet==3.0.4
click==7.1.1
docutils==0.15.2
filelock==3.0.12
idna==2.9
jmespath==0.9.5
joblib==0.14.1
jsonlines==1.2.0
numpy==1.18.2
pandas==1.0.3
python-dateutil==2.8.1
regex==2020.4.4
requests==2.23.0
s3transfer==0.3.3
sacremoses==0.0.38
scikit-learn==0.22.2.post1
scipy==1.4.1
scispacy==0.2.5
https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.2.5/en_core_sci_sm-0.2.5.tar.gz
sentencepiece==0.1.85
six==1.14.0
tokenizers==0.5.2
torch==1.5.0
tqdm==4.45.0
transformers==2.7.0

In [None]:
pip install -r requirements.txt

Downloading SciFact database

In [None]:
wget https://scifact.s3-us-west-2.amazonaws.com/release/latest/data.tar.gz
tar -xvf data.tar.gz

### Training

In [None]:
import torch
import jsonlines
import os

from torch.utils.data import Dataset, DataLoader
from transformers import get_cosine_schedule_with_warmup, RobertaTokenizer, RobertaForSequenceClassification
from tqdm import tqdm
from typing import List
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score, balanced_accuracy_score

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device "{device}"')

In [None]:
class SciFactRationaleSelectionDataset(Dataset):
    def __init__(self, corpus: str, dataset: str, abstracts):
        self.samples = []
        abstract_retrieval = jsonlines.open(abstracts)
        dataset = jsonlines.open(dataset)
        corpus = {doc['doc_id']: doc for doc in jsonlines.open(corpus)}
        for data, retrieval in tqdm(list(zip(dataset, abstract_retrieval))):
            assert data['id'] == retrieval['id']
            # Adding docs from reduced abstract method 1 and cited docs
            docs = set()
            for i in retrieval['retrieved_doc_ids']:
                if(len(docs)>=4):
                break
                docs.add(i)
            for i in retrieval['cited_doc_ids']:
                if(len(docs)>=4):
                break
                docs.add(i)
            for doc_id in docs:
                doc_id = str(doc_id)
                doc = corpus[int(doc_id)]
                #if the doc is correctly retrieved
                if doc_id in list(data['evidence'].keys()):
                    evidence_sentence_idx = {s for es in data['evidence'][doc_id] for s in es['sentences']}
                else:
                    evidence_sentence_idx = {}
                for i, sentence in enumerate(doc['abstract']):
                    self.samples.append({
                        'claim': data['claim'],
                        'sentence': sentence,
                        'evidence': i in evidence_sentence_idx
                    })
  
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

def encode(claims: List[str], sentences: List[str]):
    encoded_dict = tokenizer.batch_encode_plus(
        zip(sentences, claims),
        pad_to_max_length=True,
        return_tensors='pt')
    if encoded_dict['input_ids'].size(1) > 512:
        # Too long for the model. Truncate it
        encoded_dict = tokenizer.batch_encode_plus(
            zip(sentences, claims),
            max_length=512,
            truncation_strategy='only_first',
            pad_to_max_length=True,
            return_tensors='pt')
    encoded_dict = {key: tensor.to(device) for key, tensor in encoded_dict.items()}
    return encoded_dict

def evaluate(model, dataset):
    model.eval()
    targets = []
    outputs = []
    with torch.no_grad():
        for batch in DataLoader(dataset, batch_size=1):
            encoded_dict = encode(batch['claim'], batch['sentence'])
            logits = model(**encoded_dict)[0]
            targets.extend(batch['evidence'].float().tolist())
            outputs.extend(logits.argmax(dim=1).tolist())
    return f1_score(targets, outputs, zero_division=0),\
           precision_score(targets, outputs, zero_division=0),\
           recall_score(targets, outputs, zero_division=0), \
           accuracy_score(targets, outputs), \
           balanced_accuracy_score(targets, outputs)

In [None]:
corpus = './data/corpus.jsonl'
claim_train = './data/claims_train.jsonl'
claim_dev = './data/claims_dev.jsonl'
# Predicted abstract retrieval files here
abstract_train = './abstract_retrieval_Train.jsonl'
abstract_dev = './abstract_retrieval_Dev.jsonl'

In [None]:
trainset = SciFactRationaleSelectionDataset(corpus, claim_train, abstract_train)
devset = SciFactRationaleSelectionDataset(corpus, claim_dev, abstract_dev)

In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

tokenizer = AutoTokenizer.from_pretrained('dmis-lab/biobert-large-cased-v1.1')
model = AutoModelForSequenceClassification.from_pretrained('dmis-lab/biobert-large-cased-v1.1').to(device)

In [None]:
optimizer = torch.optim.Adam([
            {'params': model.bert.parameters(), 'lr': 5e-5},
            {'params': model.classifier.parameters(), 'lr': 1e-3}])
scheduler = get_cosine_schedule_with_warmup(optimizer, 128, 50 * 128)

In [None]:
for e in range(50):
    model.train()
    t = tqdm(DataLoader(trainset, batch_size=1, shuffle=True))
    for i, batch in enumerate(t):
        encoded_dict = encode(batch['claim'], batch['sentence'])
        loss, logits = model(**encoded_dict, labels=batch['evidence'].long().to(device))
        loss.backward()
        if (i + 1) % (128 // 1) == 0:
            optimizer.step()
            optimizer.zero_grad()
            t.set_description(f'Epoch {e}, iter {i}, loss: {round(loss.item(), 4)}')
    scheduler.step()
    train_score = evaluate(model, trainset)
    print(f'Epoch {e}, train f1: %.4f, precision: %.4f, recall: %.4f, acc: %.4f, balanced_acc: %.4f' % train_score)
    dev_score = evaluate(model, devset)
    print(f'Epoch {e}, dev f1: %.4f, precision: %.4f, recall: %.4f, acc: %.4f, balanced_acc: %.4f' % dev_score)
    # Save
    save_path = os.path.join('./saved_models', f'rationale_selection_epoch-{e}-f1-{int(dev_score[0] * 1e4)}')
    os.makedirs(save_path)
    tokenizer.save_pretrained(save_path)
    model.save_pretrained(save_path)