In [38]:
model_name = "dmis-lab/biobert-base-cased-v1.1"
checkpoint_path = '../checkpoints/checkpoint_bio_yn_balanced_seed.pt'     

In [35]:
bioasq_paths = [f'../preproc_datasets/8B_golden/8B{i}_golden.json' for i in range(1,6)]
output_paths = [f'./predictions_yesno/8b_{i}_pred.csv' for i in range(1,6)]

In [40]:
def get_yn_predictions_csv(model_name,checkpoint_path, bioasq_test_path,output_path):    
    
    from transformers import BertForSequenceClassification
    import torch

    model = BertForSequenceClassification.from_pretrained(model_name, num_labels = 4)
    checkpoint = torch.load(checkpoint_path,map_location=torch.device('cpu'))

    from torch import nn
    class BERT_Arch(nn.Module):

        def __init__(self, model):

            super(BERT_Arch, self).__init__()

            self.model = model

            # dropout layer
            self.dropout = nn.Dropout(0.1)

            # relu activation function
            self.relu =  nn.ReLU()
            # dense layer 1
            self.fc1 = nn.Linear(4,512)

            # dense layer 2 (Output layer)
            self.fc2 = nn.Linear(512,2)
            #softmax activation function
            self.softmax = nn.LogSoftmax(dim=1)

        #define the forward pass
        def forward(self, input_ids,
                attention_mask,
                token_type_ids,labels):

            #pass the inputs to the model  
            outputs = self.model(input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,labels = labels)

            cls_hs = outputs.logits

            x = self.fc1(cls_hs)

            x = self.relu(x)

            x = self.dropout(x)

            # output layer
            x = self.fc2(x)

            # apply softmax activation
            x = self.softmax(x)

            return x

    model_full = BERT_Arch(model)

    checkpoint_model_dict = { k.replace('module.', ''): v for k, v in checkpoint['model_state_dict'].items()}

    model_full.load_state_dict(checkpoint_model_dict)

    from transformers import BertTokenizer
    # Load the BERT tokenizer.
    tokenizer = BertTokenizer.from_pretrained(model_name, 
                                              do_lower_case=True)

    import json
    import pandas as pd
    with open(bioasq_test_path, 'rb') as f:
        bio_yn_raw = json.load(f)['questions']

    bio_yn_raw = [question for question in bio_yn_raw if question['type'] == 'yesno']

    bio_yn_questions = [question['body'] for question in bio_yn_raw]
    bio_yn_ids = [question['id'] for question in bio_yn_raw]
    bio_yn_labels = [question['exact_answer'] for question in bio_yn_raw]
    bio_snippets = {question['id'] : [snippet['text'] 
                                      for snippet in question['snippets']] 
                    for question in bio_yn_raw}

    ids = []
    snippets = []
    for key, value in bio_snippets.items():
        for snippet in value:
            ids.append(key)
            snippets.append(snippet)

    snippets_df = pd.DataFrame({'id': ids,'snippet': snippets})
    questions_df = pd.DataFrame({'id': bio_yn_ids, 
                                 'question': bio_yn_questions,
                                'label': bio_yn_labels})
    val_df = pd.merge(snippets_df,questions_df, how = 'left', on = 'id')
    val_a = list(val_df.question)
    val_b = list(val_df.snippet)
    val_labels = [int(answer == 'yes') for answer in val_df.label]

    val_tokens = tokenizer(val_a,val_b, 
                           add_special_tokens=True,
                           max_length=500,
                           truncation=True, padding=True,return_tensors='pt')
    val_tokens['labels'] = val_labels

    val_predictions = []
    for i in range(len(val_a)):    
        inputs = tokenizer(val_a[i], val_b[i], 
                               add_special_tokens=True,
                               max_length=500,
                               truncation=True, padding=True,return_tensors='pt')
        output = model_full(**inputs, labels = torch.tensor(val_tokens['labels'][i]))
        pred = torch.argmax(output)
        val_predictions.append(int(pred))

    test_df = pd.DataFrame({'gold': val_labels,
                           'pred': val_predictions})

    test_df.to_csv(output_path)