In [14]:
import json
from transformers import AutoTokenizer


tokenizer = AutoTokenizer.from_pretrained("llm_collections/Llama-2-7b-chat-hf")

max_instruction = 0
max_input = 0
max_output = 0
with open("data/trecis/high_sampled_5000_test.json", "r", encoding="utf8") as file:
    for data in json.load(file):
        inst_len = len(tokenizer(data["instruction"]).input_ids)
        max_instruction = max(max_instruction, inst_len)
        input_string = """<s>[INST] <<SYS>>\n\n{system_prompt}\n\n<</SYS>>\n\n{user_message} [/INST]""".format(
            system_prompt="""<s> [INST] <<SYS>>
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.

If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
    user_message=f"{data['instruction']}\n{data['input']}"
        )
        input_len = len(tokenizer(input_string).input_ids)
        max_input = max(max_input, input_len)
        max_output = max(max_output, len(tokenizer(data["output"]).input_ids))


print(f"max_instruction: {max_instruction}")
print(f"max_input:       {max_input}")
print(f"max_output:      {max_output}")

max_instruction: 1131
max_input:       1496
max_output:      66


In [1]:
# easy to fine-tuning llama
from pathlib import Path
import json
from tqdm.auto import tqdm
from sklearn.metrics import f1_score, classification_report
import re
import numpy as np

# fpath = "records/chatglm2/chatglm_zeroshot_high_level/val_results.jsonl"    #  0.28231
# fpath = "records/baichuan/baichuan_zeroshot_high_level/val_results.jsonl"   # 0.33408
# fpath = "records/baichuan/baichuan_zeroshot_high_level_10/val_results.jsonl"   # 0.33408
# fpath = "records/ernie-bot/ernie_bot_zeroshot_high_level/val_results.jsonl"
# fpath = "records/ernie-bot/ernie_bot_zeroshot_high_level_gen_prompt/val_results.jsonl" 


# 忽略大小写
request_pattern = re.compile("request", re.IGNORECASE)
call_to_action_pattern = re.compile(r"call[\-\_\s]?(?:to|for)[\-\_\s]?action", re.IGNORECASE)
report_pattern = re.compile("report", re.IGNORECASE)
other_pattern = re.compile("other", re.IGNORECASE)

true_pattern = re.compile("yes", re.IGNORECASE)
false_pattern = re.compile("no", re.IGNORECASE)

def soft_post_processing_top_level(response: str):
    request = 1.0 if request_pattern.search(response) else 0.0
    call_to_action = 1.0 if call_to_action_pattern.search(response) else 0.0
    report = 1.0 if report_pattern.search(response) or re.search("报告", response) else 0.0
    other = 1.0 if other_pattern.search(response) else 0.0
    # 修正
    if request + call_to_action + report + other == 0 or request + call_to_action + report + other >= 4.0:
        other = 1.0
    
    return [request, call_to_action, report, other]

def multi_label_to_top_level(multi_label):
    request = 1.0 if multi_label[:3].sum() > 0 else 0.0
    call_to_action = 1.0 if multi_label[3:6].sum() > 0 else 0.0
    report = 1.0 if multi_label[6:-5].sum() > 0 else 0.0
    other = 1.0 if multi_label[-5:].sum() > 0 else 0.0
    return [request, call_to_action, report, other]

high_level_info_types = [
    'Request-GoodsServices',
    'Request-SearchAndRescue',
    'Request-InformationWanted',
    'CallToAction-Volunteer',
    'CallToAction-Donations',
    'CallToAction-MovePeople',
    'Report-FirstPartyObservation',
    'Report-ThirdPartyObservation',
    'Report-Weather',
    'Report-Location',
    'Report-EmergingThreats',
    'Report-NewSubEvent',
    'Report-MultimediaShare',
    'Report-ServiceAvailable',
    'Report-Factoid',
    'Report-Official',
    'Report-News',
    'Report-CleanUp',
    'Report-Hashtags',
    'Report-OriginalEvent',
    'Other-ContextualInformation',
    'Other-Advice',
    'Other-Sentiment',
    'Other-Discussion',
    'Other-Irrelevant',
    ]

def soft_post_processing_high_level(response: str):
    high_level_flag_list = [(info_type in response) for info_type in high_level_info_types]
    return high_level_flag_list

def soft_post_processing_inference(response: str):
    if "I'm just an AI" in response:
        # a bad response
        return 0.0

    true_flag = true_pattern.search(response)
    return 1.0 if true_flag else 0.0

def report_result_llama(fpath, intent_level="top", task_type="mc"):
    soft_post_processing_map = {
        "high": (soft_post_processing_high_level, high_level_info_types),
        "top": (soft_post_processing_top_level, ["request", "call-to-action", "report", "other"])
    }
    soft_post_processing_fct, intent_labels = soft_post_processing_map[intent_level]

    if task_type == "if":
        soft_post_processing_fct = soft_post_processing_inference

    if isinstance(fpath, str):
        fpath = [fpath]
    lines = []
    for fp in fpath:
        with open(fp, "r", encoding="utf8") as file:
            lines.extend(file.readlines())

    predictions = []
    ground_truths = []
    bad_response_counter = 0
    for line_str in tqdm(lines):
        line = json.loads(line_str)
        response = line["predict"]
        try:
            pred_label = soft_post_processing_fct(response)
        except Exception as e:
            print(f"【{line['label']}: {response}】 {e}")
        else:
            if "I'm just an AI" in response:
                bad_response_counter += 1
            predictions.append(pred_label)
            ground_truths.append(soft_post_processing_fct(line["label"]))
    
    print(f"parsered : total = {len(predictions)}: {len(lines)}; num bad response = {bad_response_counter}")
    if task_type == "if":
        ground_truths = np.array(ground_truths).reshape(-1, 25)
        predictions = np.array(predictions).reshape(-1, 25)
        for idx in range(predictions.shape[0]):
            if all(predictions[idx] == 0.0):
                # 全 false
                predictions[idx, -1] = 1.0
            elif predictions[idx, -1] and any(predictions[idx, :-1]):
                # 同时有 Irrelevant 和 其他标签
                predictions[idx, -1] = 0.0

    print(ground_truths.shape)
    f1 = f1_score(ground_truths, predictions, average="macro")
    print("macro f1: {}".format(f1))

    report = classification_report(ground_truths, predictions, target_names=intent_labels , digits=5)
    print(report)
    return f1

# for exp_dir in Path("records/chatglm2/").glob("chatglm2_zeroshot_high_level_*"):
#     print(exp_dir.name)
#     report_result(exp_dir / "val_results.jsonl")

# exp_dir = Path("saves/LLaMA2-7B-Chat/lora/trecis-mc_if-train-random-1")
# f1_scores = []
# for level in ["top", "high"]:
#     for ds in ["val", "test"]:
#         print(f"{exp_dir.parents[1]} instruction tuning on {ds} set ({level})")
#         f1 = report_result_llama(
#             exp_dir / f"results_{ds}_{level}" / "generated_predictions.jsonl",
#             intent_level=level
#             )
#         f1_scores.append(f1)

# print("{:.6f}|{:.6f}|{:.6f}|{:.6f}|".format(*f1_scores))

train_task_tag = "trecis-mc-train-random-1"
f1 = report_result_llama(
    [
        f"saves/LLaMA2-7B-Chat/lora/{train_task_tag}/results_if_val_high_0/generated_predictions.jsonl",
        f"saves/LLaMA2-7B-Chat/lora/{train_task_tag}/results_if_val_high_1/generated_predictions.jsonl",
    ],
    intent_level="high",
    task_type="if"
    )
print(f"{f1:.6f}")

f1 = report_result_llama(
    [
        f"saves/LLaMA2-7B-Chat/lora/{train_task_tag}/results_if_test_high_0/generated_predictions.jsonl",
        f"saves/LLaMA2-7B-Chat/lora/{train_task_tag}/results_if_test_high_1/generated_predictions.jsonl",
    ],
    intent_level="high",
    task_type="if"
    )
print(f"{f1:.6f}")

# f1 = report_result_llama(
#     [
#         "saves/LLaMA2-7B-Chat/zero_shot/trecis-if/results_if_val_high/generated_predictions.jsonl",
#         "saves/LLaMA2-7B-Chat/zero_shot/trecis-if/results_if_val_high_1/generated_predictions.jsonl",
#     ],
#     intent_level="high",
#     task_type="if"
#     )
# print(f"{f1:.6f}")

# f1 = report_result_llama(
#     [
#         "saves/LLaMA2-7B-Chat/zero_shot/trecis-if/results_if_test_high/generated_predictions.jsonl",
#         "saves/LLaMA2-7B-Chat/zero_shot/trecis-if/results_if_test_high_1/generated_predictions.jsonl",
#     ],
#     intent_level="high",
#     task_type="if"
#     )
# print(f"{f1:.6f}")

# print("llama-2-7B-chat zero-shot on val set (top)")
# report_result_llama(Path("saves/LLaMA2-7B-Chat/zero_shot/trecis-mc/results_val_top/generated_predictions.jsonl"))
# print("llama-2-7B-chat zero-shot on sampled test set (top)")
# report_result_llama(Path("saves/LLaMA2-7B-Chat/zero_shot/trecis-mc/results_test_top/generated_predictions.jsonl"))

# print("llama-2-7B-chat zero-shot on val set (high)")
# report_result_llama(
#     Path("saves/LLaMA2-7B-Chat/zero_shot/trecis-mc/results_val_high/generated_predictions.jsonl"),
#     intent_level="high"
#     )
# print("llama-2-7B-chat zero-shot on sampled test set (high)")
# report_result_llama(
#     Path("saves/LLaMA2-7B-Chat/zero_shot/trecis-mc/results_test_high/generated_predictions.jsonl"),
#     intent_level="high"
#     )


  0%|          | 0/169700 [00:00<?, ?it/s]

parsered : total = 169700: 169700; num bad response = 0
(6788, 25)
macro f1: 0.4044971808933921
                              precision    recall  f1-score   support

       Request-GoodsServices    0.05263   0.04348   0.04762        23
     Request-SearchAndRescue    0.41379   0.40000   0.40678        30
   Request-InformationWanted    0.19403   0.28261   0.23009        46
      CallToAction-Volunteer    0.26786   0.46875   0.34091        32
      CallToAction-Donations    0.50259   0.77600   0.61006       125
     CallToAction-MovePeople    0.33523   0.76623   0.46640        77
Report-FirstPartyObservation    0.16742   0.63731   0.26518       579
Report-ThirdPartyObservation    0.34629   0.94336   0.50661      1889
              Report-Weather    0.54167   0.80984   0.64915       915
             Report-Location    0.49360   0.94992   0.64963      2556
      Report-EmergingThreats    0.36211   0.77139   0.49286       783
          Report-NewSubEvent    0.18408   0.12803   0.15102    

  0%|          | 0/125000 [00:00<?, ?it/s]

parsered : total = 125000: 125000; num bad response = 0
(5000, 25)
macro f1: 0.413598243324785
                              precision    recall  f1-score   support

       Request-GoodsServices    0.53846   0.24138   0.33333        87
     Request-SearchAndRescue    0.15686   0.29630   0.20513        27
   Request-InformationWanted    0.48810   0.47953   0.48378       171
      CallToAction-Volunteer    0.32673   0.54098   0.40741        61
      CallToAction-Donations    0.47668   0.78632   0.59355       117
     CallToAction-MovePeople    0.46067   0.60073   0.52146       273
Report-FirstPartyObservation    0.15480   0.68120   0.25227       367
Report-ThirdPartyObservation    0.24742   0.88638   0.38686      1109
              Report-Weather    0.60806   0.77066   0.67977      1077
             Report-Location    0.66674   0.93473   0.77831      3187
      Report-EmergingThreats    0.28418   0.79673   0.41894       733
          Report-NewSubEvent    0.13876   0.13551   0.13712     

  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
