In [1]:
from sagemaker.jumpstart.estimator import JumpStartEstimator

training_job_name = "jumpstart-dft-huggingface-llm-gemma-20240713-173038"
model_id = "huggingface-llm-gemma-7b-instruct"

attached_estimator = JumpStartEstimator.attach(training_job_name, model_id)


sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml


Using model 'huggingface-llm-gemma-7b-instruct' with wildcard version identifier '*'. You can pin to version '1.2.0' for more stable results. Note that models may have different input/output signatures after a major version upgrade.



2024-07-13 17:51:22 Starting - Preparing the instances for training
2024-07-13 17:51:22 Downloading - Downloading the training image
2024-07-13 17:51:22 Training - Training image download completed. Training in progress.
2024-07-13 17:51:22 Uploading - Uploading generated training model
2024-07-13 17:51:22 Completed - Training job completed


In [3]:
instance_type = "ml.g5.12xlarge"

fine_tuned_model = attached_estimator.deploy(instance_type=instance_type)

--------!

In [22]:
import sacrebleu
from rouge import Rouge
import pandas as pd
import re
import json

def extract_content(response_text):
    extract_array = response_text.split("###")
    
    return extract_array[3]



    
def evaluate_testjsonl_with_gemma(reference_path, csv_file_path):
    
    test_data_json = []
    
    with open(reference_path, 'r', encoding='utf-8') as f:
        for line in f:
              test_data_json.append(json.loads(line.strip()))
    
    rouge_score_list = []
    bleu_score_list = []
              
    rouge = Rouge()
              
    prediction_list = []
              
    for single_test in test_data_json:
        instruction = single_test["instruction"]
        whole_letter = single_test["whole_letter"]
        referral_content = single_test["referral_content"]
        prompt = f"{instruction}\n\n###\n\n{whole_letter}\n\n###"
        response = fine_tuned_model.predict({'inputs': prompt, 'parameters': {'max_new_tokens': 256}})
        # response = origin_model.predict({prompt})

        reference_text = referral_content
        # print(response[0])
        try:
            # print("response: " + response)
            # response = reponse[0] if isinstance(response, list) else response
            # print("response: " + response)
            # prediction_dict = json.loads(response["generated_text"].strip())
            prediction_text = response[0]["generated_text"].strip()
            # print(prediction_text)
            
            extracted_referral_content = extract_content(prediction_text)
            # print(extracted_referral_content)
            
            single_test["predict_referral_content"] = extracted_referral_content
           
          
        except Exception as err:
            prediction_text = "extract failure"
        finally:
            prediction_list.append(extracted_referral_content)
            print("prediction: " + extracted_referral_content + "\n")
            print("ground_truth: " + reference_text)
            print("=============================")

        bleu_score = sacrebleu.corpus_bleu([extracted_referral_content], [[reference_text]])
        bleu_score_list.append(bleu_score.score)
        single_test["bleu"] = bleu_score.score

        rouge_score = rouge.get_scores(extracted_referral_content, reference_text)
        rouge_score_list.append(rouge_score)
        
        
        
    # 创建 CSV 文件
    csv_data = []

    for single_test in test_data_json:
        csv_data.append({
            "id": single_test["id"],
            "name": single_test["name"],
            "instruction": single_test["instruction"],
            "whole_letter": single_test["whole_letter"],
            "referral_content": single_test["referral_content"],
            "predict_referral_content": single_test["predict_referral_content"],
            "bleu": single_test["bleu"],
        })

    # 创建 DataFrame
    df = pd.DataFrame(csv_data)

    # 保存为 CSV 文件
    df.to_csv(csv_file_path, index=False, encoding='utf-8')

    print(f"CSV file has been saved to {csv_file_path}")
        
    return bleu_score_list
              
    

In [23]:
bleu_score_list = evaluate_testjsonl_with_gemma("../train_test_data/test.jsonl", "./gemma_trainModel_result.csv")

prediction:  The referral reason is the presence of XFM on the anterior surface of the IOL in the right eye. Given the potential implications and the need for specialized care, I recommend further investigation and management.

ground_truth: The referral reason is the presence of XFM on the anterior surface of the IOL in the right eye. Given the potential implications and the need for specialized care, I recommend further investigation and management.
prediction:  The reason for referral is due to the presence of a neurovascular conflict seen on MRI, which has led to dislocation and atrophy of the trigeminal root, causing significant paroxysmal pain. Given her family history of TN, advanced evaluation and potential surgical intervention might be required.

ground_truth: The reason for referral is due to the presence of a neurovascular conflict seen on MRI, which has led to dislocation and atrophy of the trigeminal root, causing significant paroxysmal pain. Given her family history of T

In [24]:
def analyze_predict_data(bleu_score_list):
    # 统计大于100的个数
    count_gt_100 = sum(1 for score in bleu_score_list if score >= 100)

    # 统计大于70的个数
    count_gt_70 = sum(1 for score in bleu_score_list if score > 70)

    prob_gt_100 = count_gt_100 / len(bleu_score_list)
    prob_gt_70 = count_gt_70 / len(bleu_score_list)
    average_score = sum(bleu_score_list) / float(len(bleu_score_list))

    print(f"分数大于100的个数：{count_gt_100}, 占所有数据的百分比为： {prob_gt_100}")
    print(f"分数大于70的个数：{count_gt_70}, 占所有数据的百分比为： {prob_gt_70}")
    print(f"bleu平均分数: {average_score}")

In [25]:
analyze_predict_data(bleu_score_list)

分数大于100的个数：21, 占所有数据的百分比为： 1.0
分数大于70的个数：21, 占所有数据的百分比为： 1.0
bleu平均分数: 100.00000000000003
