Installing requirements

In [None]:
%%writefile requirements.txt
bert-score==0.3.7
chardet==4.0.0
click==7.1.2
cycler==0.10.0
dataclasses
datasets==1.3.0
dill==0.3.3
filelock==3.0.12
fsspec==0.8.7
gensim==3.8.3
huggingface-hub==0.0.2
idna==2.10
importlib-metadata==3.7.0
joblib==1.0.1
jsonlines==2.0.0
kiwisolver==1.3.1
matplotlib==3.3.4
multiprocess==0.70.11.1
nltk==3.5
numpy==1.19.5
packaging==20.9
pandas==1.1.5
pillow==8.1.0
pyarrow==3.0.0
pyparsing==2.4.7
python-dateutil==2.8.1
pytz==2021.1
regex==2020.11.13
requests==2.25.1
sacremoses==0.0.43
scikit-learn==0.24.1
scipy==1.5.4
six==1.15.0
sklearn==0.0
smart-open==4.2.0
threadpoolctl==2.1.0
tokenizers==0.10.1
torch==1.7.1
tqdm==4.49.0
transformers==4.3.3
typing-extensions==3.7.4.3
urllib3==1.26.3
xxhash==2.0.0
zipp==3.4.0

In [None]:
pip install -r requirements.txt

Downloading SciFact database

In [None]:
wget https://scifact.s3-us-west-2.amazonaws.com/release/latest/data.tar.gz
tar -xvf data.tar.gz

In [None]:
import torch
import jsonlines
import random
import os
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForSequenceClassification, get_cosine_schedule_with_warmup
from tqdm import tqdm
from typing import List
from sklearn.metrics import f1_score, precision_score, recall_score

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device "{device}"')
batch_size = 1
gpu_batch = 1

### Training Neutral classifier

In [None]:
class LabelPredictionDataset(Dataset):
    def __init__(self, corpus: str, claims: str):
        self.samples = []
        corpus = {doc['doc_id']: doc for doc in jsonlines.open(corpus)}
        label_encodings = {'CONTRADICT': 0, 'NOT_ENOUGH_INFO': 1, 'SUPPORT': 0}
        for claim in jsonlines.open(claims):
            if claim['evidence']:
                for doc_id, evidence_sets in claim['evidence'].items():
                    doc = corpus[int(doc_id)]
                    for evidence_set in evidence_sets:
                        rationale = [doc['abstract'][i].strip() for i in evidence_set['sentences']]
                        self.samples.append({
                            'claim': claim['claim'],
                            'rationale': ' '.join(rationale),
                            'label': label_encodings[evidence_set['label']]
                        })
                    rationale_idx = {s for es in evidence_sets for s in es['sentences']}
                    rationale_sentences = [doc['abstract'][i].strip() for i in sorted(list(rationale_idx))]
                    self.samples.append({
                        'claim': claim['claim'],
                        'rationale': ' '.join(rationale_sentences),
                        'label': label_encodings[evidence_sets[0]['label']]
                    })
                    non_rationale_idx = set(range(len(doc['abstract']))) - rationale_idx
                    non_rationale_idx = random.sample(non_rationale_idx,
                                                      k=min(random.randint(1, 2), len(non_rationale_idx)))
                    non_rationale_sentences = [doc['abstract'][i].strip() for i in sorted(list(non_rationale_idx))]
                    self.samples.append({
                        'claim': claim['claim'],
                        'rationale': ' '.join(non_rationale_sentences),
                        'label': label_encodings['NOT_ENOUGH_INFO']
                    })
            else:
                for doc_id in claim['cited_doc_ids']:
                    doc = corpus[int(doc_id)]
                    non_rationale_idx = random.sample(range(len(doc['abstract'])), k=random.randint(1, 2))
                    non_rationale_sentences = [doc['abstract'][i].strip() for i in non_rationale_idx]
                    self.samples.append({
                        'claim': claim['claim'],
                        'rationale': ' '.join(non_rationale_sentences),
                        'label': label_encodings['NOT_ENOUGH_INFO']
                    })

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

def encode(claim: List[str], rationale: List[str]):
    encoding = tokenizer(claim, rationale, padding=True, truncation=True, max_length=512, return_tensors="pt")
    input_ids = encoding['input_ids']
    attention_mask = encoding['attention_mask']
    return input_ids, attention_mask

def evaluate(model, dataset):
    model.eval()
    targets = []
    outputs = []
    with torch.no_grad():
        for batch in DataLoader(dataset, batch_size=gpu_batch):
            input_ids, attention_mask = encode(batch['claim'], batch['rationale'])
            logits = model(input_ids.to(device)).logits
            targets.extend(batch['label'].float().tolist())
            outputs.extend(logits.argmax(dim=1).tolist())
    return {
        'macro_f1': f1_score(targets, outputs, zero_division=0, average='macro'),
        'f1': tuple(f1_score(targets, outputs, zero_division=0, average=None)),
        'precision': tuple(precision_score(targets, outputs, zero_division=0, average=None)),
        'recall': tuple(recall_score(targets, outputs, zero_division=0, average=None))
    }

In [None]:
trainset = LabelPredictionDataset('./data/corpus.jsonl', './data/claims_train.jsonl')
devset = LabelPredictionDataset('./data/corpus.jsonl', './data/claims_dev.jsonl')

In [None]:
tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-large-cased-v1.1-mnli")
model = AutoModelForSequenceClassification.from_pretrained("dmis-lab/biobert-large-cased-v1.1-mnli").to(device)

In [None]:
optimizer = torch.optim.Adam([
    {'params': model.bert.parameters(), 'lr': 5e-5},
    {'params': model.classifier.parameters(), 'lr': 1e-4}
])
scheduler = get_cosine_schedule_with_warmup(optimizer, 128, 50 * batch_size)

In [None]:
for e in range(50):
    model.train()
    t = tqdm(DataLoader(trainset, batch_size=gpu_batch, shuffle=True))
    for i, batch in enumerate(t):
        input_ids, attention_mask = encode(batch['claim'], batch['rationale'])
        outputs = model(input_ids=input_ids.to(device), attention_mask=attention_mask.to(device), labels=batch['label'].long().to(device))
        loss = outputs.loss
        loss.backward()
        if (i + 1) % (batch_size // gpu_batch) == 0:
            optimizer.step()
            optimizer.zero_grad()
            t.set_description(f'Epoch {e}, iter {i}, loss: {round(loss.item(), 4)}')
    scheduler.step()
#     Eval
    train_score = evaluate(model, trainset)
    print(f'Epoch {e} train score:')
    print(train_score)
    dev_score = evaluate(model, devset)
    print(f'Epoch {e} dev score:')
    print(dev_score)
    # Save
    save_path = os.path.join('./saved_models', f'neutral_classifier-epoch-{e}-f1-{int(dev_score["macro_f1"] * 1e4)}')
    os.makedirs(save_path)
    tokenizer.save_pretrained(save_path)
    model.save_pretrained(save_path)

### Training Support classifier

In [None]:
class LabelPredictionDataset(Dataset):
    def __init__(self, corpus: str, claims: str):
        self.samples = []
        corpus = {doc['doc_id']: doc for doc in jsonlines.open(corpus)}
        label_encodings = {'CONTRADICT': 0, 'NOT_ENOUGH_INFO': 0, 'SUPPORT': 1}
        for claim in jsonlines.open(claims):
            if claim['evidence']:
                for doc_id, evidence_sets in claim['evidence'].items():
                    doc = corpus[int(doc_id)]
                    for evidence_set in evidence_sets:
                        rationale = [doc['abstract'][i].strip() for i in evidence_set['sentences']]
                        self.samples.append({
                            'claim': claim['claim'],
                            'rationale': ' '.join(rationale),
                            'label': label_encodings[evidence_set['label']]
                        })
                    rationale_idx = {s for es in evidence_sets for s in es['sentences']}
                    rationale_sentences = [doc['abstract'][i].strip() for i in sorted(list(rationale_idx))]
                    self.samples.append({
                        'claim': claim['claim'],
                        'rationale': ' '.join(rationale_sentences),
                        'label': label_encodings[evidence_sets[0]['label']]
                    })
                    non_rationale_idx = set(range(len(doc['abstract']))) - rationale_idx
                    non_rationale_idx = random.sample(non_rationale_idx,
                                                      k=min(random.randint(1, 2), len(non_rationale_idx)))
                    non_rationale_sentences = [doc['abstract'][i].strip() for i in sorted(list(non_rationale_idx))]
                    self.samples.append({
                        'claim': claim['claim'],
                        'rationale': ' '.join(non_rationale_sentences),
                        'label': label_encodings['NOT_ENOUGH_INFO']
                    })
            else:
                for doc_id in claim['cited_doc_ids']:
                    doc = corpus[int(doc_id)]
                    non_rationale_idx = random.sample(range(len(doc['abstract'])), k=random.randint(1, 2))
                    non_rationale_sentences = [doc['abstract'][i].strip() for i in non_rationale_idx]
                    self.samples.append({
                        'claim': claim['claim'],
                        'rationale': ' '.join(non_rationale_sentences),
                        'label': label_encodings['NOT_ENOUGH_INFO']
                    })
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        return self.samples[idx]
    
def encode(claim: List[str], rationale: List[str]):
    encoding = tokenizer(claim, rationale, padding=True, truncation=True, max_length=512, return_tensors="pt")
    input_ids = encoding['input_ids']
    attention_mask = encoding['attention_mask']
    return input_ids, attention_mask

def evaluate(model, dataset):
    model.eval()
    targets = []
    outputs = []
    with torch.no_grad():
        for batch in DataLoader(dataset, batch_size=1):
            input_ids, attention_mask = encode(batch['claim'], batch['rationale'])
            logits = model(input_ids.to(device)).logits
            targets.extend(batch['label'].float().tolist())
            outputs.extend(logits.argmax(dim=1).tolist())
    return {
        'macro_f1': f1_score(targets, outputs, zero_division=0, average='macro'),
        'f1': tuple(f1_score(targets, outputs, zero_division=0, average=None)),
        'precision': tuple(precision_score(targets, outputs, zero_division=0, average=None)),
        'recall': tuple(recall_score(targets, outputs, zero_division=0, average=None))
    }

In [None]:
trainset = LabelPredictionDataset('./data/corpus.jsonl', './data/claims_train.jsonl')
devset = LabelPredictionDataset('./data/corpus.jsonl', './data/claims_dev.jsonl')

In [None]:
tokenizer = AutoTokenizer.from_pretrained("roberta-large-mnli")
model = AutoModelForSequenceClassification.from_pretrained("roberta-large-mnli").to(device)

In [None]:
optimizer = torch.optim.Adam([
            {'params': model.roberta.parameters(), 'lr': 5e-5},
            {'params': model.classifier.parameters(), 'lr': 1e-4}
        ])
scheduler = get_cosine_schedule_with_warmup(optimizer, 128, 50 * 1)

In [None]:
for e in range(50):
    model.train()
    t = tqdm(DataLoader(trainset, batch_size=1, shuffle=True))
    for i, batch in enumerate(t):
        input_ids, attention_mask = encode(batch['claim'], batch['rationale'])
        outputs = model(input_ids=input_ids.to(device), attention_mask=attention_mask.to(device), labels=batch['label'].long().to(device))
        loss = outputs.loss
        loss.backward()
        if (i + 1) % (1 // 1) == 0:
            optimizer.step()
            optimizer.zero_grad()
            t.set_description(f'Epoch {e}, iter {i}, loss: {round(loss.item(), 4)}')
    scheduler.step()
    checkpoint = {
        'epoch': e + 1,
        'optimizer': optimizer.state_dict()
    }
    # Eval
    train_score = evaluate(model, trainset)
    print(f'Epoch {e} train score:')
    print(train_score)
    dev_score = evaluate(model, devset)
    print(f'Epoch {e} dev score:')
    print(dev_score)

    # Save
    save_path = os.path.join('./saved_models', f'support_classifier-epoch-{e}-f1-{int(dev_score["macro_f1"] * 1e4)}')
    os.makedirs(save_path)
    tokenizer.save_pretrained(save_path)
    model.save_pretrained(save_path)