In [1]:
from langchain_ollama import OllamaLLM
import evaluate
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
llm = OllamaLLM(model="hf.co/sathvik123/llama3-ChatDoc")

In [5]:
llm.invoke("What are the treatments for vitamin D-dependent rickets?")

'Vitamin D-dependent rickets is a genetic disorder that affects bone mineralization and density. It can be divided into two forms:\n1. Vitamin D-dependent rickets type 1 (VDDR1): This form results from a deficiency of the enzyme 25-hydroxyvitamin D-1 alpha-hydroxylase, which converts vitamin D to its active metabolite.\n2. Vitamin D-dependent rickets type 2 (VDDR2): This form is caused by resistance to the action of vitamin D in bone cells.\n\nThe recommended treatments for both forms are as follows:\nA. Vitamin D supplementation: Vitamin D and its analogues (e.g., calcitriol) can be used to treat VDR1.\nB. Calciferol or cholecalciferol supplements: Vitamin D preparations are administered orally, usually in the form of ergocalciferol, cholecalciferol, or dihydrotachysterone (DHT).\nC. Calcium supplementation: Calcium is important for bone health; patients with vitamin D-dependent rickets should take calcium-rich foods like dairy products, fish, and leafy greens.\nD. Phosphate therapy: 

In [6]:
def generate_predictions(llm, dataset, prompt_template):
    predictions = []
    references = []

    for row in dataset["input"]:
        input = prompt_template.format(row)
        response = llm.invoke(input)
        predictions.append(response.strip())

    for row in dataset["output"]:
        references.append(row.strip())

    return predictions, references

In [7]:
def compute_scores(predictions, references):
    rouge_metric = evaluate.load("rouge")
    bleu_metric = evaluate.load("bleu")
    meteor_metric = evaluate.load("meteor")

    rouge_score = rouge_metric.compute(predictions=predictions, references=references)
    bleu_score = bleu_metric.compute(predictions=predictions, references=references)
    meteor_score = meteor_metric.compute(predictions=predictions, references=references)

    return {
        "ROUGE": rouge_score,
        "BLEU": bleu_score,
        "METEOR": meteor_score
    }

In [8]:
my_prompt = """
instruction:
You are a doctor, please answer the medical questions based on the patient's description. Answer clearly and professionally.
question:{}
"""

In [9]:
data = load_dataset("sathvik123/llama3-medical-dataset", split = "test")

Generating train split: 100%|██████████| 89732/89732 [00:00<00:00, 287644.07 examples/s]
Generating test split: 100%|██████████| 22433/22433 [00:00<00:00, 305145.90 examples/s]


In [12]:
# Generate predictions
predictions, references = generate_predictions(llm, data[:100], my_prompt)

In [13]:
compute_scores(predictions, references)

[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\sathv\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt_tab to
[nltk_data]     C:\Users\sathv\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package omw-1.4 to
[nltk_data]     C:\Users\sathv\AppData\Roaming\nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


{'ROUGE': {'rouge1': 0.28727765213178447,
  'rouge2': 0.03895816149286061,
  'rougeL': 0.1431158009622706,
  'rougeLsum': 0.15246203020417862},
 'BLEU': {'bleu': 0.016538649204154533,
  'precisions': [0.2527598896044158,
   0.03498727735368957,
   0.00657282456956724,
   0.0012871518839223028],
  'brevity_penalty': 1.0,
  'length_ratio': 1.5229422066549911,
  'translation_length': 17392,
  'reference_length': 11420},
 'METEOR': {'meteor': 0.25571056331054287}}