In [None]:
from transformers import BartTokenizer, BartForConditionalGeneration
from torch.utils.data import DataLoader
from transformers import AdamW, get_linear_schedule_with_warmup

from tqdm import tqdm
import numpy as np
import torch
import time
import datetime
import sys
import json

In [None]:
model = BartForConditionalGeneration.from_pretrained('facebook/bart-large')
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', additional_special_tokens=["<sentencemissing>"])

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [None]:
model.resize_token_embeddings(len(tokenizer))

In [None]:
def concatenate_premise(data):
    concat_premise = ""
    for idx, label in enumerate(data['label']):
        if label:
            concat_premise += data['premise'][idx] + " "
        else:
            concat_premise += "<sentencemissing>" + " "
    return concat_premise.rstrip()

In [None]:
train_data = []

In [None]:
with open('../dataset/PMCOA-Feb23-2022-train-mask.jsonl', 'r') as f:
    for line in tqdm(f):
        data = json.loads(line)
        premise = " ".join( [data["premise"][idx] for idx in np.argwhere( data["label"] )[:,0]])
        conclusion = " ".join(data['conclusion'])
        train_data.append([premise, conclusion])

In [None]:
dev_data = []

In [None]:
with open('../dataset/PMCOA-Feb23-2022-dev-mask.jsonl', 'r') as f:
    for line in tqdm(f):
        data = json.loads(line)
        premise = " ".join( [data["premise"][idx] for idx in np.argwhere( data["label"] )[:,0]])
        conclusion = " ".join(data['conclusion'])
        dev_data.append([premise, conclusion])

In [None]:
batch_size = 2
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
dev_dataloader = DataLoader(dev_data, batch_size=batch_size, shuffle=True)

In [None]:
num_epochs = 5
learning_rate = 1e-5
warmup_steps = int(0.1*(len(train_dataloader) * num_epochs / batch_size))
epsilon = 1e-8
sample_every = 1e4

optimizer = AdamW(model.parameters(),
                  lr=learning_rate,
                  eps=epsilon
                 )

In [None]:
total_steps = len(train_dataloader) * num_epochs
scheduler = get_linear_schedule_with_warmup(optimizer, 
                                            num_warmup_steps=warmup_steps, 
                                            num_training_steps=total_steps)

In [None]:
def format_time(elapsed):
    return str(datetime.timedelta(seconds=int(round((elapsed)))))

In [None]:
model.train()
total_t0 = time.time()
device = "cuda:0"

for epoch in range(num_epochs):
    total_train_loss = 0
    t0 = time.time()

    print('======== Epoch {:} / {:} ========'.format(epoch + 1, num_epochs))
    print('Training...')
    for step, batch in enumerate(train_dataloader):
        inputs = tokenizer.batch_encode_plus(batch[0], padding=True, truncation=True, return_tensors='pt')
        targets = tokenizer.batch_encode_plus(batch[1], padding=True, truncation=True, return_tensors='pt')
        
        input_ids = inputs['input_ids'].to(device)
        attention_mask = inputs['attention_mask'].to(device)
        target_ids = targets['input_ids'].to(device)
        target_attention_mask = targets['attention_mask'].to(device)
        
        outputs = model(input_ids=input_ids, attention_mask=attention_mask,
                        labels=target_ids, decoder_attention_mask=target_attention_mask)
                
        loss = outputs[0].mean()

        batch_loss = loss.item()
        total_train_loss += batch_loss
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        
    avg_train_loss = total_train_loss / len(train_dataloader)      
    training_time = format_time(time.time() - t0)

    print("  Average training loss: {0:.2f}".format(avg_train_loss))
    print("  Training epoch took: {:}".format(training_time))
    
    model.eval()  # Switch to evaluation mode
    
    total_eval_loss = 0
    eval_steps = 0

    for batch in dev_dataloader:
        inputs = tokenizer.batch_encode_plus(batch[0], padding=True, truncation=True, return_tensors='pt')
        targets = tokenizer.batch_encode_plus(batch[1], padding=True, truncation=True, return_tensors='pt')
        
        input_ids = inputs['input_ids'].to(device)
        attention_mask = inputs['attention_mask'].to(device)
        target_ids = targets['input_ids'].to(device)
        target_attention_mask = targets['attention_mask'].to(device)
        
        with torch.no_grad():
            outputs = model(input_ids=input_ids, attention_mask=attention_mask,
                            labels=target_ids, decoder_attention_mask=target_attention_mask)
            
            loss = outputs[0].mean()

        batch_loss = loss.item()
        total_eval_loss += batch_loss
        eval_steps += 1

    avg_eval_loss = total_eval_loss / eval_steps

    print("  Average evaluation loss: {0:.2f}".format(avg_eval_loss))

    output_dir = '../fine-tuned-models/Bart-large/nppl/' + str(epoch+1) + '-epoch/'
    model_to_save = model.module if hasattr(model, 'module') else model  # Take care of distributed/parallel training
    model_to_save.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)


#### Inference

In [None]:
def concatenate_premise(data):
    concat_premise = ""
    for idx, label in enumerate(data['label']):
        if label:
            concat_premise += data['premise'][idx] + " "
        else:
            concat_premise += "<sentencemissing>" + " "
    return concat_premise.rstrip()

In [None]:
test_data = []

In [None]:
with open('../dataset/PMCOA-Feb23-2022-test-mask-nppl.jsonl', 'r') as f:
    
    for line in tqdm(f):
        data = json.loads(line)
        premise = " ".join([data['premise'][idx] for idx in np.argwhere(data['label']) [:, 0]])
        conclusion = " ".join(data['conclusion'])
        id = data['pubmed_id']
        test_data.append([premise, conclusion, id])

In [None]:
batch_size = 10
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

In [None]:
import evaluate

rouge = evaluate.load("rouge")

In [None]:
import os 
import re

model_dirs = sorted(os.listdir('../fine-tuned-models/Bart-large/nppl/'), key=lambda x: int(re.search(r'\d+', x).group()))

In [None]:
model_dirs = ['1-epoch', '2-epoch', '3-epoch', '4-epoch', '5-epoch']

In [None]:
for model_dir in tqdm(model_dirs):
    
    print('======== Epoch {:} ========'.format(model_dir.split("-")[0]))

    model = BartForConditionalGeneration.from_pretrained('../fine-tuned-models/Bart-large/nppl/' + model_dir + '/').to('cuda:0')
    tokenizer = BartTokenizer.from_pretrained('../fine-tuned-models/Bart-large/nppl/' + model_dir + '/')
    
    print('Model loaded!')
    model.eval()
    device = 'cuda:0'
    predictions = []
    ground_truth = []

    ids = []

    for step, batch in enumerate(test_dataloader):

        inputs = tokenizer.batch_encode_plus(batch[0], padding=True, truncation=True, return_tensors='pt')
        
        input_ids = inputs['input_ids'].to(device)
        attention_mask = inputs['attention_mask'].to(device)
        
        with torch.no_grad():

            generated_ids = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                num_beams=1,  
                max_length=128,  
                early_stopping=True  
            )
            
            hyp_con = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
            ref_con = list(batch[1])
            predictions.append([con for con in hyp_con])
            ground_truth.append([con for con in ref_con])
            ref_ids = list(batch[2])
            ids.append([id for id in ref_ids])
        
    predictions = [pred for sl in predictions for pred in sl]
    references = [ref for sl in ground_truth for ref in sl]
    reference_ids = [id for sl in ids for id in sl]

    assert len(predictions) == len(references)
    
    score = rouge.compute(predictions=predictions, references=references)
    
    print(f"Rouge: {score}")
