## CARe-BERT Pipeline and PubMedBERT Fine-Tuning

Alexander Xu, 2024

**Import Libraries**

In [None]:
import os
import gzip
import tarfile
from joblib import Parallel, delayed
import sys

import random
import json
import requests
from datetime import datetime
from collections import deque,defaultdict,Counter

from IPython.display import clear_output
from IPython.display import FileLink
from tqdm import tqdm
import logging
import argparse

import pandas as pd
import numpy as np
import pickle
import re

import nltk
nltk.download('wordnet','/root/nltk_data')
!unzip /root/nltk_data/corpora/wordnet.zip -d /root/nltk_data/corpora/
from nltk.stem import WordNetLemmatizer

import torch
from torch.utils.data import DataLoader
from torch.utils.data import IterableDataset
from torch.utils.data import Dataset

!pip install -U sentence-transformers
from sentence_transformers import SentenceTransformer, LoggingHandler, util, models, evaluation, losses, InputExample
from sentence_transformers.evaluation import TripletEvaluator

from scipy.spatial.distance import cityblock, euclidean, cosine
from scipy.stats import binom

tqdm.pandas()
random.seed(117)
clear_output()

## Functions

**Utility Functions**

In [None]:
# Flatten a matrix
def flatten_extend(matrix):
    flat_list = []
    for row in matrix:
        flat_list.extend(row)
    return flat_list

# Delete folder contents
def remove_folder_contents(folder):
    for the_file in os.listdir(folder):
        file_path = os.path.join(folder, the_file)
        try:
            if os.path.isfile(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path):
                remove_folder_contents(file_path)
                os.rmdir(file_path)
        except Exception as e:
            print(e)

**MIMIC Dataset Pre-Processing Functions**

In [None]:
# Parse MIMIC dataset row
def parse_row(data,ent_counts,rel_counts):
    entities = [
        {
            'id': idx,
            'start': ent["start_ix"],
            'end': ent["end_ix"] + 1,
            'label': ent["label"],
            'text': ent['tokens']
        }
        for idx, ent in data.items()
    ]

    relations = []
    for idx, ent in data.items():
        for relation in ent['relations']:
            head, child, label = relation[1], idx, relation[0]
            rel_label = 'described_as' if label == 'modify' else label
            rel_counts[rel_label] += 1
            relations.append({'head': child if label == 'modify' else head,
                              'child': head if label == 'modify' else child,
                              'label': rel_label})
    
    return (entities,relations,)

# Parse MIMIC dataset structure
def parse_dataset(dataset):
    ent_counts = defaultdict(int)
    rel_counts = defaultdict(int)
    
    dataset_copy = dataset.copy()
    dataset_copy['ents'], dataset_copy['rels'] = zip(
        *tqdm(dataset['data'].apply(lambda data: parse_row(data, ent_counts, rel_counts)))
    )
    
    return dataset_copy

**Knowledge Graph Construction Functions**

In [None]:
# Map each node to its original sentence
def match_nodes_to_sentences(nodes, sentences):
    node_to_sentence = {}
    sentence_boundaries = [0]
    
    for sentence in sentences:
        sentence_boundaries.append(sentence_boundaries[-1] + len(sentence.split()))

    for node in nodes:
        start = node['start']
        
        for idx, boundary in enumerate(sentence_boundaries):
            if start < boundary:
                node_to_sentence[node['id']] = idx
                break
    
    return node_to_sentence

# Identify roots of a knowledge graph
def find_roots(entities, relations):
    nodes_with_incoming_edges = {relation['child'] for relation in relations}
    all_nodes = {entity['id'] for entity in entities}
    
    roots = all_nodes - nodes_with_incoming_edges
    
    return list(roots)

# Create knowledge graph from root in DFS manner iteratively
def fill_from_root(start_node, entities, relations, node_to_sentence):
    graph = {'nodes': [], 'edges': [], 'sentences': set(), 'root': start_node}
    visited = set()
    queue = deque([(start_node, None)])  # (node_id, parent_node_id)

    while queue:
        node_id, parent_id = queue.popleft()
        if node_id in visited:
            continue
        visited.add(node_id)
        
        # Get the entity and add it to the graph
        entity = next(entity for entity in entities if entity['id'] == node_id)
        graph['nodes'].append(entity)
        graph['sentences'].add(node_to_sentence[node_id])
        
        # Add the edge from parent to current node
        if parent_id:
            label = next(rel['label'] for rel in relations if rel['head'] == parent_id and rel['child'] == node_id)
            graph['edges'].append({
                'from_node_id': node_id,
                'to_node_id': parent_id,
                'label': label
            })
        
        # Add children to the queue
        queue.extend((rel['child'], node_id) for rel in relations if rel['head'] == node_id)
    
    return graph

# Create all knowledge graphs
def create_knowledge_graphs(text, entities, relations):
    sentences = tokenizer.tokenize(text)
    node_to_sentence = match_nodes_to_sentences(entities, sentences)
    
    roots = find_roots(entities, relations)
    graphs = [fill_from_root(root, entities, relations, node_to_sentence) for root in roots]

    return graphs

# Main function to preprocess all data
def load_MIMIC(filepath, file_name, save_csv=False):
    tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')
    
    with open(filepath) as file:
        json_data = json.load(file)
        data_mimic = pd.DataFrame(json_data).T
        data_mimic = data_mimic.rename({'entities': 'data'}, axis=1)
    
    data_mimic = parse_dataset(data_mimic)
    data_mimic['graphs'] = data_mimic.progress_apply(lambda row: create_knowledge_graphs(row['text'], row['ents'], row['rels']), axis=1)

    if save_csv:
        data_mimic.to_csv(f'/kaggle/working/data/{file_name}.csv', index=False)
        
    return data_mimic

**Knowledge Graph Permutation for Document Augmentation (Hard Negatives) and Synthetic Query Generation (Anchors) Functions**

In [None]:
# Synthetic query augmentation for a node
def visit_node(node,query,vocabulary,replaceables,visited,edge=None):
    if node['label'] == 'OBS-U':
        query.append("possible")
    if node['label'] == 'OBS-DA':
        query.append("no")
    query.append(node['text'].lower())
    
    if edge:
        vocabulary[node['label']][edge['label']].add(node['text'].lower())
    else: 
        vocabulary[node['label']]['described_as'].add(node['text'].lower())
        vocabulary[node['label']]['suggestive_of'].add(node['text'].lower())
        vocabulary[node['label']]['located_at'].add(node['text'].lower())
    
    replaceables.append((node['text'], node['label'], edge['label'] if edge else None))
    visited.add(node['id'])

# Helper function to add children nodes and their relations to the query
def add_children_to_query(node, mask, node_id_to_number, graph, query, vocabulary, replaceables, visited):
    child_other = False
    for child_edge in graph['edges']:
        if (child_edge['from_node_id'] == node['id'] and
                child_edge['to_node_id'] not in visited and
                (mask & (1 << node_id_to_number[child_edge['to_node_id']]))):
            
            child_node = next(node for node in graph['nodes'] if node['id'] == child_edge['to_node_id'])

            if child_other:
                query.append('and')
            child_other = True

            edge_label = child_edge['label']
            query.append(edge_label.replace('_', ' '))

            visit_node(child_node, query, vocabulary, replaceables, visited, edge=child_edge)
            add_children_to_query(child_node, mask, node_id_to_number, graph, query, vocabulary, replaceables, visited)
    
# Build queries for a knowledge graph
def create_queries(graph, document, vocabulary):
    queries = {}
    num_nodes = len(graph['nodes'])
    node_id_to_number = {node['id']: i for i, node in enumerate(graph['nodes'])}
    
    for root_node in graph['nodes']:
        for mask in range(1, 1 << num_nodes):
            if not (mask & (1 << node_id_to_number[root_node['id']])):
                continue
            visited = set()
            replaceables = []
            query = []
            
            visit_node(root_node, query, vocabulary, replaceables, visited)
            add_children_to_query(root_node, mask, node_id_to_number, graph, query, vocabulary, replaceables, visited)

            if bin(mask).count('1') == len(visited):
                lemmatized_query = ' '.join([lemmatizer.lemmatize(token) for token in query]).lower()
                queries[lemmatized_query] = (document, replaceables)
    
    return queries
            
# Augment original document to create hard negatives
def create_negatives(document, replaceables):
    changed = False
    iterations = 0
    replaceable_dict = {token: (label, edge_label) for token, label, edge_label in replaceables}
    
    while not changed and iterations <= MAX_ITERATIONS:
        iterations += 1
        res = document.lower().split(' ')
        probs = binom.rvs(1, REPLACEMENT_PROB, size=len(res))
        probs_inside = binom.rvs(1, REPLACEMENT_PROB, size=len(res))
        
        for idx, cur_token in enumerate(res):
            if probs[idx] and cur_token in replaceable_dict:
                label, edge_label = replaceable_dict[cur_token]
                changed = True

                if label == 'OBS-DP' and probs_inside[idx]:
                    res[idx] = f"{random.choice(['no', 'absence of', 'possible'])} {cur_token}"
                else:
                    new_vocab = random.choice(list(vocabulary[label][edge_label])).lower() if edge_label else random.choice(vocabulary[label]['ALL']).lower()
                    while (new_vocab == cur_token or
                           stemmer.stem(new_vocab) == stemmer.stem(cur_token) or
                           lemmatizer.lemmatize(new_vocab) == lemmatizer.lemmatize(cur_token)):
                        new_vocab = random.choice(list(vocabulary[label][edge_label])).lower() if edge_label else random.choice(vocabulary[label]['ALL']).lower()
                    res[idx] = new_vocab

    return ' '.join(res)

# Combine negatives, queries, positives (original documents) into triplets
def create_triplets(query, embed_docs):
    query = query.lower()
    qid = len(query_set)
    
    if query not in query_set:
        queries[qid] = query
        query_set[query] = qid
    else:
        qid = query_set[query]
        
    random.shuffle(embed_docs)
    
    for document, replaceables in embed_docs[:10]:
        document = re.sub(r'_+', '_', document).lower()
        did = len(document_set)
        
        if document not in document_set:
            corpus[did] = document
            document_set[document] = did
        else:
            did = document_set[document]
        
        pos[qid].append(did)
        
        for _ in range(NEGATIVE_RATIO):
            neg_did = len(document_set)
            neg_doc = create_negatives(document, replaceables)
            
            if neg_doc not in document_set:
                corpus[neg_did] = neg_doc
                document_set[neg_doc] = neg_did
            else:
                neg_did = document_set[neg_doc]
                
            neg[qid].append(neg_did)

## CARe-BERT Workflow

In [None]:
# Initialize dictionaries and sets
vocabulary = {
    'OBS-DP': {'described_as': set(), 'located_at': set(), 'suggestive_of': set()},
    'OBS-U': {'described_as': set(), 'located_at': set(), 'suggestive_of': set()},
    'OBS-DA': {'described_as': set(), 'located_at': set(), 'suggestive_of': set()},
    'ANAT-DP': {'described_as': set(), 'located_at': set(), 'suggestive_of': set()}
}
queries, corpus, pos, neg = {}, {}, defaultdict(list), defaultdict(list)
query_set, document_set, train_queries = {}, {}, {}
qdr_dict = defaultdict(list)

# Initialize tools
tokenizer=nltk.data.load('tokenizers/punkt/english.pickle')
lemmatizer = WordNetLemmatizer()
stemmer = nltk.stem.PorterStemmer()

# CARe-BERT Pipeline Parameters
NEGATIVE_RATIO = 5
MAX_ITERATIONS = 5
REPLACEMENT_PROB = 0.3

**Download MIMIC Dataset and Process Data to Knowledge Forests**

In [None]:
# Load datasets
data_pipeline = load_MIMIC("/kaggle/input/radgraph-data/MIMIC-CXR_graphs.json", "MIMIC_inference", save_csv=True)
data_gt_train = load_MIMIC("/kaggle/input/evaluation-data/MIMIC_train.json", "MIMIC_gt_train", save_csv=True)
data_gt_dev = load_MIMIC("/kaggle/input/evaluation-data/MIMIC_dev.json", "MIMIC_gt_dev", save_csv=True)
data_gt_test = load_MIMIC("/kaggle/input/evaluation-data/MIMIC_test.json", "MIMIC_gt_test", save_csv=True, test=True)

**Create Queries from Knowledge Graphs**

In [None]:
# Process data for the CARe-BERT pipeline
data_pipeline.progress_apply(lambda x: create_queries(x, qdr_dict, vocabulary), axis='columns')

# Update vocabulary with combined categories
for key in vocabulary:
    vocabulary[key]['ALL'] = list(vocabulary[key]['located_at'] | vocabulary[key]['described_as'] | vocabulary[key]['suggestive_of'])

**Create Negatives and Form Triplets**

In [None]:
# Create negatives and triplets
for query, embed_docs in tqdm(qdr_dict.items()):
    create_triplets(query, embed_docs)

# Prepare training triplets
for qid, query in tqdm(queries.items()):
    train_queries[qid] = {
        'qid': qid,
        'query': query,
        'pos': pos[qid],
        'neg': neg[qid]
    }

## Fine-Tuning S-BERT

In [None]:
# Dataloader class
class Dataset(Dataset):
    def __init__(self, queries, corpus):
        self.queries = queries
        self.queries_ids = list(queries.keys())
        self.corpus = corpus

        for qid in self.queries:
            self.queries[qid]['pos'] = list(self.queries[qid]['pos'])
            self.queries[qid]['neg'] = list(self.queries[qid]['neg'])
            random.shuffle(self.queries[qid]['neg'])

    def __getitem__(self, item):
        query = self.queries[self.queries_ids[item]]
        query_text = query['query']

        pos_id = query['pos'].pop(0)
        pos_text = self.corpus[pos_id]
        query['pos'].append(pos_id)

        neg_id = query['neg'].pop(0)
        neg_text = self.corpus[neg_id]
        query['neg'].append(neg_id)
        
        return InputExample(texts=[query_text, pos_text, neg_text])

    def __len__(self):
        return len(self.queries)

In [None]:
# Split data into train, validation, and test sets
all_keys = list(train_queries.keys())
random.shuffle(all_keys)

train_keys = all_keys[:int(len(train_queries) * 0.4)]
val_keys = all_keys[int(len(train_queries) * 0.4):int(len(train_queries) * 0.5)]
test_keys = all_keys[int(len(train_queries) * 0.5):int(len(train_queries) * 0.6)]

queries_subset_train = {k: train_queries[k] for k in train_keys}
queries_subset_val = {k: train_queries[k] for k in val_keys}
queries_subset_test = {k: train_queries[k] for k in test_keys}

**Transfer Learning**

In [None]:
# Configure logging
logging.basicConfig(
    format='%(asctime)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S',
    level=logging.INFO,
    handlers=[logging.StreamHandler()]
)

# Parse arguments
parser = argparse.ArgumentParser()
parser.add_argument("--train_batch_size", default=64, type=int)
parser.add_argument("--max_seq_length", default=100, type=int)
parser.add_argument("--model_name", required=True)
parser.add_argument("--max_passages", default=0, type=int)
parser.add_argument("--epochs", default=10, type=int)
parser.add_argument("--pooling", default="mean")
parser.add_argument("--negs_to_use", default=None)
parser.add_argument("--optimizer_class", default="AdamW")
parser.add_argument("--lr", default=2e-5, type=float)
parser.add_argument("--warmup_steps", default=100, type=int)
parser.add_argument("--weight_decay", default=0.01, type=float)
parser.add_argument("--use_pre_trained_model", default=True, action="store_true")
parser.add_argument("--use_all_queries", default=False, action="store_true")
args = parser.parse_args(["--model_name", "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract"])

print(args)

model_name = args.model_name
model_save_path = f'{model_name.replace("/", "-")}-{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}'

# Load or create SBERT model
if args.use_pre_trained_model:
    print("Using pretrained SBERT model")
    model = SentenceTransformer(model_name)
    model.max_seq_length = args.max_seq_length
else:
    print("Creating new SBERT model")
    word_embedding_model = models.Transformer(model_name, max_seq_length=args.max_seq_length)
    pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), args.pooling)
    model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

# Prepare datasets and evaluators
train_dataset = Dataset(queries_subset_train, corpus=corpus)
dev_dataset = Dataset(queries_subset_val, corpus=corpus)
test_dataset = Dataset(queries_subset_test, corpus=corpus)

dev_evaluator = TripletEvaluator.from_input_examples(dev_dataset, name='dev')
test_evaluator = TripletEvaluator.from_input_examples(test_dataset, name='test')

train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=args.train_batch_size)
train_loss = losses.TripletLoss(model=model)

# Evaluate before fine-tuning
print("Performance before fine-tuning:")
print(dev_evaluator(model))

# Train the model
model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    evaluator=dev_evaluator,
    epochs=args.epochs,
    warmup_steps=args.warmup_steps,
    weight_decay=args.weight_decay,
    use_amp=True,
    checkpoint_path=model_save_path,
    checkpoint_save_steps=len(train_dataloader),
    optimizer_params={'lr': args.lr}
)

# Save the model
model.save(model_save_path)

# Evaluate the model on the test set
print("Evaluating model on test set")
print(model.evaluate(test_evaluator))