In [1]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Define the path to your Google Drive folder
base_drive_path = '/content/drive/MyDrive/basic_benchmarking'

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
# %%
# Import necessary libraries
import torch
import pickle
import pandas as pd
import os
from collections import Counter
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification, AdamW, get_linear_schedule_with_warmup
import random
import numpy as np
import argparse
import datetime
import time
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, accuracy_score, confusion_matrix, f1_score
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
import spacy
import re
import json
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from torch import nn

# Utility function to format elapsed time
def format_time(elapsed):
    '''
    Takes a time in seconds and returns a string hh:mm:ss
    '''
    elapsed_rounded = int(round((elapsed)))
    return str(datetime.timedelta(seconds=elapsed_rounded))


In [3]:
# %%
# Define utility functions

def add_row(supp_set, new_line, labels_cat, original_line, data_df, index, pmid):
    """
    Adds a new row to the dataframe using pd.concat instead of append.
    """
    new_line = re.sub(' +', ' ', new_line)
    row = pd.DataFrame([{
        'supp_set': supp_set.strip(),
        'conclusion': new_line.strip(),
        'abstract': f"{supp_set.strip()} <exp> {new_line.strip()}",
        'label_cat': labels_cat,
        'ori_conclusion': original_line.strip(),
        'pmid': pmid,
        'index': index
    }])
    return pd.concat([data_df, row], ignore_index=True)

def swap_entity_tags(sent):
    new_sent = sent.replace('<re>', '<el1>').replace('<er>', '<le1>').replace('<el>', '<re>').replace('<le>', '<er>').replace('<el1>', '<el>').replace('<le1>', '<le>')
    return new_sent, True

def swap_entity_positions(sent):
    try:
        regulator = re.findall("<re>(.*?)<er>", sent)[0].strip()
        regulated = re.findall("<el>(.*?)<le>", sent)[0].strip()
        sent1 = re.sub('<re> '+re.escape(regulator)+' <er>', '<re1> '+regulated+ ' <er1>', sent)
        sent2 = re.sub('<el> '+re.escape(regulated)+' <le>', '<re> '+regulator+ ' <er>', sent1)
        sent3 = re.sub('<re1> '+re.escape(regulated)+ ' <er1>', '<el> '+regulated, sent2)
    except Exception as e:
        print('Error:', e)
        return sent, False
    return sent3, True

def swap_entity_names(sent):
    regulator = re.findall("<re>(.*?)<er>", sent)
    regulated = re.findall("<el>(.*?)<le>", sent)
    if not regulator or not regulated:
        return sent, False
    regulator = regulator[0].strip()
    regulated = regulated[0].strip()
    sent1 = re.sub('<re> '+re.escape(regulator)+' <er>', '<re> '+regulated+ ' <er>', sent)
    sent2 = re.sub('<el> '+re.escape(regulated)+' <le>', '<el> '+regulator+ ' <le>', sent1)
    return sent2, True

def swapNumber(line, supp_set):
    line_nums = re.findall(r'(?:\d*\.\d+|\d+)', line)
    abs_nums = re.findall(r'(?:\d*\.\d+|\d+)', supp_set)
    if not line_nums or not abs_nums:
        return line, False
    ln_num = random.choice(line_nums)
    ab_num = random.choice(abs_nums)
    new_line = line
    if ln_num != ab_num or f'{ln_num})' in line:
        new_line = re.sub(r'(?<!\.)\b{}\b(?!\.)'.format(re.escape(ln_num)), ab_num, line)
    if new_line == line:
        return new_line, False
    return new_line, True

def word_replace(sent):
    negative = False
    replacements = {
        r'\binhibits\b': 'promotes',
        r'\bpromotes\b': 'inhibits',
        r'\binhibition\b': 'promotion',
        r'\bpromotion\b': 'inhibition',
        r'\binhibitor\b': 'promoter',
        r'\bpromoter\b': 'inhibitor',
        r'\bincrease\b': 'decrease',
        r'\bdecrease\b': 'increase'
    }
    for pattern, repl in replacements.items():
        if re.search(pattern, sent):
            new_sent = re.sub(pattern, repl, sent, count=1)
            return new_sent, True
    return sent, False

def posToNeg(sent):
    negative = False
    replacements = {
        r'\bwas not\b': 'was',
        r'\bwere not\b': 'were',
        r'\bcannot\b': 'can',
        r'\bis not\b': 'is',
        r'\bisn\'t\b': 'is',
        r'\bwasn\'t\b': 'was',
        r'\baren\'t\b': 'are',
        r'\bweren\'t\b': 'were'
    }
    for pattern, repl in replacements.items():
        if re.search(pattern, sent):
            new_sent = re.sub(pattern, repl, sent, count=1)
            return new_sent, True
    return sent, False

def negToPos(sent):
    negative = False
    replacements = {
        r'\bwas\b': 'was not',
        r'\bwere\b': 'were not',
        r'\bdo\b': 'do not',
        r'\bcan\b': 'cannot',
        r'\bis\b': 'is not'
    }
    for pattern, repl in replacements.items():
        if re.search(pattern, sent):
            new_sent = re.sub(pattern, repl, sent, count=1)
            return new_sent, True
    return sent, False

def swap_random_entity_outside(sent, entity_type_map_reverse, entity_type_map):
    regulator = re.findall("<re>(.*?)<er>", sent)
    regulated = re.findall("<el>(.*?)<le>", sent)
    if not regulator or not regulated:
        return sent, False
    regulator = regulator[0].strip()
    regulated = regulated[0].strip()
    regulator_type = entity_type_map.get(regulator.strip(), None)
    regulated_type = entity_type_map.get(regulated.strip(), None)
    if not regulator_type and not regulated_type:
        return sent, False
    if regulator_type:
        possible_entities = entity_type_map_reverse.get(regulator_type, [])
        new_ent = random.choice(possible_entities) if possible_entities else regulator
        if new_ent.strip() != regulator.strip():
            new_sent = re.sub('<re> '+re.escape(regulator)+' <er>', '<re> temppp <er>', sent)
            new_sent = re.sub(re.escape(new_ent), regulator, new_sent)
            new_sent = re.sub('<re> temppp <er>', f'<re> {new_ent} <er>', new_sent)
            return new_sent, True
    if regulated_type:
        possible_entities = entity_type_map_reverse.get(regulated_type, [])
        new_ent = random.choice(possible_entities) if possible_entities else regulated
        if new_ent.strip() != regulated.strip():
            new_sent = re.sub('<el> '+re.escape(regulated)+' <le>', '<el> tempp <le>', sent)
            new_sent = re.sub(re.escape(new_ent), regulated, new_sent)
            new_sent = re.sub('<el> tempp <le>', f'<el> {new_ent} <le>', new_sent)
            return new_sent, True
    return sent, False

def generatedSamples_nd_SEN(line, gen_row):
    try:
        gen_exp = gen_row['generated_nd'].iloc[0][6:-16]  # Skip first <exp> token and last polarity label
        if len(set(re.findall("<el>(.*?)<le>", gen_exp))) > 1:
            return line, False
        if len(set(re.findall("<re>(.*?)<er>", gen_exp))) > 1:
            return line, False
        if gen_row['sat_er_and_ed'].iloc[0] < 1:
            return line, False
        return gen_exp, True
    except Exception as e:
        print('Error occurred:', e)
        return line, False

def generatedSamples_nd_SRE(line, gen_row, model, threshold=0.94):
    try:
        gen_exp = gen_row['generated_nd'].iloc[0][6:-16]
        if len(set(re.findall("<el>(.*?)<le>", gen_exp))) > 1:
            return line, False
        if len(set(re.findall("<re>(.*?)<er>", gen_exp))) > 1:
            return line, False
        if gen_row['sat_er_and_ed'].iloc[0] < 1:
            return line, False
        sentence_embeddings = model.encode([gen_exp, line])
        if cosine_similarity([sentence_embeddings[1]], [sentence_embeddings[0]]) > 0.90:
            return line, False
        return gen_exp, True
    except:
        return line, False

def generatedSamples(line, gen_row, scr_threshold=0.45, model=None):
    gen_exp = gen_row['BM_Exp']
    gen_lbl = gen_row['BM_lbl']
    gen_reg = gen_row['BM_reg']
    gen_ele = gen_row['BM_ele']
    true_lbl = gen_row['True_lbl']
    tru_reg = gen_row['True_reg']
    tru_ele = gen_row['True_ele']
    bleu_scr = gen_row['BM_scr']
    if gen_lbl == true_lbl:
        return line, False
    if pd.isna(gen_reg) or pd.isna(gen_ele):
        return line, False
    if gen_reg != tru_reg or gen_ele != tru_ele:
        return line, False
    if bleu_scr > scr_threshold:
        return line, False
    return gen_exp, True


In [4]:
# %%
# Code from data_preparation.py

# Define paths and parameters
base_path = base_drive_path  # Updated to Google Drive folder
base_save = base_drive_path   # Updated to Google Drive folder
suffix = ''  # Modify if needed
seq_length = 512
tag = 'blnc'  # 'lrg' or 'blnc'; assuming 'blnc' for balanced

print('Suffix is:', suffix)
print('Tag is:', tag)

# Load the dataset
train_df = pd.read_csv(os.path.join(base_path, 'train_balanced.csv'))
dev_df = pd.read_csv(os.path.join(base_path, 'dev_balanced.csv'))
test_df = pd.read_csv(os.path.join(base_path, 'test.csv'))

print(f'Training set loaded with {len(train_df)} records.')
print(f'Development set loaded with {len(dev_df)} records.')
print(f'Test set loaded with {len(test_df)} records.')


Suffix is: 
Tag is: blnc
Training set loaded with 5544 records.
Development set loaded with 12806 records.
Test set loaded with 6308 records.


In [5]:
# %%
# Load the tokenizer and add special tokens
model_name = "michiyasunaga/BioLinkBERT-base"
print('Used model is:', model_name)
print('Loading BERT tokenizer...')
tokenizer = AutoTokenizer.from_pretrained(model_name)

special_tokens = ['<exp>', '<re>', '<er>', '<el>', '<le>', '<end>']

special_tokens_dict = {'additional_special_tokens': special_tokens}
num_added_tokens = tokenizer.add_special_tokens(special_tokens_dict)
print(f'Added {num_added_tokens} special tokens.')


Used model is: michiyasunaga/BioLinkBERT-base
Loading BERT tokenizer...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Added 6 special tokens.


In [6]:
# %%
# Code from data_preparation.py

# %%
# Define the sentence encoder
def sentence_encoder(sentence1, sentence2, labels_map, labels=None, seq_length=128, tokenizer=tokenizer, lowercase=False):
    input_ids = []
    attention_masks = []
    token_types = []
    for i, sent1 in enumerate(sentence1):
        sent2_text = sentence2[i] if sentence2 is not None else None
        if lowercase:
            sent1 = sent1.lower()
            sent2_text = sent2_text.lower() if sent2_text else None

        encoded_dict = tokenizer.encode_plus(
            text=sent1,
            text_pair=sent2_text,
            add_special_tokens=True,
            max_length=seq_length,
            truncation=True,
            truncation_strategy='only_first',
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt',
            return_token_type_ids=True,
        )

        input_ids.append(encoded_dict['input_ids'])
        attention_masks.append(encoded_dict['attention_mask'])
        token_types.append(encoded_dict['token_type_ids'])

    input_ids = torch.cat(input_ids, dim=0)
    attention_masks = torch.cat(attention_masks, dim=0)
    token_types = torch.cat(token_types, dim=0)

    try:
        labels_out = [labels_map[c] for c in labels]
        labels = torch.tensor(labels_out)
    except Exception as e:
        print(f'Error occurred during label encoding: {e}')
        labels = None

    return input_ids, attention_masks, token_types, labels


In [7]:
# %%
# Code from data_preparation.py

# Set seed for reproducibility
seed_val = 42
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)


In [8]:
# %%
# Code from data_preparation.py

# %%
# Define labels mapping
labels_map = {'pos': 1}
for item in set(train_df['label_cat'].values):
    if item != 'pos':
        labels_map[item] = 0

print(f'Labels mapping: {labels_map}')


Labels mapping: {'pos': 1, 'posToNeg': 0, 'swap_number': 0, 'generation_nd_SRE': 0, 'generation_nd_SEN': 0, 'generation_nd': 0, 'generation': 0, 'SRE': 0, 'LPR': 0, 'SREO': 0, 'SEP': 0, 'negToPos': 0, 'SEN': 0}


In [9]:
# %%
# Code from data_preparation.py

# Define dataset labeling function
def dataset_labeling(x, p1=0.9, p2=0.7, p3=0.4, p4=0.1):
    r = x['prob']
    label_cat = x['label_cat']
    if label_cat in ['generation_nd_SEN', 'generation_nd_SRE', 'generation', 'generation_nd', 'SRE']:
        return 'train' if r < p1 else 'test'
    elif label_cat == 'pos':
        return 'train' if r < p2 else 'test'
    elif label_cat in ['SEP', 'SEN', 'posToNeg', 'negToPos']:
        return 'train' if r < p4 else 'test'
    elif label_cat in ['LPR', 'swap_number']:
        return 'train' if r < p3 else 'test'
    elif label_cat == 'SREO':
        return 'train' if r < p4 else 'test'
    else:
        return 'test'


In [10]:
# %%
# Encode datasets
train_inputs, train_masks, train_token_types, train_labels = sentence_encoder(
    train_df['conclusion'].tolist(),
    train_df['supp_set'].tolist(),
    labels_map,
    labels=train_df['label_cat'].tolist(),
    seq_length=seq_length,
    tokenizer=tokenizer,
    lowercase=False
)

dev_inputs, dev_masks, dev_token_types, dev_labels = sentence_encoder(
    dev_df['conclusion'].tolist(),
    dev_df['supp_set'].tolist(),
    labels_map,
    labels=dev_df['label_cat'].tolist(),
    seq_length=seq_length,
    tokenizer=tokenizer,
    lowercase=False
)

test_inputs, test_masks, test_token_types, test_labels = sentence_encoder(
    test_df['conclusion'].tolist(),
    test_df['supp_set'].tolist(),
    labels_map,
    labels=test_df['label_cat'].tolist(),
    seq_length=seq_length,
    tokenizer=tokenizer,
    lowercase=False
)

# Create TensorDatasets
train_dataset = TensorDataset(train_inputs, train_token_types, train_masks, train_labels)
dev_dataset = TensorDataset(dev_inputs, dev_token_types, dev_masks, dev_labels)
test_dataset = TensorDataset(test_inputs, test_token_types, test_masks, test_labels)

print(f'Training dataset size: {len(train_dataset)}')
print(f'Development dataset size: {len(dev_dataset)}')
print(f'Test dataset size: {len(test_dataset)}')


Training dataset size: 5544
Development dataset size: 12806
Test dataset size: 6308


In [11]:
# %%
# Create DataLoaders
batch_size = 16
val_batch_size = 16

train_dataloader = DataLoader(
    train_dataset,
    sampler=RandomSampler(train_dataset),
    batch_size=batch_size
)

validation_dataloader = DataLoader(
    dev_dataset,
    sampler=SequentialSampler(dev_dataset),
    batch_size=val_batch_size
)

print(f'Training DataLoader has {len(train_dataloader)} batches.')
print(f'Development DataLoader has {len(validation_dataloader)} batches.')


Training DataLoader has 347 batches.
Development DataLoader has 801 batches.


In [12]:
# %%
# Install SciSpaCy and the specific NER model
!pip install scispacy
!pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.0/en_ner_bionlp13cg_md-0.5.0.tar.gz


Collecting spacy<3.8.0,>=3.7.0 (from scispacy)
  Using cached spacy-3.7.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (27 kB)
Collecting thinc<8.3.0,>=8.2.2 (from spacy<3.8.0,>=3.7.0->scispacy)
  Using cached thinc-8.2.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (15 kB)
Using cached spacy-3.7.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (6.6 MB)
Using cached thinc-8.2.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (922 kB)
Installing collected packages: thinc, spacy
  Attempting uninstall: thinc
    Found existing installation: thinc 8.0.17
    Uninstalling thinc-8.0.17:
      Successfully uninstalled thinc-8.0.17
  Attempting uninstall: spacy
    Found existing installation: spacy 3.2.6
    Uninstalling spacy-3.2.6:
      Successfully uninstalled spacy-3.2.6
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the fo

Collecting https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.0/en_ner_bionlp13cg_md-0.5.0.tar.gz
  Using cached https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.0/en_ner_bionlp13cg_md-0.5.0.tar.gz (120.2 MB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting spacy<3.3.0,>=3.2.3 (from en_ner_bionlp13cg_md==0.5.0)
  Using cached spacy-3.2.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (23 kB)
Collecting thinc<8.1.0,>=8.0.12 (from spacy<3.3.0,>=3.2.3->en_ner_bionlp13cg_md==0.5.0)
  Using cached thinc-8.0.17-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (14 kB)
Using cached spacy-3.2.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (6.1 MB)
Using cached thinc-8.0.17-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (659 kB)
Installing collected packages: thinc, spacy
  Attempting uninstall: thinc
    Found existing installation: thinc 8.2.5
    Uninstalling thinc-8.2.5:
      Succe

In [13]:
# %%
# Import SpaCy and initialize the NER model
import spacy

def extract_and_save_entities(sentences, save_path):
    entity_types = []
    for idx, sent in enumerate(sentences):
        if idx % 1000 == 0:
            print(f'Processing sentence {idx}')
        doc = nlp(sent)
        entities = {ent.text: ent.label_ for ent in doc.ents}
        entity_types.append(entities)

    # Save to JSON and CSV
    with open(save_path + '.json', 'w') as f_json, open(save_path + '.csv', 'w') as f_csv:
        for entities in entity_types:
            json.dump(entities, f_json)
            f_json.write('\n')
        pd.DataFrame(entity_types).to_csv(f_csv, index=False)

    print(f'Entities saved to {save_path}.json and {save_path}.csv')

# Initialize SpaCy NER model
nlp = spacy.load("en_ner_bionlp13cg_md")

# Extract entities from training and development sets
extract_and_save_entities(train_df['conclusion'].tolist(), os.path.join(base_save, 'conclusion_entity_types'))
extract_and_save_entities(train_df['supp_set'].tolist(), os.path.join(base_save, 'supp_set_entity_types'))


  _C._set_default_tensor_type(t)


Processing sentence 0
Processing sentence 1000
Processing sentence 2000
Processing sentence 3000
Processing sentence 4000
Processing sentence 5000
Entities saved to /content/drive/MyDrive/basic_benchmarking/conclusion_entity_types.json and /content/drive/MyDrive/basic_benchmarking/conclusion_entity_types.csv
Processing sentence 0
Processing sentence 1000
Processing sentence 2000
Processing sentence 3000
Processing sentence 4000
Processing sentence 5000
Entities saved to /content/drive/MyDrive/basic_benchmarking/supp_set_entity_types.json and /content/drive/MyDrive/basic_benchmarking/supp_set_entity_types.csv


In [14]:
# %%
# Combine train and dev datasets for further processing
combined_input = pd.concat([train_df, dev_df], ignore_index=True)

print(f'Combined input dataset size: {len(combined_input)}')


Combined input dataset size: 18350


In [15]:
# %%
# Initialize counts and storage variables
LPR_count = posToNeg_count = negToPos_count = swap_number_count = generation_count = 0
generation_nd_count = generation_nd_SRE_count = generation_nd_SEN_count = 0
SRE_count = SEP_count = SET_count = SEN_count = SREO_count = 0
negative_sents = []
positive_sents = []
negative_ids = []
positive_ids = []
negative_supp_set = []
positive_supp_set = []
labels = []
data_df = pd.DataFrame(columns=['supp_set', 'conclusion', 'abstract', 'label_cat', 'ori_conclusion', 'pmid', 'index'])


In [16]:
# %%
# Load entity type mappings
entity_type_map_reverse = {}
entity_type_map = {}
with open(os.path.join(base_save, 'supp_set_entity_types.json')) as f:
    lines = f.readlines()
    for line in lines:
        supp_set_entity_type = json.loads(line)
        for ent, label in supp_set_entity_type.items():
            entity_type_map[ent] = label
            entity_type_map_reverse.setdefault(label, []).append(ent)

print(f'Loaded {len(entity_type_map)} entity types.')


Loaded 38758 entity types.


In [18]:
# %%
# Code from entailment_preparation.ipynb

# Initialize a list to collect all rows
rows = []

# Iterate over each example to create positive and negative samples
for index, (line, supp_set, pmid) in enumerate(zip(combined_input['conclusion'], combined_input['supp_set'], combined_input['pmid'])):
    if index % 1000 == 0:
        print(f'Processed {index} items')

    if pd.isna(line) or pd.isna(supp_set) or not re.search("<re>(.*?)<er>", line) or not re.search("<el>(.*?)<le>", line):
        continue

    # Add positive example
    positive_sents.append(line)
    positive_supp_set.append(supp_set)
    positive_ids.append(pmid)
    labels_cat = 'pos'
    labels.append(labels_cat)  # Append label
    row = {
        'supp_set': supp_set.strip(),
        'conclusion': re.sub(' +', ' ', line).strip(),
        'abstract': f"{supp_set.strip()} <exp> {re.sub(' +', ' ', line).strip()}",
        'label_cat': labels_cat,
        'ori_conclusion': line.strip(),
        'pmid': pmid,
        'index': index
    }
    rows.append(row)

    # Generate and add negative examples
    negative_supp_set.append(supp_set)
    negative_ids.append(pmid)

    # 1. Swap Random Entity Outside
    new_line, negative_c = swap_random_entity_outside(line, entity_type_map_reverse, entity_type_map)
    if negative_c:
        labels_cat = 'SREO'
        negative_sents.append(new_line)
        labels.append(labels_cat)
        SREO_count +=1
        row = {
            'supp_set': supp_set.strip(),
            'conclusion': re.sub(' +', ' ', new_line).strip(),
            'abstract': f"{supp_set.strip()} <exp> {re.sub(' +', ' ', new_line).strip()}",
            'label_cat': labels_cat,
            'ori_conclusion': line.strip(),
            'pmid': pmid,
            'index': index
        }
        rows.append(row)

    # 2. Swap Entity Positions
    new_line, negative_c = swap_entity_positions(line)
    if negative_c:
        labels_cat = 'SEP'
        negative_sents.append(new_line)
        labels.append(labels_cat)
        SEP_count +=1
        row = {
            'supp_set': supp_set.strip(),
            'conclusion': re.sub(' +', ' ', new_line).strip(),
            'abstract': f"{supp_set.strip()} <exp> {re.sub(' +', ' ', new_line).strip()}",
            'label_cat': labels_cat,
            'ori_conclusion': line.strip(),
            'pmid': pmid,
            'index': index
        }
        rows.append(row)

    # 3. Swap Entity Tags
    new_line, negative_c = swap_entity_tags(line)
    if negative_c:
        labels_cat = 'SET'
        negative_sents.append(new_line)
        labels.append(labels_cat)
        SET_count +=1
        row = {
            'supp_set': supp_set.strip(),
            'conclusion': re.sub(' +', ' ', new_line).strip(),
            'abstract': f"{supp_set.strip()} <exp> {re.sub(' +', ' ', new_line).strip()}",
            'label_cat': labels_cat,
            'ori_conclusion': line.strip(),
            'pmid': pmid,
            'index': index
        }
        rows.append(row)

    # 4. Swap Entity Names
    new_line, negative_c = swap_entity_names(line)
    if negative_c:
        labels_cat = 'SEN'
        negative_sents.append(new_line)
        labels.append(labels_cat)
        SEN_count +=1
        row = {
            'supp_set': supp_set.strip(),
            'conclusion': re.sub(' +', ' ', new_line).strip(),
            'abstract': f"{supp_set.strip()} <exp> {re.sub(' +', ' ', new_line).strip()}",
            'label_cat': labels_cat,
            'ori_conclusion': line.strip(),
            'pmid': pmid,
            'index': index
        }
        rows.append(row)

    # 5. Swap Number
    new_line, negative_c = swapNumber(line, supp_set)
    if negative_c:
        labels_cat = 'swap_number'
        negative_sents.append(new_line)
        labels.append(labels_cat)
        swap_number_count +=1
        row = {
            'supp_set': supp_set.strip(),
            'conclusion': re.sub(' +', ' ', new_line).strip(),
            'abstract': f"{supp_set.strip()} <exp> {re.sub(' +', ' ', new_line).strip()}",
            'label_cat': labels_cat,
            'ori_conclusion': line.strip(),
            'pmid': pmid,
            'index': index
        }
        rows.append(row)

    # 6. Word Replace
    new_line, negative_c = word_replace(line)
    if negative_c:
        labels_cat = 'LPR'
        negative_sents.append(new_line)
        labels.append(labels_cat)
        LPR_count +=1
        row = {
            'supp_set': supp_set.strip(),
            'conclusion': re.sub(' +', ' ', new_line).strip(),
            'abstract': f"{supp_set.strip()} <exp> {re.sub(' +', ' ', new_line).strip()}",
            'label_cat': labels_cat,
            'ori_conclusion': line.strip(),
            'pmid': pmid,
            'index': index
        }
        rows.append(row)

    # 7. posToNeg
    new_line, negative_c = posToNeg(line)
    if negative_c:
        labels_cat = 'posToNeg'
        negative_sents.append(new_line)
        labels.append(labels_cat)
        posToNeg_count +=1
        row = {
            'supp_set': supp_set.strip(),
            'conclusion': re.sub(' +', ' ', new_line).strip(),
            'abstract': f"{supp_set.strip()} <exp> {re.sub(' +', ' ', new_line).strip()}",
            'label_cat': labels_cat,
            'ori_conclusion': line.strip(),
            'pmid': pmid,
            'index': index
        }
        rows.append(row)

    # 8. negToPos
    new_line, negative_c = negToPos(line)
    if negative_c:
        labels_cat = 'negToPos'
        negative_sents.append(new_line)
        labels.append(labels_cat)
        negToPos_count +=1
        row = {
            'supp_set': supp_set.strip(),
            'conclusion': re.sub(' +', ' ', new_line).strip(),
            'abstract': f"{supp_set.strip()} <exp> {re.sub(' +', ' ', new_line).strip()}",
            'label_cat': labels_cat,
            'ori_conclusion': line.strip(),
            'pmid': pmid,
            'index': index
        }
        rows.append(row)

    # Optionally, save intermediate results every 1000 items
    if index % 1000 == 0 and index != 0:
        temp_df = pd.DataFrame(rows)
        temp_df.to_csv(os.path.join(base_path, 'entailment_data.csv'), mode='a', index=False, header=not os.path.exists(os.path.join(base_path, 'entailment_data.csv')))

        with open(os.path.join(base_path, 'positive_conclusion.txt'), 'a') as f:
            for item in positive_sents:
                f.write(f"{item}\n")
        with open(os.path.join(base_path, 'negative_conclusion.txt'), 'a') as f:
            for item in negative_sents:
                f.write(f"{item}\n")
        with open(os.path.join(base_path, 'positive_suppset.txt'), 'a') as f:
            for item in positive_supp_set:
                f.write(f"{item}\n")
        with open(os.path.join(base_path, 'negative_suppset.txt'), 'a') as f:
            for item in negative_supp_set:
                f.write(f"{item}\n")
        with open(os.path.join(base_save, 'negative_ids.pk'), 'wb') as f:
            pickle.dump(negative_ids, f)
        with open(os.path.join(base_save, 'positive_ids.pk'), 'wb') as f:
            pickle.dump(positive_ids, f)
        with open(os.path.join(base_save, 'labels_cat.pk'), 'wb') as f:
            pickle.dump(labels, f)

        # Clear the rows list after saving to avoid duplication
        rows = []

# After the loop, save any remaining rows
if rows:
    final_df = pd.DataFrame(rows)
    final_df.to_csv(os.path.join(base_path, 'entailment_data.csv'), mode='a', index=False, header=not os.path.exists(os.path.join(base_path, 'entailment_data.csv')))

# Similarly, save the remaining lists if needed
with open(os.path.join(base_path, 'positive_conclusion.txt'), 'a') as f:
    for item in positive_sents:
        f.write(f"{item}\n")
with open(os.path.join(base_path, 'negative_conclusion.txt'), 'a') as f:
    for item in negative_sents:
        f.write(f"{item}\n")
with open(os.path.join(base_path, 'positive_suppset.txt'), 'a') as f:
    for item in positive_supp_set:
        f.write(f"{item}\n")
with open(os.path.join(base_path, 'negative_suppset.txt'), 'a') as f:
    for item in negative_supp_set:
        f.write(f"{item}\n")
with open(os.path.join(base_save, 'negative_ids.pk'), 'wb') as f:
    pickle.dump(negative_ids, f)
with open(os.path.join(base_save, 'positive_ids.pk'), 'wb') as f:
    pickle.dump(positive_ids, f)
with open(os.path.join(base_save, 'labels_cat.pk'), 'wb') as f:
    pickle.dump(labels, f)


Processed 0 items
Processed 1000 items
Processed 2000 items
Processed 3000 items
Processed 4000 items
Processed 5000 items
Processed 6000 items
Processed 7000 items
Processed 8000 items
Processed 9000 items
Processed 10000 items
Processed 11000 items
Processed 12000 items
Processed 13000 items
Processed 14000 items
Processed 15000 items
Processed 16000 items
Processed 17000 items
Processed 18000 items


In [20]:
# %%
# Code from classifier.py

# Initialize the model for sequence classification
model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=2,
    output_attentions=False,
    output_hidden_states=False,
)

# Resize token embeddings to accommodate new special tokens
model.resize_token_embeddings(len(tokenizer))

# Move the model to the specified device
# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at michiyasunaga/BioLinkBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28901, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

In [21]:
# Set up the optimizer and scheduler
epochs = 3
learning_rate = 2e-5
optimizer = AdamW(model.parameters(), lr=learning_rate, eps=1e-8)
total_steps = len(train_dataloader) * epochs
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0,
    num_training_steps=total_steps
)



In [22]:
# Set the seed value for reproducibility
seed_val = 42
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
if device.type == 'cuda':
    torch.cuda.manual_seed_all(seed_val)



In [23]:
# %%
# Code from classifier.py

# Freeze certain layers for efficiency
for name, param in model.named_parameters():
    if 'classifier' not in name and 'cls' not in name and \
       'layer.11' not in name and 'layer.10' not in name and 'layer.9' not in name:
        param.requires_grad = False


In [24]:
# %%
# Code from classifier.py

# Evaluation function
def evaluate(dataloader, model, device):
    model.eval()
    preds = []
    true_labels = []
    for batch in dataloader:
        b_input_ids = batch[0].to(device)
        b_token_types = batch[1].to(device)
        b_input_mask = batch[2].to(device)
        b_labels = batch[3].to(device)

        with torch.no_grad():
            outputs = model(
                b_input_ids,
                token_type_ids=b_token_types,
                attention_mask=b_input_mask
            )
        logits = outputs.logits
        preds.extend(torch.argmax(logits, dim=1).cpu().numpy())
        true_labels.extend(b_labels.cpu().numpy())

    report = classification_report(true_labels, preds, labels=[0,1])
    conf_matrix = confusion_matrix(true_labels, preds, labels=[0,1])
    acc = accuracy_score(true_labels, preds)
    macro_f1 = f1_score(true_labels, preds, labels=[0,1], average='macro')

    print(report)
    print(conf_matrix)
    print(f'Macro F1-score: {macro_f1:.2f}')

    return acc, macro_f1


In [25]:
# %%
# Code from classifier.py

# Training loop
training_stats = []
best_macro_f1 = 0

total_t0 = time.time()

for epoch in range(epochs):
    print(f"\n======== Epoch {epoch + 1} / {epochs} ========")
    print("Training...")

    t0 = time.time()
    total_train_loss = 0
    model.train()

    for step, batch in enumerate(train_dataloader):
        if step % 100 == 0 and step != 0:
            elapsed = format_time(time.time() - t0)
            print(f'  Batch {step} of {len(train_dataloader)}. Elapsed: {elapsed}.')

        b_input_ids = batch[0].to(device)
        b_token_types = batch[1].to(device)
        b_input_mask = batch[2].to(device)
        b_labels = batch[3].to(device)

        model.zero_grad()

        outputs = model(
            b_input_ids,
            token_type_ids=b_token_types,
            attention_mask=b_input_mask,
            labels=b_labels
        )

        loss = outputs.loss
        total_train_loss += loss.item()

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()

    avg_train_loss = total_train_loss / len(train_dataloader)
    training_time = format_time(time.time() - t0)

    print(f"\n  Average training loss: {avg_train_loss:.2f}")
    print(f"  Training epoch took: {training_time}")

    # Validation
    print("\nRunning Validation...")
    acc, macro_f1 = evaluate(validation_dataloader, model, device)

    # Save the model if it has the best macro F1
    if macro_f1 > best_macro_f1:
        best_macro_f1 = macro_f1
        save_path = os.path.join(base_save, f'{model_name[:5]}_2sent_class_layer_{tag}{suffix}_top3')
        model.save_pretrained(save_path)
        tokenizer.save_pretrained(save_path)
        print(f'New best model saved with Macro F1: {best_macro_f1:.3f}')

    # Record statistics
    training_stats.append({
        'epoch': epoch + 1,
        'avg_train_loss': avg_train_loss,
        'training_time': training_time,
        'validation_accuracy': acc,
        'validation_macro_f1': macro_f1
    })

print("\nTraining complete!")
total_training_time = format_time(time.time() - total_t0)
print(f"Total training time: {total_training_time}")
print(f"Best Macro F1 achieved: {best_macro_f1:.3f}")



Training...
  Batch 100 of 347. Elapsed: 0:01:14.
  Batch 200 of 347. Elapsed: 0:02:27.
  Batch 300 of 347. Elapsed: 0:03:40.

  Average training loss: 0.66
  Training epoch took: 0:04:14

Running Validation...
              precision    recall  f1-score   support

           0       0.96      0.58      0.73      9986
           1       0.38      0.91      0.54      2820

    accuracy                           0.66     12806
   macro avg       0.67      0.75      0.63     12806
weighted avg       0.83      0.66      0.68     12806

[[5829 4157]
 [ 251 2569]]
Macro F1-score: 0.63
New best model saved with Macro F1: 0.632

Training...
  Batch 100 of 347. Elapsed: 0:01:13.
  Batch 200 of 347. Elapsed: 0:02:25.
  Batch 300 of 347. Elapsed: 0:03:38.

  Average training loss: 0.54
  Training epoch took: 0:04:12

Running Validation...
              precision    recall  f1-score   support

           0       0.98      0.69      0.81      9986
           1       0.47      0.96      0.63      2

In [27]:
# %%
# Create model_predictions.csv

import pandas as pd
import os

# Define the path where the best model is saved
best_model_path = os.path.join(base_save, f'{model_name[:5]}_2sent_class_layer_{tag}{suffix}_top3')

# Check if the best model exists
if not os.path.exists(best_model_path):
    raise FileNotFoundError(f"Best model not found at {best_model_path}. Please ensure the model was saved correctly.")

# Load the best model
best_model = AutoModelForSequenceClassification.from_pretrained(best_model_path)
best_model.to(device)
best_model.eval()  # Set the model to evaluation mode

# Initialize lists to store predictions and true labels
predicted_labels = []
true_labels = []
# Collect premises and hypotheses from dev_df
premises = dev_df['supp_set'].values
hypotheses = dev_df['conclusion'].values

# Iterate over the validation DataLoader and collect predictions
for batch in validation_dataloader:
    b_input_ids = batch[0].to(device)
    b_token_types = batch[1].to(device)
    b_input_mask = batch[2].to(device)
    b_labels_batch = batch[3].to(device)

    with torch.no_grad():
        outputs = best_model(
            b_input_ids,
            token_type_ids=b_token_types,
            attention_mask=b_input_mask
        )

    logits = outputs.logits
    predictions = torch.argmax(logits, dim=1).cpu().numpy()
    predicted_labels.extend(predictions)
    true_labels.extend(b_labels_batch.cpu().numpy())

# Define inverse label mapping for readability
# Assuming labels_map = {'pos': 1, 'neg_type1': 0, 'neg_type2': 0, ...}
label_map_inv = {1: 'pos', 0: 'neg'}

# Map predicted labels to string labels
predicted_labels_mapped = [label_map_inv.get(label, 'unknown') for label in predicted_labels]

# Map true labels to string labels
# Since true labels in dev_df['label_cat'] are strings, we keep them as is
true_labels_mapped = dev_df['label_cat'].values

# Create the DataFrame
results_df = pd.DataFrame({
    'premise': premises,
    'hypothesis': hypotheses,
    'predicted_label': predicted_labels_mapped,
    'true_label': true_labels_mapped
})

# Optional: Display first few rows to verify
print(results_df.head())

# Define the path to save the CSV
results_csv_path = os.path.join(base_save, 'model_predictions.csv')

# Save the DataFrame to CSV
results_df.to_csv(results_csv_path, index=False)
print(f"Model predictions saved to {results_csv_path}")


                                             premise  \
0  CaCl2 suppresses the plasma renin activity (PR...   
1  CaCl2 suppresses the plasma renin activity (PR...   
2  CaCl2 suppresses the plasma renin activity (PR...   
3  CaCl2 suppresses the plasma renin activity (PR...   
4  CaCl2 suppresses the plasma renin activity (PR...   

                                          hypothesis predicted_label  \
0  In conclusion, promotion of <re> PRA <er> by <...             neg   
1  In conclusion, inhibition of <el> CaCl2 <le> b...             neg   
2  In conclusion, inhibition of <re> PRA <er> by ...             neg   
3  In conclusion, inhibition of <re> PRA <er> by ...             neg   
4  In conclusion, inhibition of <re> CaCl2 <er> b...             neg   

  true_label  
0        LPR  
1        SEP  
2        pos  
3   posToNeg  
4        SEN  
Model predictions saved to /content/drive/MyDrive/basic_benchmarking/model_predictions.csv
