# Load data and Transform Data

In [None]:
!pip install rouge_score bert_score openai

In [None]:
!pip install -qqq bitsandbytes torch transformers peft accelerate datasets loralib einops trl

In [None]:
import json
import os
from pprint import pprint

import bitsandbytes as bnb
import pandas as pd
import torch
import torch.nn as nn
import transformers
from datasets import load_dataset
import numpy as np
from trl import GRPOConfig, GRPOTrainer

from peft import (
    LoraConfig,
    PeftConfig,
    PeftModel,
    get_peft_model,
    prepare_model_for_kbit_training,
)
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    DataCollatorForSeq2Seq
)
MODEL_NAME = "Qwen/Qwen2.5-0.5B"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token




In [None]:
data=pd.read_csv('summary.csv')
import pandas as pd
from sklearn.model_selection import train_test_split


def stratified_split_by_length(df, length_column='text_length', train_size=0.8, val_size=0.1, test_size=0.1, bins=3, random_state=42):
    df['length_bin'] = pd.cut(df[length_column], bins=bins)

    train, temp = train_test_split(df, train_size=train_size, stratify=df['length_bin'], random_state=random_state)
    remaining = 1 - train_size
    relative_val_size = val_size / remaining

    val, test = train_test_split(temp, train_size=relative_val_size, stratify=temp['length_bin'], random_state=random_state)

    train = train.drop(columns=['length_bin'])
    val = val.drop(columns=['length_bin'])
    test = test.drop(columns=['length_bin'])

    return train, val, test


train, val, test = stratified_split_by_length(data, length_column='text_length', bins=3)

print(f"Size of Train {len(train)}")
print(f"Size of Val {len(val)}")
print(f"Size of Test {len(test)}")


Size of Train 4000
Size of Val 500
Size of Test 500


In [None]:

def process_summarization_data(data_point):
    article = data_point['text']
    summary = data_point['summary']

    prompt = "请根据以下专利文本生成摘要：\n\n" + article + "\n\n摘要："
    full_text = prompt + summary + tokenizer.eos_token

    tokenized = tokenizer(full_text, truncation=True, max_length=2048, return_tensors='pt')
    input_ids = tokenized.input_ids[0]
    attention_mask = tokenized.attention_mask[0]

    prompt_only = tokenizer("请根据以下专利文本生成摘要：\n\n" + article + "\n\n摘要：", return_tensors='pt')
    summary_start_idx = len(prompt_only.input_ids[0])

    labels = input_ids.clone()
    labels[:summary_start_idx] = -100

    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'labels': labels
    }

processed_data = train.apply(lambda row: process_summarization_data(row), axis=1)


In [None]:
from torch.utils.data import Dataset
class DictionaryDataset(Dataset):
    def __init__(self, processed_series):
        self.data = processed_series.tolist()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
train_dataset = DictionaryDataset(processed_data)

# Evaluation funcations

In [None]:
def evaluate_with_metrics(model, val_df, tokenizer, batch_size=1, num_samples=None):
    import torch
    from rouge_score import rouge_scorer
    from bert_score import BERTScorer
    import numpy as np
    from tqdm import tqdm
    model.eval()

    if num_samples and num_samples < len(val_df):
        val_df = val_df.sample(num_samples, random_state=42)
    bert_scorer = BERTScorer(lang="zh", rescale_with_baseline=True)

    all_scores = {
        'bertscore_precision': [], 'bertscore_recall': [], 'bertscore_f1': []
    }

    sample_results = []

    gen_kwargs = {
        'max_new_tokens': 256,
        'num_beams': 1,
        "use_cache": True,
    }

    for i in tqdm(range(0, len(val_df), batch_size)):
        batch = val_df.iloc[i:i+batch_size]

        references = []
        generated_summaries = []
        articles = []

        for _, row in batch.iterrows():
            article = row['text']
            reference = row['summary']

            prompt = "请根据以下专利文本生成摘要：\n\n" + article + "\n\n摘要："
            inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048).to(model.device)

            with torch.no_grad():
                try:
                    outputs = model.generate(
                        inputs.input_ids,
                        attention_mask=inputs.attention_mask,
                        pad_token_id=tokenizer.pad_token_id,
                        **gen_kwargs
                    )

                    generated = tokenizer.decode(outputs[0], skip_special_tokens=True)

                    if "摘要：" in generated:
                        summary = generated.split("摘要：")[-1].strip()
                    else:
                        summary = generated.replace(prompt, "").strip()

                    references.append(reference)
                    generated_summaries.append(summary)
                    articles.append(article)

                except Exception as e:
                    print(f"wrong: {e}")
                    continue

        if references and generated_summaries:
            try:
                P, R, F1 = bert_scorer.score(generated_summaries, references)
                all_scores['bertscore_precision'].extend(P.cpu().numpy())
                all_scores['bertscore_recall'].extend(R.cpu().numpy())
                all_scores['bertscore_f1'].extend(F1.cpu().numpy())
            except Exception as e:
                print(f"BERTScore计算错误: {e}")

        torch.cuda.empty_cache()

        total_processed = min(i + batch_size, len(val_df))
        if len(all_scores['bertscore_f1'])%50==0:
            print(f"BERTScore F1: {np.mean(all_scores['bertscore_f1']):.4f}")
            print(f"BERTScore recall: {np.mean(all_scores['bertscore_recall']):.4f}")
            print(f"BERTScore precision: {np.mean(all_scores['bertscore_precision']):.4f}")
    avg_scores = {}
    for metric, values in all_scores.items():
        if values:
            avg_scores[metric] = np.mean(values)

    print("\final results:")
    for metric, value in avg_scores.items():
        print(f"{metric}: {value:.4f}")

    return avg_scores


In [None]:
import json
import random
import time
import pandas as pd
import numpy as np
from openai import OpenAI
from tqdm import tqdm


API_KEY = "sk-b40023b4db0e42f6b18f6ae8d2ccc221"
SAMPLE_SIZE = 50
MAX_RETRIES = 3
RETRY_DELAY = 2


SYSTEM_PROMPT = """
你是一个专业文本评估助手。请根据以下原文和摘要，从以下维度评估：
1. 准确性：摘要是否准确反映原文关键信息（1-5分）
2. 连贯性：摘要是否逻辑通顺、结构清晰（1-5分）
3. 完整性：是否涵盖主要观点和重要细节（1-5分）

请用JSON格式回复，包含各维度评分、总分（平均分）和简短评价。格式如下：
{
  "准确性": 4,
  "连贯性": 4,
  "完整性": 3,
  "平均分": 3.67,
  "评价": "摘要准确反映了原文主要观点，表述流畅，但缺少一些重要细节。"
}
"""


def evaluate_summary(original_text, generated_summary, api_key, client=None):
    if client is None:
        client = OpenAI(api_key=api_key, base_url="https://api.deepseek.com")

    user_message = f"原文：{original_text}\n生成摘要：{generated_summary}"

    for attempt in range(MAX_RETRIES):
        try:
            response = client.chat.completions.create(
                model="deepseek-chat",
                messages=[
                    {"role": "system", "content": SYSTEM_PROMPT},
                    {"role": "user", "content": user_message}
                ],
                temperature=0.2,
                stream=False
            )

            result = response.choices[0].message.content

            # 尝试解析JSON
            try:
                evaluation_dict = json.loads(result)
                # 检查必要字段是否存在
                required_fields = ["准确性", "连贯性", "完整性", "平均分"]
                if all(field in evaluation_dict for field in required_fields):
                    return evaluation_dict
                else:
                    missing = [f for f in required_fields if f not in evaluation_dict]
                    print(f"{missing}")
                    print(f"{result}")
            except json.JSONDecodeError:
                time.sleep(RETRY_DELAY)

        except Exception as e:
            time.sleep(RETRY_DELAY)

    return {
        "准确性": 0,
        "连贯性": 0,
        "完整性": 0,
        "平均分": 0,
    }

def run_evaluation(data,model):

    gen_kwargs = {
        'max_new_tokens': 256,
        'num_beams': 1,
        "use_cache": True,
    }
    if len(data) <= SAMPLE_SIZE:
        sample_indices = list(range(len(data)))

    else:
        sample_indices = random.sample(range(len(data)), SAMPLE_SIZE)


    sampled_data = data.iloc[sample_indices].reset_index(drop=True)

    client = OpenAI(api_key=API_KEY, base_url="https://api.deepseek.com")

    all_evaluations = []
    for i, row in tqdm(sampled_data.iterrows(), total=len(sampled_data)):
        article= row['text']
        prompt = "请根据以下专利文本生成摘要：\n\n" + article + "\n\n摘要："
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048).to(model.device)
        with torch.no_grad():
            try:
                outputs = model.generate(
                    inputs.input_ids,
                    attention_mask=inputs.attention_mask,
                    pad_token_id=tokenizer.pad_token_id,
                    **gen_kwargs
                )

                generated = tokenizer.decode(outputs[0], skip_special_tokens=True)

                if "摘要：" in generated:
                    summary = generated.split("摘要：")[-1].strip()
                else:
                    summary = generated.replace(prompt, "").strip()
            except Exception as e:
                    print(f"wrong: {e}")
                    continue
        evaluation = evaluate_summary(article, summary, API_KEY, client)

        result = {
            'index': sample_indices[i],
            'original': article,
            'summary': summary,
            **evaluation
        }
        all_evaluations.append(result)

        time.sleep(0.5)
    scores = {
        "准确性": [],
        "连贯性": [],
        "完整性": [],
        "平均分": []
    }

    for eval_result in all_evaluations:
        for metric in scores.keys():
            if metric in eval_result and isinstance(eval_result[metric], (int, float)):
                scores[metric].append(eval_result[metric])

    average_scores = {metric: np.mean(values) if values else 0
                     for metric, values in scores.items()}



    for metric, score in average_scores.items():
        print(f"{metric}: {score:.2f}")

    results_df = pd.DataFrame(all_evaluations)
    results_df.to_csv("evaluation_results_gpro.csv", index=False)


    summary_df = pd.DataFrame([average_scores])
    summary_df.to_csv("evaluation_summary_gpro.csv", index=False)


    return average_scores



# Fine-tune(sft)

## Set config

In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

# Load the model

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    device_map="auto",
    trust_remote_code=True,
    quantization_config=bnb_config,
)



# Function to display trainable parameters
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

# Configure LoRA for parameter-efficient fine-tuning
lora_config = LoraConfig(
    r=32,
    lora_alpha=64,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ]
)

# Apply LoRA to the model
model = get_peft_model(model, lora_config)
print_trainable_parameters(model)

trainable params: 17596416 || all params: 332715904 || trainable%: 5.288721034507566


## test before fine-funing

In [None]:
results = evaluate_with_metrics(model, test, tokenizer)

 10%|█         | 50/500 [10:59<1:41:40, 13.56s/it]

BERTScore F1: 0.3932
BERTScore recall: 0.4078
BERTScore precision: 0.3782


 20%|██        | 100/500 [21:54<1:26:22, 12.96s/it]

BERTScore F1: 0.3871
BERTScore recall: 0.3995
BERTScore precision: 0.3750


 30%|███       | 150/500 [32:52<1:15:19, 12.91s/it]

BERTScore F1: 0.3837
BERTScore recall: 0.3956
BERTScore precision: 0.3718


 40%|████      | 200/500 [43:51<1:08:08, 13.63s/it]

BERTScore F1: 0.3806
BERTScore recall: 0.3919
BERTScore precision: 0.3692


 50%|█████     | 250/500 [54:50<56:07, 13.47s/it]  

BERTScore F1: 0.3803
BERTScore recall: 0.3918
BERTScore precision: 0.3687


 60%|██████    | 300/500 [1:05:49<42:46, 12.83s/it]

BERTScore F1: 0.3796
BERTScore recall: 0.3911
BERTScore precision: 0.3679


 70%|███████   | 350/500 [1:16:50<33:15, 13.30s/it]

BERTScore F1: 0.3792
BERTScore recall: 0.3914
BERTScore precision: 0.3667


 80%|████████  | 400/500 [1:27:47<21:07, 12.68s/it]

BERTScore F1: 0.3806
BERTScore recall: 0.3927
BERTScore precision: 0.3686


 90%|█████████ | 450/500 [1:40:17<10:32, 12.65s/it]

BERTScore F1: 0.3798
BERTScore recall: 0.3920
BERTScore precision: 0.3676


100%|██████████| 500/500 [1:52:34<00:00, 13.51s/it]

BERTScore F1: 0.3799
BERTScore recall: 0.3918
BERTScore precision: 0.3680
inal results:
bertscore_precision: 0.3680
bertscore_recall: 0.3918
bertscore_f1: 0.3799





The output means accuracy, coherence, completeness, average score

In [None]:
run_evaluation(test,model)

100%|██████████| 50/50 [29:52<00:00, 35.86s/it]

准确性: 4.32
连贯性: 3.58
完整性: 2.62
平均分: 3.51





{'准确性': 4.32, '连贯性': 3.58, '完整性': 2.62, '平均分': 3.5070000000000006}

In [None]:
torch.cuda.empty_cache()

## Fine tune

In [None]:
OUTPUT_DIR = "summarization_sft"

training_args = transformers.TrainingArguments(
    per_device_train_batch_size=1,
    gradient_accumulation_steps=16,
    num_train_epochs=2,
    learning_rate=5e-4,
    bf16=True,
    save_total_limit=3,
    logging_steps=50,
    output_dir=OUTPUT_DIR,
    #max_steps=250,   # Limit steps for demonstration
    optim="paged_adamw_8bit",
    lr_scheduler_type="cosine",
    warmup_ratio=0.1,
    report_to="tensorboard",
)

# Initialize trainer
trainer = transformers.Trainer(
    model=model,
    train_dataset=train_dataset,
    args=training_args,
    # eval_dataset=processed_val,
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
)
# # Disable cache for training
model.config.use_cache = False

# Uncomment to run the training
trainer.train()
trainer.save_model()

  0%|          | 0/250 [00:00<?, ?it/s]

  batch["labels"] = torch.tensor(batch["labels"], dtype=torch.int64)


{'loss': 1.247, 'grad_norm': 0.7882800698280334, 'learning_rate': 0.0004849231551964771, 'epoch': 0.4}
{'loss': 1.1144, 'grad_norm': 0.7078947424888611, 'learning_rate': 0.000375, 'epoch': 0.8}
{'loss': 0.9729, 'grad_norm': 0.7393576502799988, 'learning_rate': 0.00020658795558326743, 'epoch': 1.2}
{'loss': 0.8538, 'grad_norm': 0.7185181379318237, 'learning_rate': 5.848888922025553e-05, 'epoch': 1.6}
{'loss': 0.8284, 'grad_norm': 0.6915761828422546, 'learning_rate': 0.0, 'epoch': 2.0}
{'train_runtime': 8499.578, 'train_samples_per_second': 0.941, 'train_steps_per_second': 0.029, 'train_loss': 1.003307830810547, 'epoch': 2.0}


In [None]:
#results = evaluate_with_metrics(model, val, tokenizer)
#run_evaluation(val,model)

## evalute after fine-tuning

In [None]:
torch.cuda.empty_cache()

In [None]:
results = evaluate_with_metrics(model, test, tokenizer)

 10%|█         | 50/500 [07:53<1:12:54,  9.72s/it]

BERTScore F1: 0.6206
BERTScore recall: 0.6017
BERTScore precision: 0.6398


 20%|██        | 100/500 [15:30<57:40,  8.65s/it] 

BERTScore F1: 0.6111
BERTScore recall: 0.5936
BERTScore precision: 0.6287


 30%|███       | 150/500 [23:21<49:20,  8.46s/it]  

BERTScore F1: 0.6116
BERTScore recall: 0.5943
BERTScore precision: 0.6290


 40%|████      | 200/500 [31:06<45:21,  9.07s/it]

BERTScore F1: 0.6100
BERTScore recall: 0.5925
BERTScore precision: 0.6276


 50%|█████     | 250/500 [38:52<38:38,  9.28s/it]

BERTScore F1: 0.6101
BERTScore recall: 0.5931
BERTScore precision: 0.6271


 60%|██████    | 300/500 [46:34<29:46,  8.93s/it]

BERTScore F1: 0.6103
BERTScore recall: 0.5929
BERTScore precision: 0.6278


 70%|███████   | 350/500 [54:13<23:26,  9.37s/it]

BERTScore F1: 0.6105
BERTScore recall: 0.5933
BERTScore precision: 0.6277


 80%|████████  | 400/500 [1:01:46<14:14,  8.55s/it]

BERTScore F1: 0.6108
BERTScore recall: 0.5939
BERTScore precision: 0.6278


 90%|█████████ | 450/500 [1:09:29<07:32,  9.05s/it]

BERTScore F1: 0.6104
BERTScore recall: 0.5934
BERTScore precision: 0.6275


100%|██████████| 500/500 [1:17:07<00:00,  9.25s/it]

BERTScore F1: 0.6097
BERTScore recall: 0.5921
BERTScore precision: 0.6274
inal results:
bertscore_precision: 0.6274
bertscore_recall: 0.5921
bertscore_f1: 0.6097





The output means accuracy, coherence, completeness, average score

In [None]:
run_evaluation(test,model)

100%|██████████| 50/50 [18:17<00:00, 21.95s/it]

准确性: 4.48
连贯性: 4.46
完整性: 3.52
平均分: 4.16





{'准确性': 4.48, '连贯性': 4.46, '完整性': 3.52, '平均分': 4.1564}

# GPRO

## Prepare data

In [None]:
!pip install thulac

In [None]:
import random


num_samples = 1000

indices = random.sample(range(len(train['text'])), num_samples)

grpo_dataset = []
for i in indices:
    original = train['text'].iloc[i]
    summary = train['summary'].iloc[i]
    grpo_dataset.append({
        "prompt": "请根据以下专利文本生成摘要：\n\n" + original + "\n\n摘要：",
        "completion": summary,
        "ground_truth": summary
    })

## reward function

In [None]:
import thulac
from collections import Counter
from nltk.util import ngrams

thu = thulac.thulac()

def calculate_ngram_diversity(text, n=2):
    tokens = [word[0] for word in thu.cut(text)]
    n_grams = list(ngrams(tokens, n))
    ngram_count = Counter(n_grams)
    return len(ngram_count) / len(n_grams) if n_grams else 0


Model loaded succeed


In [None]:
import random
import requests
import json
import time
import torch
from bert_score import BERTScorer

from collections import Counter
from nltk.util import ngrams
import torch
from bert_score import BERTScorer
import numpy as np

global_bert_scorer = BERTScorer(model_type="bert-base-chinese", lang="zh", rescale_with_baseline=True)



def hybrid_reward(completions, ground_truth, **kwargs):
    rewards = []

    with torch.no_grad():
        P, R, F1 = global_bert_scorer.score(completions, ground_truth, batch_size=8)

    for completion, f1_score in zip(completions, F1):
        content_reward = f1_score.item()

        ngram_diversity = calculate_ngram_diversity(completion)
        ngram_reward = ngram_diversity

        total_reward = (0.7 * content_reward +
                        0.3 * ngram_reward)

        rewards.append(total_reward)

    return rewards



## grpo process

In [None]:
from bert_score import score
OUTPUT_DIR = "summarization_grpo"


# Configure GRPO with memory-efficient settings
config = GRPOConfig(
    # per_device_train_batch_size=8,
    # gradient_accumulation_steps=4,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=3,
    learning_rate=1e-4,
    # bf16=True,
    # fp16=False,
    # optim="adamw_torch",
    # max_grad_norm=1.0,
    # save_total_limit=1,
    #num_train_epochs=1,
    logging_steps=2,
    max_steps=50,
    output_dir=OUTPUT_DIR,
    report_to="tensorboard",
    # # 6. 添加内存优化选项
    # gradient_checkpointing=True,
    # deepspeed=None,
)

# Initialize GRPO trainer
grpo_trainer = GRPOTrainer(
    model=model,
    args=config,
    train_dataset=grpo_dataset,
    processing_class=tokenizer,
    reward_funcs=hybrid_reward
)

# 7. 优化模型设置
model.config.use_cache = False      # 禁用缓存以节省内存


# Uncomment to run the training
grpo_trainer.train()

  0%|          | 0/50 [00:00<?, ?it/s]

{'loss': -0.0147, 'grad_norm': 2.273724317550659, 'learning_rate': 9.6e-05, 'completion_length': 163.9375, 'rewards/hybrid_reward': 0.6773315072059631, 'reward': 0.6773315072059631, 'reward_std': 0.02627504492799441, 'kl': 0.7613118489583334, 'clip_ratio': 0.0, 'epoch': 0.01}
{'loss': 0.0495, 'grad_norm': 2.59893798828125, 'learning_rate': 9.200000000000001e-05, 'completion_length': 164.5625, 'rewards/hybrid_reward': 0.6795501212279002, 'reward': 0.6795501212279002, 'reward_std': 0.025207906495779753, 'kl': 0.821533203125, 'clip_ratio': 0.0, 'epoch': 0.01}
{'loss': 0.0401, 'grad_norm': 2.4737348556518555, 'learning_rate': 8.800000000000001e-05, 'completion_length': 161.66666666666666, 'rewards/hybrid_reward': 0.6599637269973755, 'reward': 0.6599637269973755, 'reward_std': 0.017015948270757992, 'kl': 0.7909342447916666, 'clip_ratio': 0.0, 'epoch': 0.02}
{'loss': 0.0341, 'grad_norm': 3.011964797973633, 'learning_rate': 8.4e-05, 'completion_length': 158.25, 'rewards/hybrid_reward': 0.6744

TrainOutput(global_step=50, training_loss=0.02712998364120722, metrics={'train_runtime': 3244.4053, 'train_samples_per_second': 0.37, 'train_steps_per_second': 0.015, 'total_flos': 0.0, 'train_loss': 0.02712998364120722})

In [None]:
torch.cuda.empty_cache()

## load model
If the notebook could not run that long, you can load the model have fine tuned before.

In [None]:
# from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
# from peft import PeftModel, PeftConfig

# bnb_config = BitsAndBytesConfig(
#     load_in_4bit=True,
#     bnb_4bit_quant_type="nf4",
#     bnb_4bit_compute_dtype=torch.bfloat16,
#     bnb_4bit_use_double_quant=True,
# )

# base_model_name = "Qwen/Qwen2.5-0.5B"
# model = AutoModelForCausalLM.from_pretrained(
#     base_model_name,
#     device_map="auto",
#     trust_remote_code=True,
#     quantization_config=bnb_config,
# )

# tokenizer = AutoTokenizer.from_pretrained(base_model_name)
# tokenizer.pad_token = tokenizer.eos_token

# checkpoint_path = "summarization_sft/checkpoint-250/"
# model = PeftModel.from_pretrained(model, checkpoint_path)

# model.eval()

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): Qwen2ForCausalLM(
      (model): Qwen2Model(
        (embed_tokens): Embedding(151936, 896)
        (layers): ModuleList(
          (0-23): 24 x Qwen2DecoderLayer(
            (self_attn): Qwen2SdpaAttention(
              (q_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=896, out_features=896, bias=True)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=896, out_features=32, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=32, out_features=896, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (k_proj): lora.

## evaluate after grpo process

In [None]:
results = evaluate_with_metrics(model, test, tokenizer)

  1%|          | 5/500 [00:52<1:30:25, 10.96s/it]

BERTScore F1: 0.6156
BERTScore recall: 0.6028
BERTScore precision: 0.6283


  2%|▏         | 10/500 [01:40<1:18:25,  9.60s/it]

BERTScore F1: 0.6308
BERTScore recall: 0.6171
BERTScore precision: 0.6446


  3%|▎         | 15/500 [02:28<1:21:42, 10.11s/it]

BERTScore F1: 0.6197
BERTScore recall: 0.6039
BERTScore precision: 0.6355


  4%|▍         | 20/500 [03:15<1:15:38,  9.46s/it]

BERTScore F1: 0.6183
BERTScore recall: 0.5986
BERTScore precision: 0.6383


  5%|▌         | 25/500 [03:59<1:12:41,  9.18s/it]

BERTScore F1: 0.6156
BERTScore recall: 0.5944
BERTScore precision: 0.6373


  6%|▌         | 30/500 [04:41<1:06:56,  8.55s/it]

BERTScore F1: 0.6182
BERTScore recall: 0.5952
BERTScore precision: 0.6417


  7%|▋         | 35/500 [05:25<1:09:57,  9.03s/it]

BERTScore F1: 0.6154
BERTScore recall: 0.5926
BERTScore precision: 0.6387


  8%|▊         | 40/500 [06:08<1:03:11,  8.24s/it]

BERTScore F1: 0.6168
BERTScore recall: 0.5941
BERTScore precision: 0.6400


  9%|▉         | 45/500 [06:58<1:11:58,  9.49s/it]

BERTScore F1: 0.6165
BERTScore recall: 0.5954
BERTScore precision: 0.6380


 10%|█         | 50/500 [07:47<1:14:10,  9.89s/it]

BERTScore F1: 0.6197
BERTScore recall: 0.6001
BERTScore precision: 0.6397


 11%|█         | 55/500 [08:37<1:14:17, 10.02s/it]

BERTScore F1: 0.6216
BERTScore recall: 0.6028
BERTScore precision: 0.6407


 12%|█▏        | 60/500 [09:20<1:05:34,  8.94s/it]

BERTScore F1: 0.6195
BERTScore recall: 0.6001
BERTScore precision: 0.6391


 13%|█▎        | 65/500 [10:00<57:04,  7.87s/it]  

BERTScore F1: 0.6209
BERTScore recall: 0.6014
BERTScore precision: 0.6407


 14%|█▍        | 70/500 [10:47<1:06:22,  9.26s/it]

BERTScore F1: 0.6172
BERTScore recall: 0.5987
BERTScore precision: 0.6358


 15%|█▌        | 75/500 [11:34<1:06:22,  9.37s/it]

BERTScore F1: 0.6152
BERTScore recall: 0.5981
BERTScore precision: 0.6324


 16%|█▌        | 80/500 [12:19<1:03:14,  9.04s/it]

BERTScore F1: 0.6134
BERTScore recall: 0.5969
BERTScore precision: 0.6301


 17%|█▋        | 85/500 [13:06<1:07:17,  9.73s/it]

BERTScore F1: 0.6129
BERTScore recall: 0.5957
BERTScore precision: 0.6303


 18%|█▊        | 90/500 [13:52<1:02:04,  9.08s/it]

BERTScore F1: 0.6140
BERTScore recall: 0.5968
BERTScore precision: 0.6313


 19%|█▉        | 95/500 [14:38<1:03:07,  9.35s/it]

BERTScore F1: 0.6130
BERTScore recall: 0.5959
BERTScore precision: 0.6302


 20%|██        | 100/500 [15:20<57:35,  8.64s/it] 

BERTScore F1: 0.6134
BERTScore recall: 0.5965
BERTScore precision: 0.6305


 21%|██        | 105/500 [16:04<56:31,  8.59s/it]

BERTScore F1: 0.6124
BERTScore recall: 0.5956
BERTScore precision: 0.6293


 22%|██▏       | 110/500 [16:50<59:59,  9.23s/it]  

BERTScore F1: 0.6128
BERTScore recall: 0.5959
BERTScore precision: 0.6300


 23%|██▎       | 115/500 [17:35<57:13,  8.92s/it]

BERTScore F1: 0.6132
BERTScore recall: 0.5963
BERTScore precision: 0.6302


 24%|██▍       | 120/500 [18:20<56:10,  8.87s/it]

BERTScore F1: 0.6130
BERTScore recall: 0.5962
BERTScore precision: 0.6300


 25%|██▌       | 125/500 [19:08<1:02:18,  9.97s/it]

BERTScore F1: 0.6117
BERTScore recall: 0.5951
BERTScore precision: 0.6285


 26%|██▌       | 130/500 [19:55<57:45,  9.37s/it]  

BERTScore F1: 0.6123
BERTScore recall: 0.5954
BERTScore precision: 0.6294


 27%|██▋       | 135/500 [20:44<58:47,  9.66s/it]

BERTScore F1: 0.6117
BERTScore recall: 0.5949
BERTScore precision: 0.6286


 28%|██▊       | 140/500 [21:30<56:04,  9.35s/it]

BERTScore F1: 0.6115
BERTScore recall: 0.5951
BERTScore precision: 0.6280


 29%|██▉       | 145/500 [22:13<52:02,  8.80s/it]

BERTScore F1: 0.6115
BERTScore recall: 0.5951
BERTScore precision: 0.6281


 30%|███       | 150/500 [22:58<52:24,  8.99s/it]

BERTScore F1: 0.6117
BERTScore recall: 0.5954
BERTScore precision: 0.6281


 31%|███       | 155/500 [23:42<51:05,  8.88s/it]

BERTScore F1: 0.6117
BERTScore recall: 0.5954
BERTScore precision: 0.6281


 32%|███▏      | 160/500 [24:23<45:08,  7.97s/it]

BERTScore F1: 0.6117
BERTScore recall: 0.5952
BERTScore precision: 0.6283


 33%|███▎      | 165/500 [25:10<51:16,  9.18s/it]

BERTScore F1: 0.6106
BERTScore recall: 0.5944
BERTScore precision: 0.6269


 34%|███▍      | 170/500 [25:56<49:27,  8.99s/it]

BERTScore F1: 0.6102
BERTScore recall: 0.5937
BERTScore precision: 0.6267


 35%|███▌      | 175/500 [26:38<45:36,  8.42s/it]

BERTScore F1: 0.6100
BERTScore recall: 0.5934
BERTScore precision: 0.6267


 36%|███▌      | 180/500 [27:27<49:23,  9.26s/it]

BERTScore F1: 0.6098
BERTScore recall: 0.5930
BERTScore precision: 0.6266


 37%|███▋      | 185/500 [28:14<48:53,  9.31s/it]

BERTScore F1: 0.6102
BERTScore recall: 0.5932
BERTScore precision: 0.6272


 38%|███▊      | 190/500 [28:59<45:21,  8.78s/it]

BERTScore F1: 0.6094
BERTScore recall: 0.5926
BERTScore precision: 0.6262


 39%|███▉      | 195/500 [29:44<44:57,  8.84s/it]

BERTScore F1: 0.6087
BERTScore recall: 0.5919
BERTScore precision: 0.6256


 40%|████      | 200/500 [30:27<43:07,  8.63s/it]

BERTScore F1: 0.6094
BERTScore recall: 0.5923
BERTScore precision: 0.6265


 41%|████      | 205/500 [31:15<45:46,  9.31s/it]

BERTScore F1: 0.6095
BERTScore recall: 0.5926
BERTScore precision: 0.6265


 42%|████▏     | 210/500 [32:04<46:18,  9.58s/it]

BERTScore F1: 0.6106
BERTScore recall: 0.5940
BERTScore precision: 0.6271


 43%|████▎     | 215/500 [32:51<43:56,  9.25s/it]

BERTScore F1: 0.6103
BERTScore recall: 0.5938
BERTScore precision: 0.6269


 44%|████▍     | 220/500 [33:33<39:31,  8.47s/it]

BERTScore F1: 0.6101
BERTScore recall: 0.5934
BERTScore precision: 0.6267


 45%|████▌     | 225/500 [34:14<37:24,  8.16s/it]

BERTScore F1: 0.6100
BERTScore recall: 0.5932
BERTScore precision: 0.6270


 46%|████▌     | 230/500 [35:02<40:55,  9.09s/it]

BERTScore F1: 0.6100
BERTScore recall: 0.5932
BERTScore precision: 0.6267


 47%|████▋     | 235/500 [35:44<37:21,  8.46s/it]

BERTScore F1: 0.6092
BERTScore recall: 0.5923
BERTScore precision: 0.6261


 48%|████▊     | 240/500 [36:33<42:01,  9.70s/it]

BERTScore F1: 0.6091
BERTScore recall: 0.5922
BERTScore precision: 0.6260


 49%|████▉     | 245/500 [37:23<44:21, 10.44s/it]

BERTScore F1: 0.6097
BERTScore recall: 0.5927
BERTScore precision: 0.6267


 50%|█████     | 250/500 [38:08<38:10,  9.16s/it]

BERTScore F1: 0.6099
BERTScore recall: 0.5926
BERTScore precision: 0.6272


 51%|█████     | 255/500 [38:53<39:23,  9.65s/it]

BERTScore F1: 0.6104
BERTScore recall: 0.5932
BERTScore precision: 0.6276


 52%|█████▏    | 260/500 [39:36<35:37,  8.91s/it]

BERTScore F1: 0.6108
BERTScore recall: 0.5935
BERTScore precision: 0.6281


 53%|█████▎    | 265/500 [40:20<32:47,  8.37s/it]

BERTScore F1: 0.6109
BERTScore recall: 0.5934
BERTScore precision: 0.6284


 54%|█████▍    | 270/500 [41:04<34:09,  8.91s/it]

BERTScore F1: 0.6105
BERTScore recall: 0.5927
BERTScore precision: 0.6285


 55%|█████▌    | 275/500 [41:49<33:08,  8.84s/it]

BERTScore F1: 0.6108
BERTScore recall: 0.5930
BERTScore precision: 0.6287


 56%|█████▌    | 280/500 [42:36<33:48,  9.22s/it]

BERTScore F1: 0.6107
BERTScore recall: 0.5931
BERTScore precision: 0.6283


 57%|█████▋    | 285/500 [43:26<36:13, 10.11s/it]

BERTScore F1: 0.6107
BERTScore recall: 0.5930
BERTScore precision: 0.6284


 58%|█████▊    | 290/500 [44:12<31:44,  9.07s/it]

BERTScore F1: 0.6104
BERTScore recall: 0.5929
BERTScore precision: 0.6280


 59%|█████▉    | 295/500 [44:55<28:38,  8.38s/it]

BERTScore F1: 0.6102
BERTScore recall: 0.5924
BERTScore precision: 0.6281


 60%|██████    | 300/500 [45:40<29:20,  8.80s/it]

BERTScore F1: 0.6102
BERTScore recall: 0.5925
BERTScore precision: 0.6280


 61%|██████    | 305/500 [46:28<29:33,  9.10s/it]

BERTScore F1: 0.6098
BERTScore recall: 0.5921
BERTScore precision: 0.6275


 62%|██████▏   | 310/500 [47:11<27:55,  8.82s/it]

BERTScore F1: 0.6101
BERTScore recall: 0.5925
BERTScore precision: 0.6277


 63%|██████▎   | 315/500 [47:56<26:52,  8.72s/it]

BERTScore F1: 0.6104
BERTScore recall: 0.5929
BERTScore precision: 0.6280


 64%|██████▍   | 320/500 [48:40<25:57,  8.66s/it]

BERTScore F1: 0.6105
BERTScore recall: 0.5931
BERTScore precision: 0.6280


 65%|██████▌   | 325/500 [49:28<27:29,  9.43s/it]

BERTScore F1: 0.6100
BERTScore recall: 0.5925
BERTScore precision: 0.6276


 66%|██████▌   | 330/500 [50:15<28:16,  9.98s/it]

BERTScore F1: 0.6100
BERTScore recall: 0.5926
BERTScore precision: 0.6275


 67%|██████▋   | 335/500 [51:03<27:25,  9.97s/it]

BERTScore F1: 0.6105
BERTScore recall: 0.5930
BERTScore precision: 0.6280


 68%|██████▊   | 340/500 [51:50<27:08, 10.18s/it]

BERTScore F1: 0.6104
BERTScore recall: 0.5931
BERTScore precision: 0.6278


 69%|██████▉   | 345/500 [52:37<25:18,  9.80s/it]

BERTScore F1: 0.6103
BERTScore recall: 0.5929
BERTScore precision: 0.6276


 70%|███████   | 350/500 [53:20<22:16,  8.91s/it]

BERTScore F1: 0.6106
BERTScore recall: 0.5931
BERTScore precision: 0.6281


 71%|███████   | 355/500 [54:08<22:11,  9.18s/it]

BERTScore F1: 0.6105
BERTScore recall: 0.5932
BERTScore precision: 0.6279


 72%|███████▏  | 360/500 [54:52<20:57,  8.98s/it]

BERTScore F1: 0.6106
BERTScore recall: 0.5932
BERTScore precision: 0.6280


 73%|███████▎  | 365/500 [55:33<18:02,  8.02s/it]

BERTScore F1: 0.6101
BERTScore recall: 0.5927
BERTScore precision: 0.6275


 74%|███████▍  | 370/500 [56:18<19:47,  9.14s/it]

BERTScore F1: 0.6101
BERTScore recall: 0.5928
BERTScore precision: 0.6274


 75%|███████▌  | 375/500 [57:02<18:12,  8.74s/it]

BERTScore F1: 0.6099
BERTScore recall: 0.5926
BERTScore precision: 0.6273


 76%|███████▌  | 380/500 [57:48<19:00,  9.50s/it]

BERTScore F1: 0.6102
BERTScore recall: 0.5931
BERTScore precision: 0.6274


 77%|███████▋  | 385/500 [58:32<17:01,  8.88s/it]

BERTScore F1: 0.6103
BERTScore recall: 0.5932
BERTScore precision: 0.6274


 78%|███████▊  | 390/500 [59:20<16:59,  9.27s/it]

BERTScore F1: 0.6106
BERTScore recall: 0.5937
BERTScore precision: 0.6275


 79%|███████▉  | 395/500 [1:00:07<17:12,  9.84s/it]

BERTScore F1: 0.6107
BERTScore recall: 0.5938
BERTScore precision: 0.6276


 80%|████████  | 400/500 [1:00:47<14:18,  8.59s/it]

BERTScore F1: 0.6109
BERTScore recall: 0.5939
BERTScore precision: 0.6279


 81%|████████  | 405/500 [1:01:34<14:29,  9.15s/it]

BERTScore F1: 0.6110
BERTScore recall: 0.5941
BERTScore precision: 0.6279


 82%|████████▏ | 410/500 [1:02:18<13:51,  9.24s/it]

BERTScore F1: 0.6111
BERTScore recall: 0.5944
BERTScore precision: 0.6278


 83%|████████▎ | 415/500 [1:03:05<12:49,  9.05s/it]

BERTScore F1: 0.6109
BERTScore recall: 0.5944
BERTScore precision: 0.6275


 84%|████████▍ | 420/500 [1:03:54<12:38,  9.48s/it]

BERTScore F1: 0.6110
BERTScore recall: 0.5946
BERTScore precision: 0.6274


 85%|████████▌ | 425/500 [1:04:37<10:53,  8.72s/it]

BERTScore F1: 0.6109
BERTScore recall: 0.5942
BERTScore precision: 0.6276


 86%|████████▌ | 430/500 [1:05:24<10:46,  9.23s/it]

BERTScore F1: 0.6112
BERTScore recall: 0.5947
BERTScore precision: 0.6279


 87%|████████▋ | 435/500 [1:06:06<09:39,  8.91s/it]

BERTScore F1: 0.6109
BERTScore recall: 0.5943
BERTScore precision: 0.6276


 88%|████████▊ | 440/500 [1:06:47<08:18,  8.31s/it]

BERTScore F1: 0.6110
BERTScore recall: 0.5943
BERTScore precision: 0.6278


 89%|████████▉ | 445/500 [1:07:37<08:50,  9.64s/it]

BERTScore F1: 0.6106
BERTScore recall: 0.5938
BERTScore precision: 0.6275


 90%|█████████ | 450/500 [1:08:23<07:39,  9.18s/it]

BERTScore F1: 0.6106
BERTScore recall: 0.5936
BERTScore precision: 0.6277


 91%|█████████ | 455/500 [1:09:13<06:58,  9.31s/it]

BERTScore F1: 0.6107
BERTScore recall: 0.5937
BERTScore precision: 0.6277


 92%|█████████▏| 460/500 [1:09:59<05:43,  8.60s/it]

BERTScore F1: 0.6104
BERTScore recall: 0.5935
BERTScore precision: 0.6274


 93%|█████████▎| 465/500 [1:10:41<05:03,  8.66s/it]

BERTScore F1: 0.6106
BERTScore recall: 0.5935
BERTScore precision: 0.6278


 94%|█████████▍| 470/500 [1:11:24<04:26,  8.87s/it]

BERTScore F1: 0.6105
BERTScore recall: 0.5933
BERTScore precision: 0.6278


 95%|█████████▌| 475/500 [1:12:05<03:30,  8.42s/it]

BERTScore F1: 0.6106
BERTScore recall: 0.5934
BERTScore precision: 0.6279


 96%|█████████▌| 480/500 [1:12:52<03:10,  9.50s/it]

BERTScore F1: 0.6107
BERTScore recall: 0.5933
BERTScore precision: 0.6281


 97%|█████████▋| 485/500 [1:13:45<02:37, 10.52s/it]

BERTScore F1: 0.6110
BERTScore recall: 0.5937
BERTScore precision: 0.6283


 98%|█████████▊| 490/500 [1:14:38<01:44, 10.43s/it]

BERTScore F1: 0.6111
BERTScore recall: 0.5939
BERTScore precision: 0.6283


 99%|█████████▉| 495/500 [1:15:26<00:52, 10.42s/it]

BERTScore F1: 0.6105
BERTScore recall: 0.5934
BERTScore precision: 0.6277


100%|██████████| 500/500 [1:16:11<00:00,  9.14s/it]

BERTScore F1: 0.6104
BERTScore recall: 0.5931
BERTScore precision: 0.6278
inal results:
bertscore_precision: 0.6278
bertscore_recall: 0.5931
bertscore_f1: 0.6104





In [None]:
run_evaluation(test,model)

100%|██████████| 50/50 [16:35<00:00, 19.91s/it]

准确性: 4.64
连贯性: 4.62
完整性: 3.70
平均分: 4.32





{'准确性': 4.64, '连贯性': 4.62, '完整性': 3.7, '平均分': 4.3229999999999995}