# **Project Notebook**

This notebook documents the steps taken to produce the data - a series of relevant *subject, relationship, object* triplets - that will be visualised in a network in future steps. The process has been broken down as follows:

0. **Package Installation**.

1. **Downloading the Data** using E-Utilities and Biopython.

2. **Relationship Extraction** using REBEL.

3. **Key-Word Extraction** using KeyBERT and BioBERT.

4. **Lemmatization and Entity Resolution** using Stanza.

5. **Visualisation**.

6. **Evaluation** using PyKEEN.

7. Other Attempts at Entity Resolution. (Extra)

### **0. Package Installation.**

This code lists the installations required to run the notebook.

In [None]:
import sys
!{sys.executable} -m pip install pandas
!{sys.executable} -m pip install keybert
!{sys.executable} -m pip install nltk
!{sys.executable} -m pip install torch
!{sys.executable} -m pip install transformers
!{sys.executable} -m pip install Levenshtein
!{sys.executable} -m pip install scikit-learn
!{sys.executable} -m pip install BioPython
!{sys.executable} -m pip install pykeen

nltk.download('all')

## **1. Downloading the Data using E-Utilities and Biopython.**
To download the data, we will make use of E-Utilities (NCBI Entrez Programming Utilities), a set of tools designed to facilitate the process of downloading large sets of bioinformatic data.

A general introduction to the E-Utilities:
- https://www.ncbi.nlm.nih.gov/books/NBK25497/
- *'A set of nine server-side programs that provide a stable interface into the Entrez query and database system at the NCBI'*.
- Uses a fixed URL syntax that translates a standard set of input parameters into the values necessary for various NCBI software components to search for and retrieve the requested data.
- The E-utilities are therefore the structured interface to the Entrez system, which currently includes 38 databases covering a variety of biomedical data, including nucleotide and protein sequences, gene records, three-dimensional molecular structures, and the biomedical literature.
- To access data, a piece of software posts an E-utility URL to NCBI, then retrieves the results of this and processes the data.
- It can use any computer languages that can send a URL to the E-utilities server and interpret the XML response (i.e. Python, Perl, Java, C++).
- NCBI requests that users limit requests to no more than 3 per second.

A combination of *ESearch* and *EFetch* can be used to find and retrieve the relevant data.

In order to make use of the tools, a program must post an 'E-Utility URL' to NCBI. ```BioPython``` is a library that provides a tool called Entrez to send these URLs using Python. 

Below, it is used to fetch all article abstracts related to the term *'biology'*. It then saves the data to one text file, 'bio_corpus.txt'.

In [None]:
from Bio import Entrez
import xml.etree.ElementTree as ET

# Setting email.
Entrez.email = 'aidanlowrie@example.com'

# Setting search word.
search_word = 'biology'

# Search for PubMed articles related to the search word 'biology', returning up to 100,000 results. 
search_handle = Entrez.esearch(db='pubmed', term=search_word, retmax=100000)
record = Entrez.read(search_handle)
search_handle.close()

# A list of uids is necessary for fetching the actual abstracts.
uids = record['IdList']

# Fetch the abstracts in XML form, so that the actual abstract may be extracted.
fetch_handle = Entrez.efetch(db="pubmed", id=','.join(uids), 
                             rettype="abstract", retmode="xml") # Return data in XML form.
data = fetch_handle.read()
fetch_handle.close()

# Parsing the XML.
root = ET.fromstring(data)
# Extracting the abstracts themselves from the returned data.
abstracts = root.findall(".//AbstractText")

# Write the abstracts to a file.
with open("data/bio_corpus.txt", "w") as file:
    for abstract in abstracts:
        file.write(str(abstract.text) + "\n\n") # Two newlines are added between abstracts, for clarity.

## **2. Relationship Extraction using REBEL**
Relationship extraction involves finding triples - *subject (head)*, *relationship (type)* and *object (tail)* - in a corpus. While this can be achieved through a variety of machine learning techniques including pattern matching and supervised machine learning, we have chosen to use REBEL (Relationship Extraction By End-to-end Language generation).

[REBEL](https://aclanthology.org/2021.findings-emnlp.204.pdf) is an open source relationship extraction seq2seq model released in 2021. We have chosen to use it due to its state-of-the-art performance and accessibility.

Before being passed into the model, we used ```NLTK``` to break the corpus into sentences.

In [None]:
# Importing nltk resources.
import nltk
from nltk.corpus.reader import PlaintextCorpusReader
from nltk.tokenize import sent_tokenize, word_tokenize

# Opening corpus.
with open("data/bio_corpus.txt", "r") as file:
    corpus_text = file.read()

corpus_sentences = sent_tokenize(corpus_text)

The ```extract_triplets``` function was lifted directly from the [huggingface REBEL docs](https://huggingface.co/Babelscape/rebel-large). It parses the text generated by REBEL into a list of triplets. 

In [None]:
# Parse REBEL output into a list of triplets. 
def extract_triplets(text):
    triplets = []
    relation, subject, object_ = '', '', ''
    text = text.strip()
    current = 'x'
    for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").split():
        if token == "<triplet>":
            current = 't'
            if relation != '':
                triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
                relation = ''
            subject = ''
        elif token == "<subj>":
            current = 's'
            if relation != '':
                triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
            object_ = ''
        elif token == "<obj>":
            current = 'o'
            relation = ''
        else:
            if current == 't':
                subject += ' ' + token
            elif current == 's':
                object_ += ' ' + token
            elif current == 'o':
                relation += ' ' + token
    if subject != '' and relation != '' and object_ != '':
        triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
    return triplets

The code below loads the REBEL model and writes the extracted triplets to a file. (This had to be carried out in batches over several nights, which is why the code was adapted to include a ```start_line```.)

In [None]:
import csv
import pandas as pd
from transformers import pipeline

# Loading the model.
triplet_extractor = pipeline('text2text-generation', model='Babelscape/rebel-large', tokenizer='Babelscape/rebel-large')

# Establishing a start line.
start_line = 0

# The model can only handle tokens of max length 1024 tokens. Those exceeding this capacity aren't considered in the dataset. (This is very rare, but the check is necessary.)
max_token_length = 1024

# Opening a csv file for triplet storage.
with open("data/triplets_batch4.csv", "w") as file:
    # Field names.
    field_names = ['head', 'type', 'tail']
    writer = csv.DictWriter(file, fieldnames=field_names)
    writer.writeheader()
    for i, sentence in enumerate(corpus_sentences):
        if i > start_line and len(triplet_extractor.tokenizer.encode(sentence)) <= max_token_length:
            extracted_text = triplet_extractor.tokenizer.batch_decode([triplet_extractor(sentence, return_tensors=True, return_text=False)[0]["generated_token_ids"]])
            triplets = extract_triplets(extracted_text[0])
            for triplet in triplets:
                writer.writerow(triplet)    

# Merge triples into a single file. 
def csv_union(csv_path_list, output_csv_path):
    dfs = []
    for csv in csv_path_list:
        dfs.append(pd.read_csv(csv))
    df_union = pd.concat(dfs)
    df_union.to_csv(output_csv_path, index=False)

csv_union(csv_path_list=['data/triplets_batch1.csv', 'data/triplets_batch2.csv', 'data/triplets_batch3.csv', 'data/triplets_batch4.csv'], output_csv_path='data/triplets.csv')

## **3. KeyWord Extraction using KeyBERT**

Originally, we tried to extract keywords by getting cosine similarity scores of terms against an average 'keyword embedding', which was generated by averaging the embeddings of a range of biology-related terms. This proved to be ineffective. We later found out about ```KeyBERT```, a library designed to facilitate the process of keyword extraction. 

KeyBERT works similarly, but generates an average '**abstract embedding**' for each abstract, rather than relying on a curated 'keyword embedding'. Potential keywords are compared to this embedding by **cosine similarity**, just like in the original method. It automatically generates 'keyword' n-grams of any specified length.

In the code, KeyBERT is used on top of ```BioBERT``` (a version of BERT pre-trained on a biological corpus).

In [1]:
from keybert import KeyBERT
from nltk.corpus import stopwords
from transformers import AutoModel, AutoTokenizer
from sklearn.metrics.pairwise import cosine_similarity

# Loading the corpus.
with open('data/bio_corpus.txt', 'r') as file:
    abstracts = file.readlines()

# Loading the BioBERT model.
model_name = "dmis-lab/biobert-v1.1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
biobert_model = AutoModel.from_pretrained(model_name)

# Loading the KeyBERT model, running on BioBERT 
kw_model = KeyBERT(model=biobert_model)

# Loading triplets dataframe.
df = pd.read_csv('data/triplets.csv', names=['head', 'type', 'tail'])
# Remove null values.
df = df[df['head'].notna() & df['tail'].notna()]

# Creating a list of stopwords from NLTK.
stop_words = list(set(stopwords.words('english')))

# Extracting keywords and write them to a CSV.
keywords = set()
with open("data/keywords.csv", "w") as file:
    writer = csv.DictWriter(file, fieldnames=['Keyword'])
    writer.writeheader()   
    for abstract in abstracts:
        # Extracting unigrams.
        new_keyword_list = set([keyword for keyword, _ in kw_model.extract_keywords(abstract, keyphrase_ngram_range=(1, 1), stop_words=stop_words)])
        novel_keywords = new_keyword_list - keywords
        keywords.update(novel_keywords)
        for novel_keyword in novel_keywords:
            writer.writerow({'Keyword': novel_keyword})
        # Extracting bigrams.
        new_keyword_list = set([keyword for keyword, _ in kw_model.extract_keywords(abstract, keyphrase_ngram_range=(2, 2), stop_words=stop_words)])
        novel_keywords = new_keyword_list - keywords
        keywords.update(novel_keywords)
        for novel_keyword in novel_keywords:
            writer.writerow({'Keyword': novel_keyword})
        # Extracting trigrams.
        new_keyword_list = set([keyword for keyword, _ in kw_model.extract_keywords(abstract, keyphrase_ngram_range=(3, 3), stop_words=stop_words)])
        novel_keywords = new_keyword_list - keywords
        keywords.update(novel_keywords)
        for novel_keyword in novel_keywords:
            writer.writerow({'Keyword': novel_keyword})

# Filter the dataframe based on a list of keywords.
def filter_dataframe(df, relevant_column_names, filter_set):
    for column_name in relevant_column_names:
        # Split hyphenated words.
        df[column_name] = df[column_name].str.replace('-', ' ')
        # Remove words that are in the filter set.
        df = df[~df[column_name].isin(filter_set)]
    return df

keyword_data = pd.read_csv('data/keywords.csv')
keywords = keyword_data['Keyword'].tolist()
filtered_df = filter_dataframe(df=df, relevant_column_names=['head', 'tail'], filter_set=keywords)
filtered_df.to_csv('data/filtered_triplets.csv', index=False)

## **4. Lemmatization and Entity Resolution using Stanza.**

The library ```stanza``` was chosen for the task of lemmatization due to its biomedical lemmatization tools, trained on biomedical text. As a result, it was less likely to incorrectly classify words. Only unigrams were chosen to be lemmatized and merged as lemmatized bigrams / trigrams might be more difficult for humans to interpret.

In [None]:
import pandas as pd
import stanza

# Loading the data.
df = pd.read_csv('data/filtered_triplets.csv')

# Downloading the lemmatizer.
stanza.download('en', package='craft')
nlp = stanza.Pipeline('en', processors='tokenize,lemma', package='craft')

# Lemmatize a given text.
def lemmatize_text(text):
    if ' ' not in text:
        doc = nlp(text)
        lemmas = [word.lemma for sent in doc.sentences for word in sent.words]
        return lemmas[0] if lemmas else text
    else:
        return text

# Lemmatizing the head and tail columns.
df['head'] = df['head'].apply(lemmatize_text)
df['tail'] = df['tail'].apply(lemmatize_text)

# Creating new 'pair' column. This will be how we determine the most common relationships between pairs. 
df['pair'] = df['head'] + ',' + df['tail']

# Grouping by pair and type, counting frequency.
df_grouped = df.groupby(['pair', 'type', 'head', 'tail']).size().reset_index(name='counts')

# Sorting by pair and counts, and dropping all but one instance of the most common (first). 
df_most_common = df_grouped.sort_values(['pair', 'counts'], ascending=False).drop_duplicates(subset='pair').sort_index()

# Dropping unnecessary 'pair' and 'counts' columns.
df_most_common = df_most_common.drop(columns=['pair', 'counts'])

# Reordering.
df_most_common = df_most_common[['head', 'type', 'tail']]

# Saving to CSV.
df_most_common.to_csv('data/filtered_lemmatized_triplets.csv', index=False)


In [13]:


df_most_common.to_csv('data/filtered_lemmatized_triplets.csv', index=False)


## **5. Visualisation.**

Visualisation code can be found in this notebook (Aidan) and a separate one (Boray), using different methods.

In [None]:
import pandas as pd
import numpy as np
import networkx as nx

# Read the CSV file into a pandas dataframe.
df = pd.read_csv("data/filtered_lemmatized_triplets.csv")

# Create NetworkX DiGraph from dataframe.
df.rename(columns={'type':'relation'}, inplace=True)
G = nx.from_pandas_edgelist(df, source='head', target='tail', edge_attr='relation', create_using=nx.DiGraph())

# Assign edge labels.
for u, v, data in G.edges(data=True):
    data['label'] = data['relation']

nx.write_gexf(G, "big_triplet_network.gexf")

import random

def filter_edges_by_relation(G, relation_type):
    H = G.copy()
    for u, v, data in G.edges(data=True):
        if data['relation'] != relation_type:
            H.remove_edge(u, v)
    return H

def prune_graph(G, target_node, radius):
    H = G.copy()

    # Create an ego graph from the target node.
    H = nx.ego_graph(G, target_node, radius=radius, center=True, undirected=False)
    
    # Iterate over all pairs of nodes.
    for u, v in list(H.edges):
        # If there are edges in both directions...
        if H.has_edge(u, v) and H.has_edge(v, u):
            # Compute shortest paths *from* the target node.
            path_u = nx.shortest_path_length(H, target_node, u)
            path_v = nx.shortest_path_length(H, target_node, v)
            # Remove the edge that is farther from the target_node.
            if path_u < path_v:
                H.remove_edge(v, u)
            elif path_u > path_v:
                H.remove_edge(u, v)
            else:
                # If there's a tie, randomly remove one of the two edges.
                if random.choice([True, False]):
                    H.remove_edge(u, v)
                else:
                    H.remove_edge(v, u)
    return H

# Create subgraph for a particular target node.
def generate_subgraph(G, target_node, radius, relation_type=None):
    H = G.copy()
    H = prune_graph(H, target_node=target_node, radius=radius)
    if relation_type != None:
        H = filter_edges_by_relation(H, relation_type)
    nx.write_gexf(H, str('graphs/' + target_node + '_' + str(radius) + '_' + str(relation_type) + '.gexf'))

The function below can be used to generate a subgraph of a target term in the knowledge graph (node), with only one edge allowed between terms (preferring outgoing edges from nodes closest to the target). Examples of graphs created by this method can be found in the 'data/graphs' directory.

In [None]:
generate_subgraph(G=G, target_node='telomere', radius=1, relation_type=None)

## **6. Evaluation using PyKEEN.**

PyKEEN is a library that can be used to create predictor models based on Knowledge Graph embeddings (such as the information we have created by the triplet generation process). Below, PyKEEN is used to generate a predictor model, and then - after 100 iterations of training - the model's performance is evaluated. By evaluating the performance of the model, we can infer information about the KG itself. 

In [None]:
from pykeen.pipeline import pipeline
from pykeen.triples import TriplesFactory

# Loading data.
df = pd.read_csv('drive/MyDrive/data/filtered_lemmatized_triplets.csv')
df = df.astype(str)

# Creating a TriplesFactory object to pass into the pipeline.
triples_factory = TriplesFactory.from_labeled_triples(df.values)
# Splitting object into training and testing data.
training, testing = triples_factory.split([.8, .2])  # 80% training, 20% testing.

# Training the model.
pipeline_result = pipeline(
    model='nTransE', 
    training=training,
    testing=testing,
    random_seed=101,
    device='cuda', 
    training_kwargs=dict(num_epochs=100),
)

# Evaluating the model.
results_df = pipeline_result.metric_results.to_df()
# Saving the evaluation.
results_df.to_csv('drive/MyDrive/data/metric_results_lemmatized_triplets_model.csv')

This concludes the processes used to create and evaluate the KG.

## **7. Other Attempts at Entity Resolution. (Extra)**

#### **Using BioBERT Embeddings and Levenshtein Distance.**
While lemmatization is useful for normalising text, we hoped that using BioBERT embeddings would allow for entity resolution in more ambiguous cases (such as 'Alzheimers' and 'Alzheimers Disease'.) We made several attempts to perform entity resolution using BioBERT embeddings, but this proved ineffective. Our final attempt, which considered a combination of BioBERT embeddings and Levenshtein (edit) distance, is below.

In [None]:
import torch
import random
import Levenshtein

# Get a list of each unique term in the df.
unique_words = list(set(df["head"].unique().tolist() 
                        + df["tail"].unique().tolist()))

# Tokenise.
unique_words_tokenised = [tokenizer(word, return_tensors="pt") for word in unique_words]

# Feeding the tokenised words into the model to get a list of unique_word_embeddings.
with torch.no_grad():
    unique_word_outputs = [biobert_model(**tokens) for tokens in unique_words_tokenised]
unique_word_embeddings = [output.last_hidden_state.mean(dim=1).numpy() for output in unique_word_outputs]

# Creating a DataFrame with the words and their embeddings
df_embeddings = pd.DataFrame({
    'word': unique_words,
    'embedding': unique_word_embeddings
})

def get_cosine_similarity(df_embeddings, word_1, word_2):
    # Getting the embeddings for the two words
    embedding1_df = df_embeddings[df_embeddings['word'] == word_1]['embedding']
    embedding2_df = df_embeddings[df_embeddings['word'] == word_2]['embedding']
    if embedding1_df.empty or embedding2_df.empty:
        return 0
    else:
        embedding1 = embedding1_df.values[0]
        embedding2 = embedding2_df.values[0]
        # Computing and returning cosine similarity scores
        return cosine_similarity(embedding1, embedding2)[0][0] # type: ignore

def get_levenshtein_similarity(word_1, word_2):
    return 1 - Levenshtein.distance(word_1, word_2) / max(len(word_1), len(word_2))

# Merging similar words by rippling through the keyword list and comparing against others, then removing from list.
def merge_similar_words(df, df_embeddings, cosine_threshold, levenshtein_threshold):
    word_replacements = {}
    unique_words = df_embeddings['word'].to_list()
    for word_1 in unique_words:
        unique_words.remove(word_1)
        for word_2 in unique_words:
            levenshtein_score = get_levenshtein_similarity(word_1=word_1, word_2=word_2)
            cosine_score = get_cosine_similarity(df_embeddings=df_embeddings, word_1=word_1, word_2=word_2)
            if  (cosine_score > cosine_threshold and levenshtein_score > levenshtein_threshold) or levenshtein_score > 0.85:
                print('SUCCESS', '1', word_1, '2', word_2, 'cos', cosine_score, 'lev', levenshtein_score)
                if word_1.lower() == word_1 or word_2.lower() == word_2:
                    word_1 = word_1.lower()
                    word_2 = word_2.lower()
                if len(word_1) < len(word_2):
                    winner = word_1
                elif len(word_2) < len(word_1):
                    winner = word_2
                else:
                    winner = random.choice((word_1, word_2))
                if winner == word_1:
                    loser = word_2
                else:
                    loser = word_1
                word_replacements[loser] = winner
    df_replaced = df.replace(word_replacements)
    return df_replaced

df_replaced = merge_similar_words(df=df, df_embeddings=df_embeddings, cosine_threshold=0.97, levenshtein_threshold=0.6)
df_replaced.to_csv('data/final_triplets.csv')

The main issue was that some words would be considered very semantically similar to just about everything. To limit this, the threshold was increased and the levenshtein scores were introduced. But problems remained. Terms like *RNA virus* and *DNA virus* would be merged by the algorithm, despite clearly referring to different entities. Another approach was needed.

#### **Finetuning an Entity Resolution Model.**
Next, we tried to finetune a BioBERT model for the task of entity resolution. I still think that this is a promising approach, but the dataset for it just doesn't exist yet. I tried to make one myself, but without sufficient time or resources, such a task has proved to be impossible. With a large enough dataset, I anticipate that this could have been a successful solution to the Entity Resolution task. 

In [None]:
from transformers import BertTokenizer, BertForSequenceClassification
from torch.utils.data import DataLoader, Dataset
import torch
import pandas as pd
from torch.optim import Adam
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from itertools import combinations

# Loading BioBERT model and tokenizer.
tokenizer = BertTokenizer.from_pretrained('dmis-lab/biobert-base-cased-v1.1')
model = BertForSequenceClassification.from_pretrained('dmis-lab/biobert-base-cased-v1.1')

# Preparing the dataset.
class MergingDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.tokenizer = tokenizer
        self.data = data

    def __getitem__(self, idx):
        text_a, text_b, label = self.data[idx]
        inputs = self.tokenizer.encode_plus(text_a, text_b,
                                            padding='max_length',
                                            max_length=512,
                                            truncation=True,
                                            return_tensors='pt')
        inputs = {key: tensor.squeeze(0) for key, tensor in inputs.items()}
        inputs['labels'] = torch.tensor(label, dtype=torch.long)
        return inputs

    def __len__(self):
        return len(self.data)

# Loading the data.
data_df = pd.read_csv('drive/MyDrive/data/classifier_training_data.csv', usecols=[0, 1, 2], header=None, names=['Text_A', 'Text_B', 'Label'])
data_df = data_df.applymap(lambda s:s.lower() if type(s) == str else s)
data = [(row['Text_A'], row['Text_B'], int(row['Label'])) for index, row in data_df.iterrows()]

train_data, val_data = train_test_split(data, test_size=0.05, random_state=10101)

train_dataset = MergingDataset(train_data, tokenizer)
val_dataset = MergingDataset(val_data, tokenizer)

# Fine-tuning the model.
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)

train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=16, shuffle=False)

optimizer = Adam(model.parameters(), lr=1e-4)

model.train()
for epoch in range(15):
    for batch in tqdm(val_dataloader):
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    model.eval()
    total_eval_accuracy = 0
    total_eval_loss = 0
    for batch in val_dataloader:
        with torch.no_grad():
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            loss = outputs.loss
            total_eval_loss += loss.item()

            logits = outputs.logits
            predictions = torch.argmax(logits, dim=-1)
            correct_predictions = predictions == batch['labels']
            total_eval_accuracy += correct_predictions.sum().item()

    average_val_accuracy = total_eval_accuracy / len(val_dataset)
    average_val_loss = total_eval_loss / len(val_dataloader)

    print(f"Validation Accuracy for epoch {epoch+1}: {average_val_accuracy}")
    print(f"Validation Loss for epoch {epoch + 1}: {average_val_loss}")

    model.train()

model.save_pretrained("drive/MyDrive/data/er_model_2")

# Loading BioBERT model and tokenizer.
tokenizer = BertTokenizer.from_pretrained('dmis-lab/biobert-base-cased-v1.1')
model = BertForSequenceClassification.from_pretrained('drive/MyDrive/data/er_model_2')

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
model.eval()

# Loading triplets dataframe.
df = pd.read_csv('drive/MyDrive/data/filtered_triplets.csv', names=['head', 'type', 'tail'])
df = df.applymap(lambda s:s.lower() if type(s) == str else s)


# Remove null values.
df = df[df['head'].notna() & df['tail'].notna()]

# Get a list of each unique term in the df.
unique_terms = list(set(df["head"].unique().tolist() 
                        + df["tail"].unique().tolist()))

to_merge = set()

for term1, term2 in combinations(unique_terms, 2):
    inputs = tokenizer.encode_plus(term1, term2, padding='max_length', max_length=512, truncation=True, return_tensors='pt')
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        prediction = torch.argmax(logits, dim=-1).item()

        if prediction == 1:  # If the model predicts that the terms should be merged...
            # Choose the shorter term to keep.
            shorter_term = term1 if len(term1) <= len(term2) else term2
            print(f'Merging {term1} and {term2}, keeping {shorter_term}')
            to_merge.add(shorter_term)

# `to_merge` now contains all terms that should be merged according to the model.
print(f"Terms to merge: {to_merge}")