In [None]:
!pip install transformers tqdm more_itertools scikit-learn torch bioc

Collecting bioc
  Downloading bioc-2.1-py3-none-any.whl.metadata (4.6 kB)
Collecting jsonlines>=1.2.0 (from bioc)
  Downloading jsonlines-4.0.0-py3-none-any.whl.metadata (1.6 kB)
Collecting intervaltree (from bioc)
  Downloading intervaltree-3.1.0.tar.gz (32 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting docopt (from bioc)
  Downloading docopt-0.6.2.tar.gz (25 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Downloading bioc-2.1-py3-none-any.whl (33 kB)
Downloading jsonlines-4.0.0-py3-none-any.whl (8.7 kB)
Building wheels for collected packages: docopt, intervaltree
  Building wheel for docopt (setup.py) ... [?25l[?25hdone
  Created wheel for docopt: filename=docopt-0.6.2-py2.py3-none-any.whl size=13706 sha256=fb15323952ac2b3d42cec29cbb9ec06881c51039803e71e1802f58344ca617a5
  Stored in directory: /root/.cache/pip/wheels/1a/bf/a1/4cee4f7678c68c5875ca89eaccf460593539805c3906722228
  Building wheel for intervaltree (setup.py) ... [?25l[?25hdone
  Created w

In [None]:
# utils file

import gzip
import random
import bioc
import json

# apparently .gz and .tar.gz are different things
# this is a tar archive
# source = "example_bioc_files.tar.gz"
# this is a gzip-compressed (.gz) single file
# BC5CDR = "processed_sources/bc5cdr_train.bioc.xml.gz"  # the source path is determined by pwd


def extract_first_n_docs(path: str, n: int | None = None):
    """
    Read a gzipped BioC XML file at `path` and return a list with info for the first `n` documents.
    If n is None (default), return info for all documents.
    Each list item is a dict: {'id': ..., 'passages': [ {'text': ..., 'annotations': [ {...}, ... ]}, ... ] }
    """
    if n is not None and n < 0:
        raise ValueError("n must be non-negative")

    results = []
    with gzip.open(path, 'rt', encoding='utf-8') as file:
        data = file.read()

    collection = bioc.biocxml.loads(data)
    docs = collection.documents if n is None else collection.documents[:n]
    for document in docs:
        doc_info = {'id': document.id, 'passages': []}
        for passage in document.passages:
            passage_info = {'text': passage.text, 'annotations': []}
            for anno in passage.annotations:
                # serialize common fields of a BioC annotation
                anno_info = {
                    'id': getattr(anno, 'id', None),
                    'text': getattr(anno, 'text', None),
                    'infons': dict(getattr(anno, 'infons', {})),
                    'locations': [
                        {'offset': getattr(loc, 'offset', None), 'length': getattr(loc, 'length', None)}
                        for loc in getattr(anno, 'locations', [])
                    ]
                }
                passage_info['annotations'].append(anno_info)
            doc_info['passages'].append(passage_info)
        results.append(doc_info)

    return results

def get_mention_names_id_pairs(path: str):
    docs = extract_first_n_docs(path)
    name_id_pairs = []
    for doc in docs:
        for passage in doc['passages']:
            for anno in passage['annotations']:
                name_id_pairs.append((anno['text'], anno['infons'].get('concept_id')))
    return name_id_pairs


# mesh = "processed_sources/mesh2015.json.gz"

def read_first_n_from_json_gz(path: str, n: int | None = None) -> list:
    """
    Read a gzipped JSON file at `path` and return the first `n` elements if the top-level
    JSON value is a list. If n is None (default), return all entries.
    Raises ValueError for non-list top-level or invalid `n`.
    """
    if n is not None and n < 0:
        raise ValueError("n must be non-negative")

    with gzip.open(path, 'rt', encoding='utf-8') as file:
        data = json.load(file)

    if not isinstance(data, list):
        raise ValueError(f"JSON top-level is {type(data).__name__}, expected list")

    return data if n is None else data[:n]


def get_entity_name_id_pairs(path: str):
    entities = read_first_n_from_json_gz(path)
    name_id_pairs = [(entity['name'], entity['id']) for entity in entities]
    return name_id_pairs


# function to get an entry by id
def get_entry_by_id(data: list, entry_id: str) -> dict | None:
    """
    Get a JSON entry by its ID from a list of entries.
    """
    for item in data:
        if item.get("id") == entry_id:
            return item
    return None

In [None]:
# There are 2 inputs to the model which I need to embed and then compare:
# 1. The mention entity name
# 2. The name of all entities

# functions:
# function to extract the name from every mention
# function to extract the name from the mesh entry

# import model
# prepare both types of inputs
# construct positives and negatives 1:4 ratio
# train on the data constructed so that dot product/ cosine similarity is high when the entity is correctly matched and low when no match


from utils import *

# BC5CDR input  (mention names and ids)
bc5cdr_name_id_pairs = get_mention_names_id_pairs("processed_sources/bc5cdr_train.bioc.xml.gz")


# MeSH2015 input (entity names and ids)
mesh_name_id_pairs = get_entity_name_id_pairs("processed_sources/mesh2015.json.gz")

In [None]:
bc5cdr_name_id_pairs[:5]

[('Naloxone', 'MESH:D009270'),
 ('clonidine', 'MESH:D003000'),
 ('hypertensive', 'MESH:D006973'),
 ('clonidine', 'MESH:D003000'),
 ('hypotensive', 'MESH:D007022')]

In [None]:
mesh_name_id_pairs[:5]

[('Calcimycin', 'MESH:D000001'),
 ('Temefos', 'MESH:D000002'),
 ('Abattoirs', 'MESH:D000003'),
 ('Abbreviations as Topic', 'MESH:D000004'),
 ('Abdomen', 'MESH:D000005')]

In [None]:
# constructing positive pairs
# take all mention names and get the corresponding entity name by matching the id

mesh_id_to_name = {entity_id: entity_name for entity_name, entity_id in mesh_name_id_pairs}

positive_pairs = [
    (mention_name, mesh_id_to_name[mention_id])
    for mention_name, mention_id in bc5cdr_name_id_pairs
    if mention_id in mesh_id_to_name
]

In [None]:
positive_pairs[:5]

[('Naloxone', 'Naloxone'),
 ('clonidine', 'Clonidine'),
 ('hypertensive', 'Hypertension'),
 ('clonidine', 'Clonidine'),
 ('hypotensive', 'Hypotension')]

In [None]:
# constructing negative pairs
# I want to have 4 times as many negative pairs as positive pairs
# There are several negative sampling techinques
# for now take one mention name from the positives and one entity name from the positives that don't match
# make it sample randomly from the positives
import random
negative_pairs = []
while len(negative_pairs) < 4 * len(positive_pairs):
    mention_name, mention_id = random.choice(positive_pairs)
    entity_name, entity_id = random.choice(positive_pairs)
    if mention_id != entity_id:
        negative_pairs.append((mention_name, entity_name))


In [None]:
negative_pairs[:5]

[('ESRD', 'vitamin K'),
 ('cardiomyopathy', 'N-pyrimidinyl-2-phenoxyacetamide'),
 ('osteopenia', 'renal dysfunction'),
 ('Dopamine', 'appetite suppressants'),
 ('estradiol', 'cognitive impairment')]

In [None]:
# We'll turn them from tuples into dictionaries with boolean label of whether they are positive or negative pairs:
training_data = []
for mention_name, entity_name in positive_pairs:
    training_data.append({'mention_name': mention_name, 'entity_name': entity_name, 'label': True})
for mention_name, entity_name in negative_pairs:
    training_data.append({'mention_name': mention_name, 'entity_name': entity_name, 'label': False})

In [None]:
# Now we'll split our dataset into training, validation and test splits:
from sklearn.model_selection import train_test_split

train_pairs, valtest_pairs = train_test_split(training_data, train_size=0.6, random_state=43)
val_pairs, test_pairs = train_test_split(valtest_pairs, train_size=0.5, random_state=43)

In [None]:
batch = train_pairs[:8]
batch

[{'mention_name': '3alpha-hydroxy-3beta-methyl-5alpha-pregnan-20-one',
  'entity_name': 'propofol',
  'label': False},
 {'mention_name': 'catecholamine',
  'entity_name': 'Catecholamines',
  'label': True},
 {'mention_name': 'Heparan sulphate',
  'entity_name': 'phenylephrine',
  'label': False},
 {'mention_name': 'AMI', 'entity_name': 'Lithium', 'label': False},
 {'mention_name': 'dipyridamole',
  'entity_name': 'retention of urine',
  'label': False},
 {'mention_name': 'iopamidol', 'entity_name': 'Fentanyl', 'label': False},
 {'mention_name': 'nephropathy',
  'entity_name': 'streptozotocin',
  'label': False},
 {'mention_name': 'thrombosis', 'entity_name': 'nephrotoxic', 'label': False}]

In [None]:
# This loads the model directly which allows fine-tuning
# All the below is for fine-tuning the base SapBERT model
from transformers import AutoTokenizer, AutoModel
model_name = "cambridgeltl/SapBERT-from-PubMedBERT-fulltext"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

# define loss function
import torch
loss_func = torch.nn.BCEWithLogitsLoss()

# define training parameters
batch_size = 8
num_epochs = 3
learning_rate = 1e-5

# define optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# for jupyter notebook
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)



# training loop
from more_itertools import chunked
from tqdm.auto import tqdm
from sklearn.metrics import f1_score

for epoch in range(num_epochs):

  model.train()
  train_loss = 0
  random.shuffle(train_pairs)
  train_batches = list(chunked(train_pairs, batch_size))
  train_predictions, train_labels = [], []

  for batch in tqdm(train_batches):
    optimizer.zero_grad()

    tokenized1 = tokenizer( [ x['mention_name'] for x in batch ], max_length=512, padding=True, truncation=True, return_tensors='pt' )
    outputs1 = model( input_ids=tokenized1['input_ids'].to(device), attention_mask=tokenized1['attention_mask'].to(device) )
    cls_vectors1 = outputs1.last_hidden_state[:,0,:]

    tokenized2 = tokenizer( [ x['entity_name'] for x in batch ], max_length=512, padding=True, truncation=True, return_tensors='pt' )
    outputs2 = model( input_ids=tokenized2['input_ids'].to(device), attention_mask=tokenized2['attention_mask'].to(device) )
    cls_vectors2 = outputs2.last_hidden_state[:,0,:]

    dotproducts = (cls_vectors1 * cls_vectors2).sum(axis=1)

    labels = torch.tensor([ float(x['label']) for x in batch ]).to(device)

    loss = loss_func(dotproducts, labels)
    loss.backward()
    optimizer.step()

    train_loss += loss.item()

    train_predictions += [ dotproduct > 0 for dotproduct in dotproducts.cpu().tolist() ]
    train_labels += [ x['label'] for x in batch ]

  model.eval()
  val_loss = 0
  val_batches = list(chunked(val_pairs, batch_size))
  val_predictions, val_labels = [], []

  with torch.no_grad():
    for batch in tqdm(val_batches):

      tokenized1 = tokenizer( [ x['mention_name'] for x in batch ], max_length=512, padding=True, truncation=True, return_tensors='pt' )
      outputs1 = model( input_ids=tokenized1['input_ids'].to(device), attention_mask=tokenized1['attention_mask'].to(device) )
      cls_vectors1 = outputs1.last_hidden_state[:,0,:]

      tokenized2 = tokenizer( [ x['entity_name'] for x in batch ], max_length=512, padding=True, truncation=True, return_tensors='pt' )
      outputs2 = model( input_ids=tokenized2['input_ids'].to(device), attention_mask=tokenized2['attention_mask'].to(device) )
      cls_vectors2 = outputs2.last_hidden_state[:,0,:]

      dotproducts = (cls_vectors1 * cls_vectors2).sum(axis=1)

      labels = torch.tensor([ float(x['label']) for x in batch ]).to(device)

      loss = loss_func(dotproducts, labels)

      val_loss += loss.item()

      # training it to predict positive if dotproduct > 0
      val_predictions += [ dotproduct > 0 for dotproduct in dotproducts.cpu().tolist() ]
      val_labels += [ x['label'] for x in batch ]

  train_loss /= len(train_batches)
  val_loss /= len(val_batches)

  train_f1 = f1_score(train_labels, train_predictions, zero_division=0)
  val_f1 = f1_score(val_labels, val_predictions, zero_division=0)

  print(f"{epoch=} {train_loss=:.4f} {train_f1=:.4f} {val_loss=:.4f} {val_f1=:.4f}")


model.save_pretrained("model/name_only_entity_linking_finetuned_model")

  0%|          | 0/3497 [00:00<?, ?it/s]

  0%|          | 0/1166 [00:00<?, ?it/s]

epoch=0 train_loss=5.8265 train_f1=0.4979 val_loss=5.3866 val_f1=0.5205


  0%|          | 0/3497 [00:00<?, ?it/s]

  0%|          | 0/1166 [00:00<?, ?it/s]

epoch=1 train_loss=3.3259 train_f1=0.5555 val_loss=5.1988 val_f1=0.5279


  0%|          | 0/3497 [00:00<?, ?it/s]

  0%|          | 0/1166 [00:00<?, ?it/s]

epoch=2 train_loss=2.3499 train_f1=0.6035 val_loss=5.0622 val_f1=0.5302
