In [1]:
!pip install faiss-cpu
!pip install sentence-transformers


Collecting faiss-cpu
  Downloading faiss_cpu-1.10.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (4.4 kB)
Downloading faiss_cpu-1.10.0-cp311-cp311-manylinux_2_28_x86_64.whl (30.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m30.7/30.7 MB[0m [31m57.4 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: faiss-cpu
Successfully installed faiss-cpu-1.10.0
[0mCollecting sentence-transformers
  Downloading sentence_transformers-4.0.2-py3-none-any.whl.metadata (13 kB)
Downloading sentence_transformers-4.0.2-py3-none-any.whl (340 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m340.6/340.6 kB[0m [31m8.3 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: sentence-transformers
Successfully installed sentence-transformers-4.0.2
[0m

In [4]:
import pandas as pd
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np

df = pd.read_csv("./unique_diseases_with_NHS.csv", encoding='latin1')
df["Disease"] = df["Disease"].str.strip()
df["Symptoms"] = df["Symptoms"].str.strip()
df["Guidelines"] = df["Guidelines"].str.strip()

df["entry_text"] = df.apply(lambda row: f"Disease: {row['Disease']}. Symptoms: {row['Symptoms']}. Guidelines: {row['Guidelines']}.", axis=1)
corpus = df["entry_text"].tolist()
print("构造了 {} 条知识条目.".format(len(corpus)))

embedder = SentenceTransformer("all-MiniLM-L6-v2")
corpus_embeddings = embedder.encode(corpus, convert_to_numpy=True)
embedding_dim = corpus_embeddings.shape[1]
index = faiss.IndexFlatL2(embedding_dim)
index.add(corpus_embeddings)
print("FAISS 索引构建成功，包含 {} 个条目。".format(index.ntotal))

query = "chest pain, shortness of breath, and dizziness"
query_embedding = embedder.encode([query], convert_to_numpy=True)
distances, indices = index.search(np.array(query_embedding).astype('float32'), k=5)
print("\n检索到的最近邻索引:", indices[0])
print("对应的知识条目：")
for idx in indices[0]:
    print(corpus[idx])

构造了 60 条知识条目.


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.5k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

FAISS 索引构建成功，包含 60 个条目。

检索到的最近邻索引: [27 54  4 21 35]
对应的知识条目：
Disease: heart attack. Symptoms: You might also feel short of breath, sweat a lot, feel nauseous, or even vomit. Some people experience a fast or irregular heartbeat, feel very tired, or have a sense of anxiety or doom.. Guidelines: If you think you or someone else is having a heart attack, call 911 right away. While waiting for help, chew an aspirin (unless you are allergic) to help prevent blood clots..
Disease: shortness of breath. Symptoms: Shortness of breath, also called dyspnea, can feel like tightness in the chest, difficulty breathing in fully, or a sensation of not getting enough air.. Guidelines: First, I¡¯ll want to check your heart, lungs, and oxygen levels. Treatment depends on the cause ¡ª could be an inhaler for asthma, diuretics for heart failure, antibiotics for an infection, or even blood thinners if there¡¯s a clot..
Disease: hypertension. Symptoms: Symptoms of hypertension include headaches, dizziness, c

In [54]:
import re

def extract_symptoms(text):
    pattern = r"Instruction: The patient is experiencing The patient is experiencing (.+?)\.\. What is the most likely diagnosis"
    match = re.search(pattern, text)
    if match:
        return match.group(1).strip()
    return None

test_df = pd.read_csv("./Final_Prompt-Style_Test_Set.csv")
print("测试数据集列：", test_df.columns.tolist())
test_df['symptoms'] = test_df['prompt'].apply(extract_symptoms)
test_df

测试数据集列： ['prompt', 'diagnosis']


Unnamed: 0,prompt,diagnosis,symptoms
0,Instruction: The patient is experiencing The p...,Acne,"skin rash, blackheads, scurring"
1,Instruction: The patient is experiencing The p...,Acne,"skin rash, pus filled pimples, blackheads, scu..."
2,Instruction: The patient is experiencing The p...,Hyperthyroidism,"fatigue, mood swings, weight loss, restlessnes..."
3,Instruction: The patient is experiencing The p...,AIDS,"muscle wasting, patches in throat, high fever,..."
4,Instruction: The patient is experiencing The p...,Chronic cholestasis,"itching, vomiting, yellowish skin, nausea, los..."
...,...,...,...
979,Instruction: The patient is experiencing The p...,Dimorphic hemmorhoids(piles),"constipation, pain during bowel movements, pai..."
980,Instruction: The patient is experiencing The p...,AIDS,"muscle wasting, patches in throat, high fever,..."
981,Instruction: The patient is experiencing The p...,Dengue,"skin rash, chills, joint pain, vomiting, fatig..."
982,Instruction: The patient is experiencing The p...,Acne,"skin rash, pus filled pimples, scurring"


In [6]:
def retrieve_knowledge(user_input, top_k=3):
    query_embedding = embedder.encode([user_input], convert_to_numpy=True)
    distances, indices = index.search(np.array(query_embedding).astype('float32'), top_k)
    retrieved_texts = [corpus[idx] for idx in indices[0]]
    return retrieved_texts

In [25]:
# symptoms = "chest pain, shortness of breath, and dizziness"
# retrieved_context = retrieve_knowledge(symptoms, top_k=3)
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch
from huggingface_hub import login

login(token="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")

In [55]:
symptoms = "chest pain, shortness of breath, and dizziness"
retrieved_context = retrieve_knowledge(symptoms, top_k=3)

def get_context(symptom_text):
    retrieved = retrieve_knowledge(symptom_text, top_k=3)
    return " ||| ".join(retrieved) 
test_df['retrieved_context'] = test_df['symptoms'].apply(get_context)
test_df

Unnamed: 0,prompt,diagnosis,symptoms,retrieved_context
0,Instruction: The patient is experiencing The p...,Acne,"skin rash, blackheads, scurring",Disease: fungal infection. Symptoms: On the sk...
1,Instruction: The patient is experiencing The p...,Acne,"skin rash, pus filled pimples, blackheads, scu...",Disease: acne. Symptoms: Symptoms of acne?incl...
2,Instruction: The patient is experiencing The p...,Hyperthyroidism,"fatigue, mood swings, weight loss, restlessnes...","Disease: stress, anxiety and low mood. Symptom..."
3,Instruction: The patient is experiencing The p...,AIDS,"muscle wasting, patches in throat, high fever,...",Disease: common cold. Symptoms: Typical sympto...
4,Instruction: The patient is experiencing The p...,Chronic cholestasis,"itching, vomiting, yellowish skin, nausea, los...",Disease: typhoid. Symptoms: Common signs inclu...
...,...,...,...,...
979,Instruction: The patient is experiencing The p...,Dimorphic hemmorhoids(piles),"constipation, pain during bowel movements, pai...",Disease: dimorphic hemmorhoids(piles). Symptom...
980,Instruction: The patient is experiencing The p...,AIDS,"muscle wasting, patches in throat, high fever,...",Disease: common cold. Symptoms: Typical sympto...
981,Instruction: The patient is experiencing The p...,Dengue,"skin rash, chills, joint pain, vomiting, fatig...",Disease: typhoid. Symptoms: Common signs inclu...
982,Instruction: The patient is experiencing The p...,Acne,"skin rash, pus filled pimples, scurring",Disease: fungal infection. Symptoms: On the sk...


In [27]:
final_prompt = f"""Instruction: The patient is experiencing {symptoms}.
Context: {" ".join(retrieved_context)}
What is the most likely diagnosis and what do you recommend?
Output:"""

print("\n最终 Prompt：")
print(final_prompt)


最终 Prompt：
Instruction: The patient is experiencing chest pain, shortness of breath, and dizziness.
Context: Disease: heart attack. Symptoms: You might also feel short of breath, sweat a lot, feel nauseous, or even vomit. Some people experience a fast or irregular heartbeat, feel very tired, or have a sense of anxiety or doom.. Guidelines: If you think you or someone else is having a heart attack, call 911 right away. While waiting for help, chew an aspirin (unless you are allergic) to help prevent blood clots.. Disease: shortness of breath. Symptoms: Shortness of breath, also called dyspnea, can feel like tightness in the chest, difficulty breathing in fully, or a sensation of not getting enough air.. Guidelines: First, I¡¯ll want to check your heart, lungs, and oxygen levels. Treatment depends on the cause ¡ª could be an inhaler for asthma, diuretics for heart failure, antibiotics for an infection, or even blood thinners if there¡¯s a clot.. Disease: hypertension. Symptoms: Symptoms

In [30]:
model_path = "./lora_llama_medical_finetuned"
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.float16,
    device_map="auto"
)

tokenizer = AutoTokenizer.from_pretrained(model_path)
inputs_final = tokenizer(final_prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
    output_ids_final = model.generate(
        **inputs_final,
        max_new_tokens=200,
        do_sample=True,
        temperature=0.7,
        top_p=0.9,
        eos_token_id=tokenizer.eos_token_id
    )
output_text_final = tokenizer.decode(output_ids_final[0], skip_special_tokens=True)
print("\n最终生成输出：")
print(output_text_final)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.



最终生成输出：
Instruction: The patient is experiencing chest pain, shortness of breath, and dizziness.
Context: Disease: heart attack. Symptoms: You might also feel short of breath, sweat a lot, feel nauseous, or even vomit. Some people experience a fast or irregular heartbeat, feel very tired, or have a sense of anxiety or doom.. Guidelines: If you think you or someone else is having a heart attack, call 911 right away. While waiting for help, chew an aspirin (unless you are allergic) to help prevent blood clots.. Disease: shortness of breath. Symptoms: Shortness of breath, also called dyspnea, can feel like tightness in the chest, difficulty breathing in fully, or a sensation of not getting enough air.. Guidelines: First, I¡¯ll want to check your heart, lungs, and oxygen levels. Treatment depends on the cause ¡ª could be an inhaler for asthma, diuretics for heart failure, antibiotics for an infection, or even blood thinners if there¡¯s a clot.. Disease: hypertension. Symptoms: Symptoms of

In [56]:
def generate_output(symptoms, retrieved_context):
    final_prompt = f"""Instruction: The patient is experiencing {symptoms}.
Context: {" ".join(retrieved_context)}
What is the most likely diagnosis and what do you recommend?
Output:"""
    
    inputs = tokenizer(final_prompt, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        output_ids = model.generate(
            **inputs,
            max_new_tokens=200,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            eos_token_id=tokenizer.eos_token_id
        )
    
    output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    return output_text

def get_generated_output(row):
    symptoms = row["symptoms"]
    retrieved_context = row["retrieved_context"]
    if isinstance(retrieved_context, str):
        retrieved_context = retrieved_context.split(" ||| ")
    return generate_output(symptoms, retrieved_context)

test_df['generated_output'] = test_df.apply(get_generated_output, axis=1)
test_df

Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128009 for

Unnamed: 0,prompt,diagnosis,symptoms,retrieved_context,generated_output
0,Instruction: The patient is experiencing The p...,Acne,"skin rash, blackheads, scurring",Disease: fungal infection. Symptoms: On the sk...,Instruction: The patient is experiencing skin ...
1,Instruction: The patient is experiencing The p...,Acne,"skin rash, pus filled pimples, blackheads, scu...",Disease: acne. Symptoms: Symptoms of acne?incl...,Instruction: The patient is experiencing skin ...
2,Instruction: The patient is experiencing The p...,Hyperthyroidism,"fatigue, mood swings, weight loss, restlessnes...","Disease: stress, anxiety and low mood. Symptom...",Instruction: The patient is experiencing fatig...
3,Instruction: The patient is experiencing The p...,AIDS,"muscle wasting, patches in throat, high fever,...",Disease: common cold. Symptoms: Typical sympto...,Instruction: The patient is experiencing muscl...
4,Instruction: The patient is experiencing The p...,Chronic cholestasis,"itching, vomiting, yellowish skin, nausea, los...",Disease: typhoid. Symptoms: Common signs inclu...,Instruction: The patient is experiencing itchi...
...,...,...,...,...,...
979,Instruction: The patient is experiencing The p...,Dimorphic hemmorhoids(piles),"constipation, pain during bowel movements, pai...",Disease: dimorphic hemmorhoids(piles). Symptom...,Instruction: The patient is experiencing const...
980,Instruction: The patient is experiencing The p...,AIDS,"muscle wasting, patches in throat, high fever,...",Disease: common cold. Symptoms: Typical sympto...,Instruction: The patient is experiencing muscl...
981,Instruction: The patient is experiencing The p...,Dengue,"skin rash, chills, joint pain, vomiting, fatig...",Disease: typhoid. Symptoms: Common signs inclu...,Instruction: The patient is experiencing skin ...
982,Instruction: The patient is experiencing The p...,Acne,"skin rash, pus filled pimples, scurring",Disease: fungal infection. Symptoms: On the sk...,Instruction: The patient is experiencing skin ...


In [31]:
symptoms = "chest pain, shortness of breath, and dizziness"
prompt = f"""Instruction: The patient is experiencing {symptoms}. What is the most likely diagnosis and what do you recommend?
Output:"""



inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

with torch.no_grad():
    output_ids = model.generate(
        **inputs,
        max_new_tokens=200,
        do_sample=True,
        temperature=0.7,
        top_p=0.9,
        eos_token_id=tokenizer.eos_token_id
    )

output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print("生成输出：")
print(output_text)

Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.


生成输出：
Instruction: The patient is experiencing chest pain, shortness of breath, and dizziness. What is the most likely diagnosis and what do you recommend?
Output: Diagnosis: Heart attack
Advice:
- call ambulance
- chew or swallow asprin
- keep calm
- keep breathing
- keep still
- avoid heavy lifting
- keep informed
- keep follow up
- keep medicine
- keep follow up
- keep emergency contact
- keep record
- keep follow up
- keep follow up
- keep follow up
- keep follow up
- keep follow up
- keep follow up
- keep follow up
- keep follow up
- keep follow up
- keep follow up
- keep follow up
- keep follow up
- keep follow up
- keep follow up
- keep follow up
- keep follow up
- keep follow up
- keep follow up
- keep follow up
- keep follow up
- keep follow up
- keep follow up
- keep follow up
- keep follow up
- keep follow up
- keep follow up
- keep follow up
- keep


In [47]:
result_df = pd.DataFrame({
    "prompt": [prompt],
    "generated_output": [output_text],
    "diagnosis": ["hypertension"]
})

output_csv = "./Final_Prompt-Style_Test_Set_with_generated_April3.csv"
result_df.to_csv(output_csv, index=False)
print("生成的结果已保存到:", output_csv)


import sacrebleu
from rapidfuzz import fuzz
from rouge_score import rouge_scorer
import nltk
from nltk.tokenize import word_tokenize
from nltk.translate.meteor_score import meteor_score

# nltk.download('punkt')

def extract_diagnosis(output_text):
    if "Doctor:" in output_text:
        diag = output_text.split("Doctor:")[-1].split("\n")[0].strip().lower()
    elif "Output:" in output_text:
        diag = output_text.split("Output:")[-1].split("\n")[0].strip().lower()
    else:
        diag = output_text.strip().lower()
    return diag

df_eval = pd.read_csv(output_csv)
df_eval["extracted_diagnosis"] = df_eval["generated_output"].apply(extract_diagnosis)
df_eval["diagnosis_lower"] = df_eval["diagnosis"].str.lower()


def fuzzy_match_accuracy(ground_truths, predictions, threshold=80):
    matches = 0
    for gt, pred in zip(ground_truths, predictions):
        score = fuzz.ratio(gt, pred)
        if score >= threshold:
            matches += 1
    return matches / len(ground_truths)

gt_list = df_eval["diagnosis_lower"].tolist()
pred_list = df_eval["extracted_diagnosis"].tolist()
fuzzy_acc = fuzzy_match_accuracy(gt_list, pred_list, threshold=80)
print("Fuzzy Matching Accuracy: {:.2f}%".format(fuzzy_acc * 100))

bleu = sacrebleu.corpus_bleu(pred_list, [gt_list])
print("Corpus BLEU Score: {:.2f}".format(bleu.score))

def compute_average_rouge(ground_truths, predictions):
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)
    rouge1_f1, rougel_f1 = [], []
    for gt, pred in zip(ground_truths, predictions):
        scores = scorer.score(gt, pred)
        rouge1_f1.append(scores['rouge1'].fmeasure)
        rougel_f1.append(scores['rougeL'].fmeasure)
    avg_rouge1 = sum(rouge1_f1) / len(rouge1_f1)
    avg_rougel = sum(rougel_f1) / len(rougel_f1)
    return avg_rouge1, avg_rougel

avg_rouge1, avg_rougel = compute_average_rouge(gt_list, pred_list)
print("Average ROUGE-1 F1: {:.2f}".format(avg_rouge1))
print("Average ROUGE-L F1: {:.2f}".format(avg_rougel))

def compute_average_meteor(ground_truths, predictions):
    scores = []
    for gt, pred in zip(ground_truths, predictions):
        gt_tokens = word_tokenize(gt)
        pred_tokens = word_tokenize(pred)
        scores.append(meteor_score([gt_tokens], pred_tokens))
    return sum(scores) / len(scores)

avg_meteor = compute_average_meteor(gt_list, pred_list)
print("Average METEOR Score: {:.2f}".format(avg_meteor))

生成的结果已保存到: ./Final_Prompt-Style_Test_Set_with_generated_April3.csv
Fuzzy Matching Accuracy: 0.00%
Corpus BLEU Score: 0.00
Average ROUGE-1 F1: 0.00
Average ROUGE-L F1: 0.00


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


Average METEOR Score: 0.00


In [2]:
!pip install rapidfuzz
!pip install rouge-score
nltk.download('punkt')
nltk.download('punkt_tab')
nltk.download('wordnet')

Collecting rapidfuzz
  Downloading rapidfuzz-3.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Downloading rapidfuzz-3.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m93.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: rapidfuzz
Successfully installed rapidfuzz-3.13.0
[0mCollecting rouge-score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting absl-py (from rouge-score)
  Downloading absl_py-2.2.2-py3-none-any.whl.metadata (2.6 kB)
Collecting nltk (from rouge-score)
  Downloading nltk-3.9.1-py3-none-any.whl.metadata (2.9 kB)
Downloading absl_py-2.2.2-py3-none-any.whl (135 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m135.6/135.6 kB[0m [31m10.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading nltk-3.9.1-py3-none-any.whl (1.5 MB)
[2K   [90m━━━━━━━━

NameError: name 'nltk' is not defined

In [49]:
df_eval

Unnamed: 0,prompt,generated_output,diagnosis,extracted_diagnosis,diagnosis_lower
0,Instruction: The patient is experiencing chest...,Instruction: The patient is experiencing chest...,hypertension,diagnosis: heart attack,hypertension


In [57]:
def build_prompt(row):
    symptoms = row['symptoms']
    context = row['retrieved_context']
    if isinstance(context, str):
        context = context.split(" ||| ")
    prompt = f"""Instruction: The patient is experiencing {symptoms}.
Context: {" ".join(context)}
What is the most likely diagnosis and what do you recommend?
Output:"""
    return prompt

test_df['prompt'] = test_df.apply(build_prompt, axis=1)

result_df = test_df[['prompt', 'generated_output', 'diagnosis']].copy()

output_csv = "./Final_Prompt-Style_Test_Set_with_generated_April3.csv"
result_df.to_csv(output_csv, index=False)
print("生成的结果已保存到:", output_csv)


import sacrebleu
from rapidfuzz import fuzz
from rouge_score import rouge_scorer
import nltk
from nltk.tokenize import word_tokenize
from nltk.translate.meteor_score import meteor_score

# nltk.download('punkt')

def extract_diagnosis(output_text):
    if "Doctor:" in output_text:
        diag = output_text.split("Doctor:")[-1].split("\n")[0].strip().lower()
    elif "Output:" in output_text:
        diag = output_text.split("Output:")[-1].split("\n")[0].strip().lower()
    else:
        diag = output_text.strip().lower()
    return diag

df_eval = pd.read_csv(output_csv)
df_eval["extracted_diagnosis"] = df_eval["generated_output"].apply(extract_diagnosis)
df_eval["diagnosis_lower"] = df_eval["diagnosis"].str.lower()


def fuzzy_match_accuracy(ground_truths, predictions, threshold=80):
    matches = 0
    for gt, pred in zip(ground_truths, predictions):
        score = fuzz.ratio(gt, pred)
        if score >= threshold:
            matches += 1
    return matches / len(ground_truths)

gt_list = df_eval["diagnosis_lower"].tolist()
pred_list = df_eval["extracted_diagnosis"].tolist()
fuzzy_acc = fuzzy_match_accuracy(gt_list, pred_list, threshold=80)
print("Fuzzy Matching Accuracy: {:.2f}%".format(fuzzy_acc * 100))

bleu = sacrebleu.corpus_bleu(pred_list, [gt_list])
print("Corpus BLEU Score: {:.2f}".format(bleu.score))

def compute_average_rouge(ground_truths, predictions):
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)
    rouge1_f1, rougel_f1 = [], []
    for gt, pred in zip(ground_truths, predictions):
        scores = scorer.score(gt, pred)
        rouge1_f1.append(scores['rouge1'].fmeasure)
        rougel_f1.append(scores['rougeL'].fmeasure)
    avg_rouge1 = sum(rouge1_f1) / len(rouge1_f1)
    avg_rougel = sum(rougel_f1) / len(rougel_f1)
    return avg_rouge1, avg_rougel

avg_rouge1, avg_rougel = compute_average_rouge(gt_list, pred_list)
print("Average ROUGE-1 F1: {:.2f}".format(avg_rouge1))
print("Average ROUGE-L F1: {:.2f}".format(avg_rougel))

def compute_average_meteor(ground_truths, predictions):
    scores = []
    for gt, pred in zip(ground_truths, predictions):
        gt_tokens = word_tokenize(gt)
        pred_tokens = word_tokenize(pred)
        scores.append(meteor_score([gt_tokens], pred_tokens))
    return sum(scores) / len(scores)

avg_meteor = compute_average_meteor(gt_list, pred_list)
print("Average METEOR Score: {:.2f}".format(avg_meteor))

生成的结果已保存到: ./Final_Prompt-Style_Test_Set_with_generated_April3.csv
Fuzzy Matching Accuracy: 4.67%
Corpus BLEU Score: 9.28
Average ROUGE-1 F1: 0.27
Average ROUGE-L F1: 0.27
Average METEOR Score: 0.22


In [58]:
df_eval.to_csv("./df_eval.csv", index=False, encoding="utf-8")