# BioDEX-DSPy

Biomedical publications can describe adverse drug reactions. Extracting and coding these reactions is a real-world task which is vital for public drug safety, but automating this procedure has been challenging [(D'Oosterlinck et al., 2023)](https://arxiv.org/abs/2305.13395).

In this notebook, we will build a state-of-the-art reaction extractor using the [DSPy framework](https://github.com/stanfordnlp/dspy) [(Khattab et al., 2023)](https://arxiv.org/pdf/2310.03714.pdf). We will combine in-context learning and retrieval to build a pipeline that reads a full biomedical paper, predicts adverse reactions, and encodeds these reactions into one of ~26k admissable reaction terms.

## Setup
- downloads data
- installs packages

In [3]:
%load_ext autoreload
%autoreload 2

import sys
import os

# Set up the cache for this notebook
os.environ["DSP_NOTEBOOK_CACHEDIR"] = os.path.join('.', 'cache')

# download biomedical terms and embeddings
!wget -nc 'https://www.dropbox.com/scl/fi/5ywdea0xjkb10os1o6ryj/embeddings-FremyCompany-BioLORD-STAMB2-v1.pt?rlkey=nek172noiyrpn588jt7dunl66&dl=0' -O 'embeddings[FremyCompany--BioLORD-STAMB2-v1].pt'
!wget -nc 'https://www.dropbox.com/scl/fi/f92z0vg42icsn5g89f3wu/reaction_terms.txt?rlkey=ot8qasqr3r9getbn9epyji0aa&dl=0' -O "reaction_terms.txt"
!wget -nc 'https://www.dropbox.com/scl/fi/cgu0eal9m7q0xrswp49g5/cache.zip?rlkey=x3lnpc5vz1t0di7igzthrky8h&dl=0' -O "cache.zip"
!unzip -n -q cache.zip -y

# install packages
!pip install dspy-ai datasets sentence_transformers torch openai

import datasets
import dspy
from dspy.evaluate import Evaluate
import tqdm
import sentence_transformers
from sentence_transformers import SentenceTransformer
import torch
import os
from functools import lru_cache
from collections import defaultdict, Counter
import math
import re
from functools import partial



The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
File ‘embeddings[FremyCompany--BioLORD-STAMB2-v1].pt’ already there; not retrieving.
File ‘reaction_terms.txt’ already there; not retrieving.
File ‘cache.zip’ already there; not retrieving.
caution: filename not matched:  -y


## Setting up the data
We will be using the BioDEX dataset [(D'Oosterlinck et al., 2023)](https://arxiv.org/abs/2305.13395). BioDEX features full biomedical publications and associated expert-created drug safety reports. For our purposes, we will use the biomedical reactions discussed in these reports.

- [BioDEX Github](https://github.com/KarelDO/BioDEX)
- [BioDEX HuggingFace](https://huggingface.co/BioDEX)

Let's first write some helper functions to parse and normalize biomedical reactions:

In [4]:
# normalize one reaction string
def normalize(reaction: str) -> str:
    # Remove leading and trailing newlines
    reaction = reaction.strip('\n')

    # Remove leading and trailing punctuation and newlines
    reaction = re.sub(r'^[^\w\s]+|[^\w\s]+$', '', reaction, flags=re.UNICODE)

    # Remove leading and trailing newlines
    reaction = reaction.strip('\n')

    return reaction.strip().lower()

# given a csv string of reactions, parse into a list
def extract_reactions_from_string(reactions: str) -> list[str]:
  return [normalize(r) for r in reactions.split(',')]

# given a list of csv trings of reactions, parse into a list
def extract_reactions_from_strings(reactions: list[str]) -> list[str]:
  reactions = [normalize(r) for r in reactions]
  reactions = ", ".join(reactions)
  return extract_reactions_from_string(reactions)

# process a biodex datapoint
def preprocess_example(example: dict) -> dspy.Example:
    title = example['title']
    abstract = example['abstract']
    context = example['fulltext_processed'].split('\n\nTEXT:\n', 1)[-1]
    reactions = extract_reactions_from_string(example['reactions'])

    example = dict(title=title, abstract=abstract, context=context, reactions=reactions)
    example['labels'] = dspy.Example(reactions=reactions)

    return example

Great, now let's load the dataset and create a tiny training and validation set to easily apply in-context learning without being too expensive.

In [5]:
dataset = datasets.load_dataset("BioDEX/BioDEX-Reactions")
official_trainset, official_devset = dataset['train'], dataset['validation']
trainset, devset = [], []

for example in tqdm.tqdm(official_trainset):
    if len(trainset) >= 1000: break
    trainset.append(preprocess_example(example))

for example in tqdm.tqdm(official_devset):
    if len(devset) >= 500: break
    devset.append(preprocess_example(example))

trainsetX = [dspy.Example(**x).with_inputs('title', 'abstract', 'context', 'labels') for x in trainset]
trainset = [dspy.Example(**x).with_inputs('title', 'abstract', 'context') for x in trainset]
devsetX = [dspy.Example(**x).with_inputs('title', 'abstract', 'context', 'labels') for x in devset]
devset = [dspy.Example(**x).with_inputs('title', 'abstract', 'context') for x in devset]

print(len(trainset), len(devset))

  9%|▊         | 1000/11543 [00:00<00:01, 6514.35it/s]
 17%|█▋        | 500/2886 [00:00<00:00, 6430.38it/s]

1000 500





Let's look at one example to familiarize ourself.

Each datapoints has a:
- titel
- abstract
- context (body of the paper)
- a list of associated reactions

In [6]:
# show one title and abstract, as well as the reactions in the final expert-created drug safety report
print(trainset[0].title)
print()
print(trainset[0].abstract)
print()
print(trainset[0].labels().reactions)

HIV-1 Drug Resistance by Ultra-Deep Sequencing Following Short Course Zidovudine, Single-Dose Nevirapine, and Single-Dose Tenofovir with Emtricitabine for Prevention of Mother-to-Child Transmission.

Antiretroviral drug resistance following pMTCT strategies remains a significant problem. With rapid advancements in next generation sequencing technologies, there is more focus on HIV drug-resistant variants of low frequency, or the so-called minority variants. In South Africa, AZT monotherapy for pMTCT, similar to World Health Organization option A, has been used since 2008. In 2010, a single dose of co-formulated TDF/FTC was included in the strategy for prevention of resistance conferred by single-dose nevirapine (sd NVP). The study was conducted in KwaZulu-Natal, South Africa, among pMTCT participants who received AZT monotherapy from 14 weeks of gestation, intrapartum AZT and sd NVP, and postpartum sd TDF/FTC. Twenty-six specimens collected at 6 weeks post-delivery were successfully se

## Setting up evaluation code

We will optimize Recall@10: given 10 unique reaction predictions from our pipeline, how many of the true predictions did we find.


In [7]:
def metric_recall(gold: list[str], pred: list[str]) -> float:
  """ Given a gold and predicted list of reactions, normalize and compute recall."""
  gold = [normalize(r) for r in gold]
  pred = [normalize(r) for r in pred]

  gold, pred = set(gold), set(pred)

  intersection = gold.intersection(pred)

  recall = len(intersection) / len(gold)
  return recall

def metric_recallK(gold: list[str], pred: list[str], K:int=10) -> float:
  return metric_recall(gold, pred[:K])

# wrap the recall@K metric so it can take dspy Examples
def dspy_metric_recall10(gold: dspy.Example, pred: dspy.Example, trace=None) -> float:
  return metric_recallK(gold.reactions, pred.reactions, K=10)

def dspy_metric_recall20(gold: dspy.Example, pred: dspy.Example, trace=None) -> float:
  return metric_recallK(gold.reactions, pred.reactions, K=20)
  
def dspy_metric_recall30(gold: dspy.Example, pred: dspy.Example, trace=None) -> float:
  return metric_recallK(gold.reactions, pred.reactions, K=30)


## Set up a neural grounding
The output space of BioDEX contains ~26k distinct reaction classes. Even naively enumerating these classes, without demonstrations, might overflow the in-context learning window. Thus, we need a more scalable way of combining reactions with in-context learning. To solve this, we set up a retrieval-based grounder. This grounded embeds the reaction names and retrieves the K nearest neighbors given an ungrounded reaction prediction. Additionally, the grounder returns the predicted semantic similarty, so that we can use this later on as a measure of prediction confidence. Optinally, the grounder incorporates the prior distribution of reactions in the BioDEX trainset, so that prior statistics can be combined with in-context learning.

We use the BioLORD model [(Remy et al., 2022)](https://arxiv.org/abs/2210.11892) to produce our embeddings, but feel free to experiment with any type of embeddings.

Switch the runtime to GPU when you are using a new encoder to speed up embedding.

In [8]:
class ReactionGrounder():
  """ Matches a given reaction string to one of the BioDEX reaction terms.
  Implements several ways of grounding reactions. Each grounding function returns a list of length K.
  Each list item consists of a similarity score and the associated grounded term. """

  def __init__(self, model_name:str='FremyCompany/BioLORD-STAMB2-v1', trainset: list[dict]=[]):

      self.model_name = model_name
      self.friendly_model_name = self.model_name.replace('/','--')
      self.trainset = trainset

      self.model = SentenceTransformer(self.model_name)
      self.model.to('cpu')

      self.reaction_terms = self._load_reaction_terms()
      self.reaction_terms_to_count = self._calculate_counts()
      self.reaction_embeddings = self._load_embeddings()

  def _load_reaction_terms(self) -> list[str]:
      """Get all reaction terms and normalize them."""
      reaction_filename = 'reaction_terms.txt'
      return [normalize(r) for r in open(reaction_filename).read().splitlines()]

  def _load_embeddings(self) -> torch.Tensor:
      """Load or create embeddings for all reaction terms."""
      reaction_embeddings_filename = f'embeddings[{self.friendly_model_name}].pt'

      # If the file exists, load. Else, create embeddings.
      if os.path.isfile(reaction_embeddings_filename):
          with open(reaction_embeddings_filename, "rb") as f:
              reaction_embeddings = torch.load(f, map_location=torch.device('cpu'))
      else:
          self.model.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
          reaction_embeddings = self.model.encode(self.reaction_terms, convert_to_tensor = True, show_progress_bar = True)
          with open(reaction_embeddings_filename, "wb") as f:
              torch.save(reaction_embeddings, f)
          self.model.to(torch.device('cpu'))
      return reaction_embeddings

  def _calculate_counts(self) -> dict[str, int]:
      """ Given a training set, count how many times each reaction occured as prior."""
      reactions = []
      for example in self.trainset:
        reactions.extend(extract_reactions_from_string(example['reactions']))
      counts = defaultdict(lambda: 0, Counter(reactions))
      return counts

  @lru_cache(maxsize=100000)
  def ground(self, reaction: str, K:int=3) -> list[tuple[float, str]]:
      """ Finds K closest matches based on semantic embedding similarity. """
      query_embeddings = self.model.encode(reaction, convert_to_tensor = True)
      query_result = sentence_transformers.util.semantic_search(query_embeddings, self.reaction_embeddings, query_chunk_size=64, top_k=K)[0]

      # get (score, term) tuples
      matches = []
      for result in query_result:
        score = result['score']
        term = self.reaction_terms[result['corpus_id']]
        matches.append((score, term))

      return sorted(matches, reverse=True)

  @lru_cache(maxsize=100000)
  def ground_with_prior(self, reaction, K=3):
      """ Finds 3*K closest matches based on semantic embedding similarity.
      Incorporates the prior counts, returns K most likely terms."""
      matches = self.ground(reaction, K=3*K)

      # heuristically incorporate prior into the similarity scores
      new_matches = []
      for score, term in matches:
        prior = self.reaction_terms_to_count[term]

        score = score * max(2,math.log(prior + math.e))
        # score = score * math.log(prior + math.e)
        # score = score * math.log(prior + .1)
        new_matches.append((score, term))

      return sorted(new_matches, reverse=True)[:K]

# create the grounder
grounder = ReactionGrounder(trainset=official_trainset)


No sentence-transformers model found with name /Users/kldooste/.cache/torch/sentence_transformers/FremyCompany_BioLORD-STAMB2-v1. Creating a new one with MEAN pooling.


Let's play with some of the groundings! Look how different queries get grounded, and how incorporating a prior term changes some of the retrieved reactions.

In [9]:
K = 5
queries = [
    'pain',
    'fever',
    'i have a runny nose'
]

for q in queries:
  result = grounder.ground(q, K=K)
  result_with_prior = grounder.ground_with_prior(q, K=K)

  print("Query: ", q)

  print("\t Ground without prior:")
  for score, term in result:
    print(f"\t\t{term} - {score} ")
  print("\t Ground with prior:")
  for score, term in result_with_prior:
    print(f"\t\t{term} - {score} ")


Query:  pain
	 Ground without prior:
		pain - 1.000000238418579 
		tenderness - 0.8098210096359253 
		inflammatory pain - 0.7327480316162109 
		headache - 0.7172471284866333 
		discomfort - 0.7165692448616028 
	 Ground with prior:
		pain - 4.679520433152032 
		headache - 3.8610861348756327 
		tenderness - 2.5640926425656185 
		discomfort - 2.0598477859496542 
		musculoskeletal pain - 1.8701876261560002 
Query:  fever
	 Ground without prior:
		pyrexia - 0.923504114151001 
		body temperature increased - 0.8943093419075012 
		hyperpyrexia - 0.8813413381576538 
		febrile infection - 0.7856964468955994 
		body temperature abnormal - 0.7851898670196533 
	 Ground with prior:
		pyrexia - 5.612130707861001 
		chills - 2.5693159432673554 
		hyperthermia - 2.3369348762331836 
		body temperature increased - 2.2742648513245327 
		hyperpyrexia - 2.09049834993291 
Query:  i have a runny nose
	 Ground without prior:
		nasal congestion - 0.6704621315002441 
		nasal flaring - 0.6388666033744812 
		rebou

## DSPy program
We are almost ready to build our DSPy program!

Our program will have the following input-output behavior:
- input: title, abstract, body of a paper
- output: a list of predicted reactions, sorted by how confident we are.

There are many potential pipelines we can construct. In this work, we will:
1. chunk a paper in different windows
2. predict ungrounded reaction terms per chunk using in-context learning
3. find the top K nearest grounded neighbors per ungrounded reaction using the grounder defined above
4. resolve all grounded reactions into a final prediction using the similarity scores from the grounding step

Because of the abstractions provided by DSPy, we can quickly iterate on the design of this pipeline. For example, if a reaction term is predicted in two chunks, should we add the confidence of this term or take the maximum for the final prediction? This is all trivally implemented in DSPy.




Let's first implement the function that is used to resolve a list of predicted reactions and scores:

In [10]:
from operator import add

def resolve_reactions(reactions: list[(float, str)], resolve_f = add) -> list[str]:
  """ Applies a resolve function across all duplicate predicted reactions to aggregate their similarity score.
  Sorts the resulting reactions according to aggregated score."""

  reactions_to_score = defaultdict(lambda: .0)

  for score, term in reactions:
    reactions_to_score[term] = resolve_f(reactions_to_score[term], score)

  reactions = sorted(reactions_to_score.items(), key=lambda x: x[1], reverse=True)
  reactions = [(r[1], r[0]) for r in reactions]
  return reactions

Let's also implement a chunker class.

In [11]:
class Chunker:
    def __init__(self, context_window=3000, max_windows=5):
        self.context_window = context_window
        self.max_windows = max_windows
        self.window_overlap = 0.02

    def __call__(self, paper):
        snippet_idx = 0

        while snippet_idx < self.max_windows and paper:
            endpos = int(self.context_window * (1.0 + self.window_overlap))
            snippet, paper = paper[:endpos], paper[endpos:]

            next_newline_pos = snippet.rfind('\n')
            if paper and next_newline_pos != -1 and next_newline_pos >= self.context_window // 2:
                paper = snippet[next_newline_pos+1:] + paper
                snippet = snippet[:next_newline_pos]

            yield snippet_idx, snippet.strip()
            snippet_idx += 1

Great! Now let's build the DSPy program, starting with the signature for the ChainOfThought in-context module. Notice how this is everything we need to use ChainOfThought. No sloppy prompt-engineering, but well-defined interfaces:

In [12]:
class PredictReactions(dspy.Signature):
    __doc__ = f"""Given a snippet from a medical article, identify the adverse drug reactions affecting the patient. If none are mentioned in the snippet, say '\n'."""

    title = dspy.InputField()
    context = dspy.InputField()
    reactions = dspy.OutputField(desc="list of comma-separated adverse drug reactions", format=lambda x: ', '.join(x) if isinstance(x, list) else x)


Awesome, let's create our pipeline. `PredictThenGround` will use ChainOfThought and chunking to predict reactions, and will optionally use a grounding function to resolve these to the final space of reactions.

In [13]:
# Grounding with prior
HINT = "Is any of the following reactions discussed in the article snippet? The valid candidates are:"

class PredictThenGround(dspy.Module):
    def __init__(self, context_window=3000, max_windows=5, num_preds=1, grounding_function = lambda r: [(1.0, r)], resolve_function = add):
        super().__init__()

        # devides a biomedical paper into chunks
        self.chunk = Chunker(context_window=context_window, max_windows=max_windows)
        # given a paper title and body, predict a list of ungrounded reactions using a CoT
        self.predict = dspy.ChainOfThoughtWithHint(PredictReactions, n=num_preds)
        # ungrounded reaction -> grounded neighbors and similarity scores
        self.grounding_function = grounding_function
        # operator to combine similarity scores over multiple duplicate predictions
        self.resolve_function = resolve_function

    def forward(self, title, abstract, context, labels=None):
        hint = f"{HINT} {', '.join(labels.reactions)}." if labels else None
        reactions = []

        # for each chunk in the paper
        for _, snippet in self.chunk(abstract + '\n\n' + context):
            # use the LM to predict ungrounded reactions
            chunk_reactions = self.predict(title=title, context=[snippet], hint=hint)
            reactions.extend(extract_reactions_from_strings(chunk_reactions.completions.reactions))

        # for each ungrounded reaction, get grounded reactions and grounding confidence
        grounded_reactions = sorted([r for sublist in [self.grounding_function(r) for r in reactions] for r in sublist], reverse=True)
        # aggregate / pool duplicate predictions and sort based on confidence
        resolved_reactions = resolve_reactions(grounded_reactions, resolve_f=self.resolve_function)
        # get a final list of grounded reactions
        reactions = [r[1] for r in resolved_reactions]

        # track all of these predictions
        return dspy.Prediction(reactions=reactions, resolved_reactions=resolved_reactions, grounded_reactions=grounded_reactions, ungrounded_reactions=reactions)

Let's get ready to run some models! We'll use `gpt-3.5-turbo-1106`. All executions are cached, so you don't need an API to run this notebook. However, if you change anything, you might need to add your OpenAI API key to execute new GPT calls.

In [14]:
# set DSPy to use gpt-3.5
turbo11 = dspy.OpenAI(model='gpt-3.5-turbo-1106', max_tokens=150)
dspy.settings.configure(lm=turbo11)

The following code creates an evaluation helper which we will use to evaluate all our DSPy programs.

In [15]:
# create an evaluation helper function
evaluateR10 = Evaluate(devset=trainset[100:150], metric=dspy_metric_recall10, num_threads=8, display_progress=True, display_table=0, max_errors=100)
evaluateR20 = Evaluate(devset=trainset[100:150], metric=dspy_metric_recall20, num_threads=8, display_progress=True, display_table=0, max_errors=100)
evaluateR30 = Evaluate(devset=trainset[100:150], metric=dspy_metric_recall30, num_threads=8, display_progress=True, display_table=0, max_errors=100)

### Simplest CoT pipeline
Let's evaluate the simplest CoT pipeline possible, which uses no grounding and only processes the first chunk of the paper.

In [16]:
# create the pipeline
pipeline_no_grounding_no_chunking = PredictThenGround(max_windows=1)

# get an example
example = trainset[0]
example_x = example.inputs()
example_y = example.labels()

# get a prediction
prediction = pipeline_no_grounding_no_chunking(**example_x)

# print the reactions predicted
print('Predicted reactions: ', prediction.reactions)
print('Gold reactions: ', example_y.reactions)
print('Recall@10: ', dspy_metric_recall10(example_y, prediction))

Predicted reactions:  ['thymidine analogue mutations', 'nnrti resistance']
Gold reactions:  ['drug resistance', 'exposure during pregnancy', 'viral mutation identified']
Recall@10:  0.0


This does not look that great. Let's run a more thorough evaluation.

In [17]:
evaluateR10(pipeline_no_grounding_no_chunking)

Average Metric: 2.5325460487225193 / 50  (5.1): 100%|██████████| 50/50 [00:00<00:00, 493.04it/s] 

Average Metric: 2.5325460487225193 / 50  (5.1%)





5.07

5.07% Recall@10, not great.

### CoT pipeline with chunking
Let's evaluate the CoT pipeline now over 5 chunks.

In [18]:
# create the pipeline
pipeline_no_grounding = PredictThenGround(max_windows=5)

# get an example
example = trainset[0]
example_x = example.inputs()
example_y = example.labels()

# get a prediction
prediction = pipeline_no_grounding(**example_x)

# print the reactions predicted
print('Predicted reactions: ', prediction.reactions)
print('Gold reactions: ', example_y.reactions)
print('Recall@10: ', dspy_metric_recall10(example_y, prediction))

evaluateR10(pipeline_no_grounding)

Predicted reactions:  ['nnrti resistance', 'y188c', 'y181c', 'v90i', 'v108i', 'v106m', 'v106a', 'thymidine analogue mutations (tams', 'thymidine analogue mutations', 'tams', 't69s', 'n/a', 'l100v', 'k70r', 'k65r', 'k103n', 'k101e', 'intermediate nnrti resistance', 'high-level resistance to nvp and efv', 'g190a', 'f227l', 'a98g', '']
Gold reactions:  ['drug resistance', 'exposure during pregnancy', 'viral mutation identified']
Recall@10:  0.0


Average Metric: 4.870875160271446 / 50  (9.7): 100%|██████████| 50/50 [00:00<00:00, 405.58it/s]  

Average Metric: 4.870875160271446 / 50  (9.7%)





9.74

Still no traction on this example, but at least we are already outputting more reactions, increasing the final evaluation metric to 9.74!

### CoT pipeline with chunking and per-chunk grounding
Let's add the simplest grounding function. For every ungrounded prediction the LM makes, our pipeline will now find the 3 nearest grounded reaction terms.

When the same reaction term is predicted multiple times, we resolve to adding the similarity scores.

In [19]:
K = 3

# greate the grounding function
grounding_function = partial(grounder.ground, K=K)

# greate the pipeline
pipeline_with_ground = PredictThenGround(max_windows=5, grounding_function=grounding_function)

# get an example
example = trainset[0]
example_x = example.inputs()
example_y = example.labels()

# get a predcition
prediction = pipeline_with_ground(**example_x)

# print the reactions predicted
print('Predicted reactions: ', prediction.reactions)
print('Gold reactions: ', example_y.reactions)
print('Recall@10: ', dspy_metric_recall10(example_y, prediction))
evaluateR10(pipeline_with_ground)

Predicted reactions:  ['carbohydrate antigen 125', 'multiple-drug resistance', 'nat2 polymorphism', 'hiv tropism identified', 'carbohydrate antigen 27.29', 'nat1 polymorphism', 'hiv infection cdc category a', 'carbohydrate antigen 19-9', 'mutagenic effect', 'ret gene mutation', 'hiv antigen', 'k-ras status assay', 'carbohydrate antigen 549', 'carbohydrate antigen 50', 'c-kit gene mutation', 'transgenerational epigenetic inheritance', 'acquired gene mutation', 'sustained viral response', 'pathogen resistance', 'vascular resistance systemic', 'human t-cell lymphotropic virus infection', 'human t-cell lymphotropic virus type i infection', 'n-ras gene mutation', 'k-ras gene mutation', 'nyha classification', 'slow virus infection', 'melas syndrome', 'no adverse event', 'varicella post vaccine', 'covid-19', 'blood group o', 'viral mutation identified', 'hiv infection cdc group i', 'intermediate syndrome', 'reverse tri-iodothyronine', 'blood group b', 'aase syndrome', 'troponin t', 'tri-iodot

Average Metric: 12.755123995371672 / 50  (25.5): 100%|██████████| 50/50 [00:10<00:00,  4.81it/s]

Average Metric: 12.755123995371672 / 50  (25.5%)





25.51

Our final metric jumped to 25.51! For our initial example, we actually correctly predict 'viral mutation identified', but not as one of the top-10 reactions. Let's try to solve this by incorporating the prior distribution of the reactions in the grounding.

### CoT pipeline with chunking and per-chunk grounding and priors
Let's use the grounding function that takes the priors into account.

In [20]:
# let's make a better pipeline with grounding, now taking the prior into account.

K = 3

grounding_function = partial(grounder.ground_with_prior, K=K)
pipeline_with_ground_and_prior = PredictThenGround(max_windows=5, grounding_function=grounding_function)

# get an example
example = trainset[0]
example_x = example.inputs()
example_y = example.labels()

# get a predcition
prediction = pipeline_with_ground_and_prior(**example_x)

# print the reactions predicted
print('Predicted reactions: ', prediction.reactions)
print('Gold reactions: ', example_y.reactions)
print('Recall@10: ', dspy_metric_recall10(example_y, prediction))

evaluateR10(pipeline_with_ground_and_prior)

Predicted reactions:  ['multiple-drug resistance', 'pathogen resistance', 'drug resistance', 'carbohydrate antigen 125', 'viral mutation identified', 'hiv tropism identified', 'carbohydrate antigen 27.29', 'acquired gene mutation', 'carbohydrate antigen 19-9', 'mutagenic effect', 'no adverse event', 'hiv antigen', 'hiv infection cdc category a', 'k-ras status assay', 'carbohydrate antigen 549', 'covid-19', 'virologic failure', 'carbohydrate antigen 50', 'c-kit gene mutation', 'transgenerational epigenetic inheritance', 'bk virus infection', 'ret gene mutation', 'human t-cell lymphotropic virus infection', 'human t-cell lymphotropic virus type i infection', 'n-ras gene mutation', 'nat2 polymorphism', 'k-ras gene mutation', 'nyha classification', 'slow virus infection', 'melas syndrome', 'covid-19 pneumonia', 'blood group o', 'intermediate syndrome', 'reverse tri-iodothyronine', 'aase syndrome', 'troponin t', 'tri-iodothyronine', 'divorced', 'homosexuality', 'married']
Gold reactions:  [

Average Metric: 14.939415830127901 / 50  (29.9): 100%|██████████| 50/50 [00:10<00:00,  4.83it/s]

Average Metric: 14.939415830127901 / 50  (29.9%)





29.88

29.88%, Great! We got strong zero-shot performance by combining in-context learning with grounding and priors. Using the DSPy framework, this was all rather easy to code up and highly modular.

## Analyze best program
We're going to look at some prediction to understand what is going on. Open the output in a text editor if you want to see it.

In [21]:
def print_score_and_term(ls):
    for score, term in ls:
        print('\t\t', term, '\t', score)

for example in trainset[100:150]:
    # get an example
    example_x = example.inputs()
    example_y = example.labels()

    # get a predcition
    prediction = pipeline_with_ground_and_prior(**example_x)

    # print the reactions predicted
    print('Gold reactions: ',)
    print('\t', example_y.reactions)
    print('Predicted reactions: ')
    print('\t', prediction.reactions)
    print('Ungrounded reactions: ')
    print('\t', prediction.ungrounded_reactions)
    print('Recall@10: ', dspy_metric_recall10(example_y, prediction))
    print('Recall@20: ', dspy_metric_recall20(example_y, prediction))
    print('Recall@30: ', dspy_metric_recall30(example_y, prediction))
    print('Resolved reactions: ')
    print_score_and_term(prediction.resolved_reactions)
    print('Grounded reactions: ')
    print_score_and_term(prediction.grounded_reactions)
    print('--------')

Gold reactions: 
	 ['blindness', 'consciousness fluctuating', 'encephalitis', 'encephalopathy', 'eye pain', 'eyelid ptosis', 'infection reactivation', 'loss of consciousness', 'mydriasis', 'off label use', 'ophthalmoplegia', 'orbital apex syndrome', 'paraesthesia', 'periorbital oedema', 'pupil fixed', 'pyrexia', 'varicella zoster virus infection']
Predicted reactions: 
	 ['ophthalmoplegia', 'gaze palsy', 'extraocular muscle paresis', 'encephalopathy', 'no adverse event', 'encephalitis', 'hepatic encephalopathy', 'toxic encephalopathy', 'meningitis', 'meningitis aseptic', 'encephalitis viral', 'adverse drug reaction', 'panencephalitis', 'orbital apex syndrome', 'meningitis noninfective', 'no reaction on previous exposure to drug', 'orbital compartment syndrome', 'cavernous sinus syndrome']
Ungrounded reactions: 
	 ['ophthalmoplegia', 'gaze palsy', 'extraocular muscle paresis', 'encephalopathy', 'no adverse event', 'encephalitis', 'hepatic encephalopathy', 'toxic encephalopathy', 'mening

# Compiling programs
Typically with in-context learning, we'd spend some time writing a better prompt or gathering good few-shot demonstrations. However, this is brittle and time-consuming. Luckily DSPy can help us by compiling our prompt for us!

This pushes the performance to 35%!

In [22]:
from dspy.teleprompt import BootstrapFewShotWithRandomSearch
tp = BootstrapFewShotWithRandomSearch(metric=dspy_metric_recall10, max_bootstrapped_demos=2, max_labeled_demos=0, max_rounds=1,
                                       num_candidate_programs=20, num_threads=8, teacher_settings=dict(lm=turbo11))


compiledR = tp.compile(pipeline_with_ground_and_prior, teacher=pipeline_with_ground_and_prior, trainset=trainsetX[0:100], valset=trainset[100:150], restrict=range(20))


./cache/compiler
Going to sample between 1 and 2 traces per predictor.
Will attempt to train 20 candidate sets.
-3 range(0, 20)
-2 range(0, 20)
-1 range(0, 20)


  2%|▏         | 2/100 [00:00<00:29,  3.30it/s]


Bootstrapped 2 full traces after 3 examples in round 0.


Average Metric: 15.112314319667261 / 50  (30.2): 100%|██████████| 50/50 [00:08<00:00,  5.77it/s]


Average Metric: 15.112314319667261 / 50  (30.2%)
Score: 30.22 for set: [2]
New best score: 30.22 for seed 0
Scores so far: [30.22]
Best score: 30.22


  1%|          | 1/100 [00:00<00:34,  2.88it/s]


Bootstrapped 1 full traces after 2 examples in round 0.


Average Metric: 14.80803624480095 / 50  (29.6): 100%|██████████| 50/50 [00:07<00:00,  6.54it/s] 


Average Metric: 14.80803624480095 / 50  (29.6%)
Score: 29.62 for set: [1]
Scores so far: [30.22, 29.62]
Best score: 30.22


  1%|          | 1/100 [00:00<00:13,  7.40it/s]


Bootstrapped 1 full traces after 2 examples in round 0.


Average Metric: 17.35727866904337 / 50  (34.7): 100%|██████████| 50/50 [00:14<00:00,  3.57it/s] 


Average Metric: 17.35727866904337 / 50  (34.7%)
Score: 34.71 for set: [1]
New best score: 34.71 for seed 2
Scores so far: [30.22, 29.62, 34.71]
Best score: 34.71
Average of max per entry across top 1 scores: 0.3471455733808675
Average of max per entry across top 2 scores: 0.4029887106357694
Average of max per entry across top 3 scores: 0.4114887106357694
Average of max per entry across top 5 scores: 0.4114887106357694
Average of max per entry across top 8 scores: 0.4114887106357694
Average of max per entry across top 9999 scores: 0.4114887106357694


  1%|          | 1/100 [00:00<00:35,  2.83it/s]


Bootstrapped 1 full traces after 2 examples in round 0.


Average Metric: 16.347177658942364 / 50  (32.7): 100%|██████████| 50/50 [00:10<00:00,  4.62it/s]


Average Metric: 16.347177658942364 / 50  (32.7%)
Score: 32.69 for set: [1]
Scores so far: [30.22, 29.62, 34.71, 32.69]
Best score: 34.71
Average of max per entry across top 1 scores: 0.3471455733808675
Average of max per entry across top 2 scores: 0.3970546642899584
Average of max per entry across top 3 scores: 0.42389780154486034
Average of max per entry across top 5 scores: 0.4323978015448603
Average of max per entry across top 8 scores: 0.4323978015448603
Average of max per entry across top 9999 scores: 0.4323978015448603


  1%|          | 1/100 [00:00<00:25,  3.96it/s]


Bootstrapped 1 full traces after 2 examples in round 0.


Average Metric: 17.726203208556147 / 50  (35.5): 100%|██████████| 50/50 [00:09<00:00,  5.24it/s]


Average Metric: 17.726203208556147 / 50  (35.5%)
Score: 35.45 for set: [1]
New best score: 35.45 for seed 4
Scores so far: [30.22, 29.62, 34.71, 32.69, 35.45]
Best score: 35.45
Average of max per entry across top 1 scores: 0.3545240641711229
Average of max per entry across top 2 scores: 0.39821093285799164
Average of max per entry across top 3 scores: 0.41245335710041586
Average of max per entry across top 5 scores: 0.4346200237670825
Average of max per entry across top 8 scores: 0.4346200237670825
Average of max per entry across top 9999 scores: 0.4346200237670825


  2%|▏         | 2/100 [00:00<00:21,  4.66it/s]


Bootstrapped 2 full traces after 3 examples in round 0.


Average Metric: 12.51916221033868 / 50  (25.0): 100%|██████████| 50/50 [00:06<00:00,  7.50it/s] 


Average Metric: 12.51916221033868 / 50  (25.0%)
Score: 25.04 for set: [2]
Scores so far: [30.22, 29.62, 34.71, 32.69, 35.45, 25.04]
Best score: 35.45
Average of max per entry across top 1 scores: 0.3545240641711229
Average of max per entry across top 2 scores: 0.39821093285799164
Average of max per entry across top 3 scores: 0.41245335710041586
Average of max per entry across top 5 scores: 0.4346200237670825
Average of max per entry across top 8 scores: 0.4587361853832441
Average of max per entry across top 9999 scores: 0.4587361853832441


  1%|          | 1/100 [00:00<00:29,  3.38it/s]


Bootstrapped 1 full traces after 2 examples in round 0.


Average Metric: 14.85307721174594 / 50  (29.7): 100%|██████████| 50/50 [00:12<00:00,  3.98it/s] 


Average Metric: 14.85307721174594 / 50  (29.7%)
Score: 29.71 for set: [1]
Scores so far: [30.22, 29.62, 34.71, 32.69, 35.45, 25.04, 29.71]
Best score: 35.45
Average of max per entry across top 1 scores: 0.3545240641711229
Average of max per entry across top 2 scores: 0.39821093285799164
Average of max per entry across top 3 scores: 0.41245335710041586
Average of max per entry across top 5 scores: 0.4358393220126966
Average of max per entry across top 8 scores: 0.46645548362885825
Average of max per entry across top 9999 scores: 0.46645548362885825


  2%|▏         | 2/100 [00:00<00:30,  3.26it/s]


Bootstrapped 2 full traces after 3 examples in round 0.


Average Metric: 11.73581402257873 / 50  (23.5): 100%|██████████| 50/50 [00:04<00:00, 11.94it/s] 


Average Metric: 11.73581402257873 / 50  (23.5%)
Score: 23.47 for set: [2]
Scores so far: [30.22, 29.62, 34.71, 32.69, 35.45, 25.04, 29.71, 23.47]
Best score: 35.45
Average of max per entry across top 1 scores: 0.3545240641711229
Average of max per entry across top 2 scores: 0.39821093285799164
Average of max per entry across top 3 scores: 0.41245335710041586
Average of max per entry across top 5 scores: 0.4358393220126966
Average of max per entry across top 8 scores: 0.48645548362885827
Average of max per entry across top 9999 scores: 0.48645548362885827


  1%|          | 1/100 [00:00<00:05, 17.72it/s]


Bootstrapped 1 full traces after 2 examples in round 0.


Average Metric: 13.458036244800951 / 50  (26.9): 100%|██████████| 50/50 [00:05<00:00,  9.04it/s]


Average Metric: 13.458036244800951 / 50  (26.9%)
Score: 26.92 for set: [1]
Scores so far: [30.22, 29.62, 34.71, 32.69, 35.45, 25.04, 29.71, 23.47, 26.92]
Best score: 35.45
Average of max per entry across top 1 scores: 0.3545240641711229
Average of max per entry across top 2 scores: 0.39821093285799164
Average of max per entry across top 3 scores: 0.41245335710041586
Average of max per entry across top 5 scores: 0.4358393220126966
Average of max per entry across top 8 scores: 0.4931221502955249
Average of max per entry across top 9999 scores: 0.5131221502955249


  2%|▏         | 2/100 [00:00<00:15,  6.25it/s]


Bootstrapped 2 full traces after 3 examples in round 0.


Average Metric: 17.537046939988116 / 50  (35.1): 100%|██████████| 50/50 [00:09<00:00,  5.20it/s]


Average Metric: 17.537046939988116 / 50  (35.1%)
Score: 35.07 for set: [2]
Scores so far: [30.22, 29.62, 34.71, 32.69, 35.45, 25.04, 29.71, 23.47, 26.92, 35.07]
Best score: 35.45
Average of max per entry across top 1 scores: 0.3545240641711229
Average of max per entry across top 2 scores: 0.38235204991087346
Average of max per entry across top 3 scores: 0.4226298276886512
Average of max per entry across top 5 scores: 0.44912982768865123
Average of max per entry across top 8 scores: 0.4893491259342653
Average of max per entry across top 9999 scores: 0.5209652875504269


  1%|          | 1/100 [00:00<00:13,  7.10it/s]


Bootstrapped 1 full traces after 2 examples in round 0.


Average Metric: 14.391889483065953 / 50  (28.8): 100%|██████████| 50/50 [00:04<00:00, 10.86it/s]


Average Metric: 14.391889483065953 / 50  (28.8%)
Score: 28.78 for set: [1]
Scores so far: [30.22, 29.62, 34.71, 32.69, 35.45, 25.04, 29.71, 23.47, 26.92, 35.07, 28.78]
Best score: 35.45
Average of max per entry across top 1 scores: 0.3545240641711229
Average of max per entry across top 2 scores: 0.38235204991087346
Average of max per entry across top 3 scores: 0.4226298276886512
Average of max per entry across top 5 scores: 0.44912982768865123
Average of max per entry across top 8 scores: 0.4551824592675986
Average of max per entry across top 9999 scores: 0.5234652875504269


  2%|▏         | 2/100 [00:00<00:09, 10.26it/s]


Bootstrapped 2 full traces after 3 examples in round 0.


Average Metric: 12.664111705288178 / 50  (25.3): 100%|██████████| 50/50 [00:02<00:00, 17.71it/s]


Average Metric: 12.664111705288178 / 50  (25.3%)
Score: 25.33 for set: [2]
Scores so far: [30.22, 29.62, 34.71, 32.69, 35.45, 25.04, 29.71, 23.47, 26.92, 35.07, 28.78, 25.33]
Best score: 35.45
Average of max per entry across top 1 scores: 0.3545240641711229
Average of max per entry across top 2 scores: 0.38235204991087346
Average of max per entry across top 3 scores: 0.4226298276886512
Average of max per entry across top 5 scores: 0.44912982768865123
Average of max per entry across top 8 scores: 0.4551824592675986
Average of max per entry across top 9999 scores: 0.5234652875504269


  2%|▏         | 2/100 [00:00<00:24,  3.92it/s]


Bootstrapped 2 full traces after 3 examples in round 0.


Average Metric: 12.636824123588829 / 50  (25.3): 100%|██████████| 50/50 [00:05<00:00,  8.90it/s]


Average Metric: 12.636824123588829 / 50  (25.3%)
Score: 25.27 for set: [2]
Scores so far: [30.22, 29.62, 34.71, 32.69, 35.45, 25.04, 29.71, 23.47, 26.92, 35.07, 28.78, 25.33, 25.27]
Best score: 35.45
Average of max per entry across top 1 scores: 0.3545240641711229
Average of max per entry across top 2 scores: 0.38235204991087346
Average of max per entry across top 3 scores: 0.4226298276886512
Average of max per entry across top 5 scores: 0.44912982768865123
Average of max per entry across top 8 scores: 0.4551824592675986
Average of max per entry across top 9999 scores: 0.5434652875504269


  2%|▏         | 2/100 [00:00<00:07, 12.68it/s]


Bootstrapped 2 full traces after 3 examples in round 0.


Average Metric: 15.349702911467615 / 50  (30.7): 100%|██████████| 50/50 [00:05<00:00,  9.38it/s]


Average Metric: 15.349702911467615 / 50  (30.7%)
Score: 30.7 for set: [2]
Scores so far: [30.22, 29.62, 34.71, 32.69, 35.45, 25.04, 29.71, 23.47, 26.92, 35.07, 28.78, 25.33, 25.27, 30.7]
Best score: 35.45
Average of max per entry across top 1 scores: 0.3545240641711229
Average of max per entry across top 2 scores: 0.38235204991087346
Average of max per entry across top 3 scores: 0.4226298276886512
Average of max per entry across top 5 scores: 0.46546316102198454
Average of max per entry across top 8 scores: 0.4846824592675986
Average of max per entry across top 9999 scores: 0.5454652875504269


  1%|          | 1/100 [00:00<00:32,  3.03it/s]


Bootstrapped 1 full traces after 2 examples in round 0.


Average Metric: 13.963324420677363 / 50  (27.9): 100%|██████████| 50/50 [00:09<00:00,  5.46it/s]


Average Metric: 13.963324420677363 / 50  (27.9%)
Score: 27.93 for set: [1]
Scores so far: [30.22, 29.62, 34.71, 32.69, 35.45, 25.04, 29.71, 23.47, 26.92, 35.07, 28.78, 25.33, 25.27, 30.7, 27.93]
Best score: 35.45
Average of max per entry across top 1 scores: 0.3545240641711229
Average of max per entry across top 2 scores: 0.38235204991087346
Average of max per entry across top 3 scores: 0.4226298276886512
Average of max per entry across top 5 scores: 0.46546316102198454
Average of max per entry across top 8 scores: 0.4846824592675986
Average of max per entry across top 9999 scores: 0.549465287550427


  1%|          | 1/100 [00:00<00:21,  4.51it/s]


Bootstrapped 1 full traces after 2 examples in round 0.


Average Metric: 15.328490790255495 / 50  (30.7): 100%|██████████| 50/50 [00:06<00:00,  7.39it/s]


Average Metric: 15.328490790255495 / 50  (30.7%)
Score: 30.66 for set: [1]
Scores so far: [30.22, 29.62, 34.71, 32.69, 35.45, 25.04, 29.71, 23.47, 26.92, 35.07, 28.78, 25.33, 25.27, 30.7, 27.93, 30.66]
Best score: 35.45
Average of max per entry across top 1 scores: 0.3545240641711229
Average of max per entry across top 2 scores: 0.38235204991087346
Average of max per entry across top 3 scores: 0.4226298276886512
Average of max per entry across top 5 scores: 0.46546316102198454
Average of max per entry across top 8 scores: 0.4871824592675986
Average of max per entry across top 9999 scores: 0.554465287550427


  2%|▏         | 2/100 [00:00<00:28,  3.48it/s]


Bootstrapped 2 full traces after 3 examples in round 0.


Average Metric: 13.475445632798573 / 50  (27.0): 100%|██████████| 50/50 [00:05<00:00,  9.03it/s]


Average Metric: 13.475445632798573 / 50  (27.0%)
Score: 26.95 for set: [2]
Scores so far: [30.22, 29.62, 34.71, 32.69, 35.45, 25.04, 29.71, 23.47, 26.92, 35.07, 28.78, 25.33, 25.27, 30.7, 27.93, 30.66, 26.95]
Best score: 35.45
Average of max per entry across top 1 scores: 0.3545240641711229
Average of max per entry across top 2 scores: 0.38235204991087346
Average of max per entry across top 3 scores: 0.4226298276886512
Average of max per entry across top 5 scores: 0.46546316102198454
Average of max per entry across top 8 scores: 0.4871824592675986
Average of max per entry across top 9999 scores: 0.554465287550427


  2%|▏         | 2/100 [00:00<00:15,  6.27it/s]


Bootstrapped 2 full traces after 3 examples in round 0.


Average Metric: 11.933065953654191 / 50  (23.9): 100%|██████████| 50/50 [00:03<00:00, 16.60it/s]


Average Metric: 11.933065953654191 / 50  (23.9%)
Score: 23.87 for set: [2]
Scores so far: [30.22, 29.62, 34.71, 32.69, 35.45, 25.04, 29.71, 23.47, 26.92, 35.07, 28.78, 25.33, 25.27, 30.7, 27.93, 30.66, 26.95, 23.87]
Best score: 35.45
Average of max per entry across top 1 scores: 0.3545240641711229
Average of max per entry across top 2 scores: 0.38235204991087346
Average of max per entry across top 3 scores: 0.4226298276886512
Average of max per entry across top 5 scores: 0.46546316102198454
Average of max per entry across top 8 scores: 0.4871824592675986
Average of max per entry across top 9999 scores: 0.556965287550427


  1%|          | 1/100 [00:00<01:03,  1.55it/s]


Bootstrapped 1 full traces after 2 examples in round 0.


Average Metric: 16.45903149138443 / 50  (32.9): 100%|██████████| 50/50 [00:10<00:00,  4.98it/s] 


Average Metric: 16.45903149138443 / 50  (32.9%)
Score: 32.92 for set: [1]
Scores so far: [30.22, 29.62, 34.71, 32.69, 35.45, 25.04, 29.71, 23.47, 26.92, 35.07, 28.78, 25.33, 25.27, 30.7, 27.93, 30.66, 26.95, 23.87, 32.92]
Best score: 35.45
Average of max per entry across top 1 scores: 0.3545240641711229
Average of max per entry across top 2 scores: 0.38235204991087346
Average of max per entry across top 3 scores: 0.4226298276886512
Average of max per entry across top 5 scores: 0.4424126559714795
Average of max per entry across top 8 scores: 0.4910793226381462
Average of max per entry across top 9999 scores: 0.556965287550427


  1%|          | 1/100 [00:00<00:07, 13.62it/s]


Bootstrapped 1 full traces after 2 examples in round 0.


Average Metric: 11.264111705288176 / 50  (22.5): 100%|██████████| 50/50 [00:03<00:00, 15.22it/s]

Average Metric: 11.264111705288176 / 50  (22.5%)
Score: 22.53 for set: [1]
Scores so far: [30.22, 29.62, 34.71, 32.69, 35.45, 25.04, 29.71, 23.47, 26.92, 35.07, 28.78, 25.33, 25.27, 30.7, 27.93, 30.66, 26.95, 23.87, 32.92, 22.53]
Best score: 35.45
Average of max per entry across top 1 scores: 0.3545240641711229
Average of max per entry across top 2 scores: 0.38235204991087346
Average of max per entry across top 3 scores: 0.4226298276886512
Average of max per entry across top 5 scores: 0.4424126559714795
Average of max per entry across top 8 scores: 0.4910793226381462
Average of max per entry across top 9999 scores: 0.556965287550427
20 candidate programs found.



