In [6]:
from langchain_ollama import OllamaLLM
import evaluate
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
llm = OllamaLLM(model="hf.co/sathvik123/llama3-ChatDoc")

In [8]:
llm.invoke("What are the treatments for vitamin D-dependent rickets?")

"Vitamin D-dependent rickets is a type of inherited disorder that affects bone development. It can be treated in several ways, including:\n\n1. Vitamin D supplements: Patients with vitamin D-dependent rickets may need to take large doses of vitamin D supplements throughout their lives.\n2. Calcitriol supplements: Calcitriol is the active form of vitamin D that the body uses to regulate calcium levels. Taking calcitriol can help to correct calcium imbalances and improve bone mineralization.\n3. Vitamin D analogs: Some patients with vitamin D-dependent rickets may benefit from taking vitamin D analogs, such as paricalcitol, which are more selective for vitamin D receptors in the bones.\n4. Bisphosphonates: These medications can help to reduce calcium excretion and improve bone mineralization.\n5. Hormone therapy: Some patients with vitamin D-dependent rickets may also require hormone therapy, including thyroid hormone replacement if they have concurrent hypothyroidism.\n6. Parathyroid su

In [22]:
def generate_predictions(llm, dataset, prompt_template):
    predictions = []
    references = []

    for row in dataset["input"]:
        input = prompt_template.format(row)
        response = llm.invoke(input)
        predictions.append(response.strip())

    for row in dataset["output"]:
        references.append(row.strip())

    return predictions, references

In [10]:
def compute_scores(predictions, references):
    rouge_metric = evaluate.load("rouge")
    #bleu_metric = evaluate.load("bleu")
    meteor_metric = evaluate.load("meteor")

    rouge_score = rouge_metric.compute(predictions=predictions, references=references)
    #bleu_score = bleu_metric.compute(predictions=predictions, references=references)
    meteor_score = meteor_metric.compute(predictions=predictions, references=references)

    return {
        "ROUGE": rouge_score,
        "METEOR": meteor_score
    }

In [20]:
my_prompt = """
<|start_header_id|>system<|end_header_id|>If you are a doctor, please answer the medical questions based on the patient's description.<|eot_id|><|start_header_id|>user<|end_header_id|> This is the question: {}<|eot_id|>
"""

In [12]:
data = load_dataset("sathvik123/llama3-medical-dataset", split = "test")

In [25]:
# Generate predictions
predictions, references = generate_predictions(llm, data[:100], my_prompt)

In [26]:
compute_scores(predictions, references)

[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\sathv\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt_tab to
[nltk_data]     C:\Users\sathv\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package omw-1.4 to
[nltk_data]     C:\Users\sathv\AppData\Roaming\nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


{'ROUGE': {'rouge1': 0.2928535355420423,
  'rouge2': 0.03969550563416629,
  'rougeL': 0.14719890945601122,
  'rougeLsum': 0.1497006169422635},
 'METEOR': {'meteor': 0.2438378550734367}}