In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import warnings
warnings.filterwarnings("ignore")

In [3]:
import numpy as np
import pandas as pd
import os
from tqdm import tqdm
import bitsandbytes as bnb
import torch
import torch.nn as nn
import transformers
from datasets import Dataset
from peft import LoraConfig, PeftConfig
from trl import SFTTrainer
from trl import setup_chat_format
from transformers import (AutoModelForCausalLM, 
                          AutoTokenizer, 
                          BitsAndBytesConfig, 
                          TrainingArguments, 
                          pipeline, 
                          logging)
from sklearn.metrics import (accuracy_score, 
                             classification_report, 
                             confusion_matrix)
from sklearn.model_selection import train_test_split

In [4]:
print(f"pytorch version {torch.__version__}")


def _to_model_device(toks, lm):
    # 最靠谱：直接问模型“输入嵌入层”在哪个 device
    dev = lm.get_input_embeddings().weight.device
    return {k: v.to(dev) for k, v in toks.items()}

pytorch version 2.5.1+cu121


In [5]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"working on {device}")

working on cuda:0


In [6]:
filename = "data/all-data.csv"

df = pd.read_csv(filename, 
                 names=["sentiment", "text"],
                 encoding="utf-8", encoding_errors="replace")

X_train = list()
X_test = list()

for sentiment in ["positive", "neutral", "negative"]:
    train, test  = train_test_split(df[df.sentiment==sentiment], 
                                    train_size=300,
                                    test_size=300, 
                                    random_state=42)
    # print(test)
    X_train.append(train)
    X_test.append(test)
X_train = pd.concat(X_train).sample(frac=1, random_state=10)
X_test = pd.concat(X_test)
# print(X_test)

eval_idx = [idx for idx in df.index if idx not in list(X_train.index) + list(X_test.index)]
X_eval = df[df.index.isin(eval_idx)]
X_eval = (X_eval
          .groupby('sentiment', group_keys=False)
          .apply(lambda x: x.sample(n=50, random_state=10, replace=True)))
X_train = X_train.reset_index(drop=True)

# 扰动用的
import copy
X_train_new = copy.deepcopy(X_train)
X_test_new  = copy.deepcopy(X_test)
X_eval_new  = copy.deepcopy(X_eval)

def generate_prompt(data_point):
    return f"""
            Analyze the sentiment of the news headline enclosed in square brackets, 
            determine if it is positive, neutral, or negative, and return the answer as 
            the corresponding sentiment label "positive" or "neutral" or "negative".

            [{data_point["text"]}] = {data_point["sentiment"]}
            """.strip()

def generate_test_prompt(data_point):
    return f"""
            Analyze the sentiment of the news headline enclosed in square brackets, 
            determine if it is positive, neutral, or negative, and return the answer as 
            the corresponding sentiment label "positive" or "neutral" or "negative".

            [{data_point["text"]}] = """.strip()

X_train = pd.DataFrame(X_train.apply(generate_prompt, axis=1), 
                       columns=["text"])
X_eval = pd.DataFrame(X_eval.apply(generate_prompt, axis=1), 
                      columns=["text"])

y_true = X_test.sentiment
X_test = pd.DataFrame(X_test.apply(generate_test_prompt, axis=1), columns=["text"])
print(X_test.head())
train_data = Dataset.from_pandas(X_train)
eval_data = Dataset.from_pandas(X_eval)

                                                   text
567   Analyze the sentiment of the news headline enc...
1752  Analyze the sentiment of the news headline enc...
995   Analyze the sentiment of the news headline enc...
601   Analyze the sentiment of the news headline enc...
568   Analyze the sentiment of the news headline enc...


In [7]:
def evaluate(y_true, y_pred):
    labels = ['positive', 'neutral', 'negative']
    mapping = {'positive': 2, 'neutral': 1, 'none':1, 'negative': 0}
    def map_func(x):
        return mapping.get(x, 1)
    
    y_true = np.vectorize(map_func)(y_true)
    y_pred = np.vectorize(map_func)(y_pred)
    
    # Calculate accuracy
    accuracy = accuracy_score(y_true=y_true, y_pred=y_pred)
    print(f'Accuracy: {accuracy:.3f}')
    
    # Generate accuracy report
    unique_labels = set(y_true)  # Get unique labels
    
    for label in unique_labels:
        label_indices = [i for i in range(len(y_true)) 
                         if y_true[i] == label]
        label_y_true = [y_true[i] for i in label_indices]
        label_y_pred = [y_pred[i] for i in label_indices]
        accuracy = accuracy_score(label_y_true, label_y_pred)
        print(f'Accuracy for label {label}: {accuracy:.3f}')
        
    # Generate classification report
    class_report = classification_report(y_true=y_true, y_pred=y_pred)
    print('\nClassification Report:')
    print(class_report)
    
    # Generate confusion matrix
    conf_matrix = confusion_matrix(y_true=y_true, y_pred=y_pred, labels=[0, 1, 2])
    print('\nConfusion Matrix:')
    print(conf_matrix)

In [8]:
from transformers import AutoTokenizer, AutoModelForCausalLM

In [9]:
model_name = "meta-llama/Llama-2-7b-hf"
cache_dir = "/data3/zhenglon/huggingface/transformers"

compute_dtype = getattr(torch, "float16")

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, 
    bnb_4bit_quant_type="nf4", 
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=True,
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map=device,
    torch_dtype=compute_dtype,
    quantization_config=bnb_config,
    cache_dir=cache_dir,
    local_files_only=True,   # 强制只用你刚缓存到本地的文件
)

model.config.use_cache = False
model.config.pretraining_tp = 1

tokenizer = AutoTokenizer.from_pretrained(
    model_name, 
    trust_remote_code=True,
    cache_dir=cache_dir,
    local_files_only=True,
)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

model, tokenizer = setup_chat_format(model, tokenizer)

Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.48s/it]


In [10]:
def predict(X_test, model, tokenizer):
    y_pred = []
    for i in tqdm(range(len(X_test))):
        prompt = X_test.iloc[i]["text"]
        pipe = pipeline(task="text-generation", 
                        model=model, 
                        tokenizer=tokenizer, 
                        max_new_tokens = 1, 
                        do_sample=False,
                       )
        result = pipe(prompt)
        # print(result)
        answer = result[0]['generated_text'].split("=")[-1]
        if "positive" in answer:
            y_pred.append("positive")
        elif "negative" in answer:
            y_pred.append("negative")
        elif "neutral" in answer:
            y_pred.append("neutral")
        else:
            y_pred.append("none")
    return y_pred

# 推理

In [None]:
# y_pred = predict(X_test, model, tokenizer)

In [12]:
# evaluate(y_true, y_pred)

# Baseline 大部分都预测为了 label 1

In [11]:
output_dir="trained_weigths"

peft_config = LoraConfig(
        lora_alpha=16,  
        lora_dropout=0.1,
        r=64, # rank 表示LoRA规模
        bias="none",
        target_modules="all-linear",
        task_type="CAUSAL_LM",
)

training_arguments = TrainingArguments(
    output_dir=output_dir,                    # directory to save and repository id
    num_train_epochs=3,                       # number of training epochs
    per_device_train_batch_size=1,            # batch size per device during training
    gradient_accumulation_steps=8,            # number of steps before performing a backward/update pass
    gradient_checkpointing=True,              # use gradient checkpointing to save memory
    optim="paged_adamw_32bit",
    save_steps=0,
    logging_steps=25,                         # log every 10 steps
    learning_rate=2e-4,                       # learning rate, based on QLoRA paper
    weight_decay=0.001,
    fp16=True,
    bf16=False,
    max_grad_norm=0.3,                        # max gradient norm based on QLoRA paper
    max_steps=-1,
    warmup_ratio=0.03,                        # warmup ratio based on QLoRA paper
    group_by_length=True,
    lr_scheduler_type="cosine",               # use cosine learning rate scheduler
    report_to="tensorboard",                  # report metrics to tensorboard
    evaluation_strategy="epoch"               # save checkpoint every epoch
)

trainer = SFTTrainer(
    model=model,
    args=training_arguments,
    train_dataset=train_data,
    eval_dataset=eval_data,
    peft_config=peft_config,
    dataset_text_field="text",
    tokenizer=tokenizer,
    max_seq_length=1024,
    packing=False,
    dataset_kwargs={
        "add_special_tokens": False,
        "append_concat_token": False,
    }
)

Map: 100%|██████████| 900/900 [00:00<00:00, 5505.67 examples/s]
Map: 100%|██████████| 150/150 [00:00<00:00, 6391.55 examples/s]


# LoRA 训练

In [12]:
# Train model
trainer.train()

You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss
0,0.8009,0.698089
2,0.5153,0.713891


TrainOutput(global_step=336, training_loss=0.7165863045624324, metrics={'train_runtime': 1258.2582, 'train_samples_per_second': 2.146, 'train_steps_per_second': 0.267, 'total_flos': 1.0717877554348032e+16, 'train_loss': 0.7165863045624324, 'epoch': 2.99})

In [13]:
# Save trained model and tokenizer
trainer.save_model()
tokenizer.save_pretrained(output_dir)

('trained_weigths/tokenizer_config.json',
 'trained_weigths/special_tokens_map.json',
 'trained_weigths/tokenizer.json')

In [14]:
import gc

del [model, tokenizer, peft_config, trainer, train_data, eval_data, bnb_config, training_arguments]
del [df, X_train, X_eval]
del [TrainingArguments, SFTTrainer, LoraConfig, BitsAndBytesConfig]

In [15]:
for _ in range(100):
    torch.cuda.empty_cache()
    gc.collect()

In [16]:
from peft import AutoPeftModelForCausalLM
finetuned_model = "./trained_weigths/"
compute_dtype = getattr(torch, "float16")
cache_dir = "/data3/zhenglon/huggingface/transformers"

tokenizer = AutoTokenizer.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    cache_dir=cache_dir,
    local_files_only=True,   # 只用本地缓存
    use_fast=True,
    trust_remote_code=True,
)

model = AutoPeftModelForCausalLM.from_pretrained(
     finetuned_model,
     torch_dtype=compute_dtype,
     return_dict=True,
     low_cpu_mem_usage=True,
     device_map=device,
)
print(f"working on {device}")
merged_model = model.merge_and_unload()
merged_model.save_pretrained("./merged_model",safe_serialization=True, max_shard_size="2GB")
tokenizer.save_pretrained("./merged_model")

Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.15s/it]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


working on cuda:0


('./merged_model/tokenizer_config.json',
 './merged_model/special_tokens_map.json',
 './merged_model/tokenizer.json')

In [18]:
# 训练时最好有 pad_token_id
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

# 仅 LoRA 参与训练（保险起见再显式设一次）
for n, p in model.named_parameters():
    p.requires_grad = ("lora_" in n)

# 可选：查看可训练参数量
try:
    model.print_trainable_parameters()
except Exception:
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total     = sum(p.numel() for p in model.parameters())
    print(f"Trainable params: {trainable:,} / {total:,}")


trainable params: 0 || all params: 6,738,432,000 || trainable%: 0.0


In [17]:
y_pred = predict(X_test, merged_model, tokenizer)
evaluate(y_true, y_pred)

100%|██████████| 900/900 [00:27<00:00, 32.75it/s]

Accuracy: 0.834
Accuracy for label 0: 0.887
Accuracy for label 1: 0.840
Accuracy for label 2: 0.777

Classification Report:
              precision    recall  f1-score   support

           0       0.95      0.89      0.92       300
           1       0.72      0.84      0.78       300
           2       0.86      0.78      0.82       300

    accuracy                           0.83       900
   macro avg       0.84      0.83      0.84       900
weighted avg       0.84      0.83      0.84       900


Confusion Matrix:
[[266  31   3]
 [ 13 252  35]
 [  1  66 233]]





In [20]:
evaluation = pd.DataFrame({'text': X_test["text"], 
                           'y_true':y_true, 
                           'y_pred': y_pred},
                         )
evaluation.to_csv("test_predictions.csv", index=False)

In [22]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
from sentence_transformers import SentenceTransformer, util
import spacy, re, random
nlp = spacy.load("en_core_web_sm")

# fin_tokenizer = AutoTokenizer.from_pretrained("ProsusAI/finbert")
# fin_model = AutoModelForSequenceClassification.from_pretrained("ProsusAI/finbert")

_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
_FIN = "ProsusAI/finbert"
fin_tokenizer = AutoTokenizer.from_pretrained(_FIN)
fin_model     = AutoModelForSequenceClassification.from_pretrained(_FIN).to(_DEVICE).eval()

sbert = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

In [23]:
# todo
SYN_DICT = {
    "rise": ["increase","climb","advance"],
    "fall": ["decline","drop","slide","dip"],
    "strong": ["robust","solid","firm"],
    "weak": ["soft","fragile","subdued"]
}

def synonym_replace(text, max_repl=2):
    doc = nlp(text)
    repl_idx = [i for i,t in enumerate(doc) 
                if t.pos_ in {"ADJ","ADV","NOUN","VERB"} and t.ent_type_=="" and t.lemma_.lower() in SYN_DICT]
    random.shuffle(repl_idx)
    out = list(token.text for token in doc)
    for i in repl_idx[:max_repl]:
        lemma = doc[i].lemma_.lower()
        cand = random.choice(SYN_DICT[lemma])
        out[i] = cand
    return spacy.tokens.doc.Doc(doc.vocab, words=out).text


In [24]:
# todo
SYN_DICT = {
    "rise": ["increase","climb","advance"],
    "fall": ["decline","drop","slide","dip"],
    "strong": ["robust","solid","firm"],
    "weak": ["soft","fragile","subdued"]
}

def synonym_replace(text, max_repl=2):
    doc = nlp(text)
    repl_idx = [i for i,t in enumerate(doc) 
                if t.pos_ in {"ADJ","ADV","NOUN","VERB"} and t.ent_type_=="" and t.lemma_.lower() in SYN_DICT]
    random.shuffle(repl_idx)
    out = list(token.text for token in doc)
    for i in repl_idx[:max_repl]:
        lemma = doc[i].lemma_.lower()
        cand = random.choice(SYN_DICT[lemma])
        out[i] = cand
    return spacy.tokens.doc.Doc(doc.vocab, words=out).text


In [25]:
#真实性过滤
# 数据增强

import re
import random

def template_paraphrase_v1(
    text: str,
    *,
    max_len_delta: float = 0.20,
    keep_polarity_fn=None,  # 可传入一个函数，如 finbert_pred -> 返回 {"label": ...}
    seed: int | None = 42
) -> str:
    """
    模板释义改写（v1）—— 为金融新闻标题设计的“轻度”改写器。
    规则集：副词缓和、强度词中和、等价短语替换、轻度语态改写、名词化/动词化。
    - 控制：改写后长度变化≤max_len_delta（默认20%）
    - 可选：若提供 keep_polarity_fn( text )-> {"label": str }，则保证改写前后情感极性一致；不一致则回退原句。
    """
    if seed is not None:
        random.seed(seed)

    orig = text
    L0   = len(orig.split())

    # ---------- 1) 等价短语替换（expectations / 指标描述） ----------
    # 注意：先做短语级，再做词级，避免相互覆盖
    rules_phrase = [
        (r"\bbeat expectations\b",              "exceeded expectations"),
        (r"\boutperformed expectations\b",      "exceeded expectations"),
        (r"\bmissed expectations\b",            "failed to meet expectations"),
        (r"\bfell short of expectations\b",     "failed to meet expectations"),
        (r"\bmet expectations\b",               "matched expectations"),
        (r"\babove guidance\b",                 "above its guidance"),
        (r"\bbelow guidance\b",                 "below its guidance"),
        (r"\bstrong demand\b",                  "robust demand"),
        (r"\bweak demand\b",                    "subdued demand"),
    ]

    # ---------- 2) 强度副词中和（夸张→温和） ----------
    rules_intensity = [
        (r"\b(sharply|dramatically|significantly|substantially|considerably)\b", "moderately"),
        (r"\b(soared|surged|skyrocketed)\b", "jumped"),
        (r"\b(plunged|tumbled|cratered)\b",  "fell"),
    ]

    # ---------- 3) 动词缓和（加入副词/轻改结构） ----------
    # 选择一个轻度副词，避免句子到处都是同一个词
    soften_adv = random.choice(["slightly", "modestly", "marginally", "somewhat"])

    def add_soft_adv(m):
        # 如 rose -> rose slightly / increased -> increased modestly
        return f"{m.group(1)} {soften_adv}"

    rules_verbs_soften = [
        # 上涨类
        (r"\b(rose|increased|climbed|gained|advanced)\b", add_soft_adv),
        # 下跌类
        (r"\b(fell|declined|dropped|slid|weakened)\b",     add_soft_adv),
    ]

    # ---------- 4) 轻度语态/结构改写 ----------
    rules_syntax = [
        (r"\bThe company reported\b", "A report from the company indicated"),
        (r"\bThe firm reported\b",    "A report from the firm indicated"),
        (r"\brevenues? (rose|increased)\b", r"revenue \1"),
        (r"\bprofits? (fell|declined)\b",  r"profit \1"),
        (r"\bwas up\b", "rose"),
        (r"\bwas down\b", "declined"),
    ]

    # ---------- 5) 依次应用规则 ----------
    new = text

    for pat, rep in rules_phrase:
        new = re.sub(pat, rep, new, flags=re.I)

    for pat, rep in rules_intensity:
        new = re.sub(pat, rep, new, flags=re.I)

    for pat, rep in rules_verbs_soften:
        new = re.sub(pat, rep, new, flags=re.I)

    for pat, rep in rules_syntax:
        new = re.sub(pat, rep, new, flags=re.I)

    # ---------- 6) 长度变化控制 ----------
    L1 = len(new.split())
    if abs(L1 - L0) / max(1, L0) > max_len_delta:
        # 超过约束就回退到原句
        new = orig

    # ---------- 7) 可选：极性一致性检查（如传入 finbert_pred） ----------
    if callable(keep_polarity_fn):
        try:
            if keep_polarity_fn(orig)["label"] != keep_polarity_fn(new)["label"]:
                new = orig  # 极性变了，回退
        except Exception:
            # 避免评估器异常导致崩溃
            new = orig

    return new


In [26]:
# Todo 实体替换
# SECTOR_PEERS = {"Apple":"Microsoft", "JPMorgan":"Bank of America", "Exxon":"Chevron"}
SECTOR_PEERS ={
    "Barclays": "Citigroup",
    "Citigroup": "Goldman Sachs",
    "Goldman Sachs": "Deutsche Bank",
    "Deutsche Bank": "Credit Suisse",
    "Credit Suisse": "JPMorgan",
    "JPMorgan": "UBS",
    "UBS": "Bank of America",
    "Bank of America": "Morgan Stanley",
    "Morgan Stanley": "Barclays",
    "United": "American Airlines",
    "American Airlines": "British Airways",
    "British Airways": "Delta",
    "Delta": "Emirates",
    "Emirates": "United",
    "Orange": "China Mobile",
    "China Mobile": "Verizon",
    "Verizon": "Vodafone",
    "Vodafone": "AT&T",
    "AT&T": "Orange",
    "Target": "CVS",
    "CVS": "Home Depot",
    "Home Depot": "Target",
    "Microsoft": "Google",
    "Google": "Apple",
    "Apple": "Intel",
    "Intel": "IBM",
    "IBM": "Qualcomm",
    "Qualcomm": "Microsoft",
    "BP": "Shell",
    "Shell": "BP",
    "Volkswagen": "Ford",
    "Ford": "Honda",
    "Honda": "Volkswagen",
    "Sony": "Fox",
    "Fox": "Sony",
    "X": "Snap"
}

def perturb_numbers_entities(text):
    def tweak_num(m):
        x = float(m.group(1))
        y = round(x + random.choice([-0.2, -0.1, 0.1, 0.2]), 1)  # 百分点微调
        return f"{y}%"
    text2 = re.sub(r"(\d+\.?\d*)\s?%", tweak_num, text)

    for a,b in SECTOR_PEERS.items():
        if a in text2:
            text2 = text2.replace(a,b,1); break
    return text2


In [27]:
ALLOWED_NER = {"ORG","GPE","PRODUCT","MONEY","PERCENT"}

def entity_label_set(text: str) -> set:
    doc = nlp(text)
    # 若管线里没有 NER，返回空集（相当于不触发该约束）
    if "ner" not in nlp.pipe_names:
        return set()
    return {ent.label_ for ent in doc.ents if ent.label_ in ALLOWED_NER}

def pass_filter(src: str, tgt: str, sim_th: float = 0.85, len_ratio: float = 0.2):
    # 语义相似度（SBERT）
    emb = sbert.encode([src, tgt], convert_to_tensor=True, normalize_embeddings=True)
    sim = float(util.cos_sim(emb[0], emb[1]))
    if sim < sim_th:
        return False, {"sim": sim}

    # 长度变化约束
    L0, L1 = len(src.split()), len(tgt.split())
    if abs(L1 - L0) / max(1, L0) > len_ratio:
        return False, {"sim": sim, "len_ok": False}

    # 实体“类型集合”一致（允许同行替换）
    src_labels = entity_label_set(src)
    tgt_labels = entity_label_set(tgt)
    if src_labels != tgt_labels:
        return False, {"sim": sim, "ner_ok": False, "src_ner": sorted(src_labels), "tgt_ner": sorted(tgt_labels)}

    return True, {"sim": sim}

In [28]:
# 是否对抗
# def finbert_pred(text):
#     scores = fin_pipe(text)[0]  # list of dicts: {'label':'positive','score':...}
#     scores = {d['label'].lower(): d['score'] for d in scores}
#     label = max(scores, key=scores.get)
#     margin = scores[label] - max([v for k,v in scores.items() if k!=label])
#     return label, margin, scores

@torch.no_grad()
def finbert_pred(text: str):
    enc = fin_tokenizer(text, return_tensors="pt", truncation=True).to(_DEVICE)
    logits = fin_model(**enc).logits               # [1, 3]
    probs  = F.softmax(logits, dim=-1).squeeze(0)  # torch.Tensor，不用 numpy
    id2label = {i: fin_model.config.id2label[i].lower() for i in range(probs.numel())}
    scores   = {id2label[i]: float(probs[i].item()) for i in range(probs.numel())}
    label    = max(scores, key=scores.get)
    top = scores[label]; second = max(v for k,v in scores.items() if k != label)
    margin = top - second
    return label, margin, scores

def adversarial_tag(orig, pert, margin_drop=0.2):
    y0, m0, s0 = finbert_pred(orig)
    y1, m1, s1 = finbert_pred(pert)
    flipped = (y0 != y1)
    weakened = (m0 - m1) >= margin_drop
    return {"orig":y0,"pert":y1,"flipped":flipped,"weakened":weakened,"m0":m0,"m1":m1}


In [29]:
# 
import torch.nn.functional as F

def gen_perturb_sample(text):
    cands = [
        synonym_replace(text),
        template_paraphrase_v1(text),
        perturb_numbers_entities(text),
    ]
    outs = []
    for t in cands:
        ok, info = pass_filter(text, t)
        if not ok: 
            outs.append({"orig":text,"pert":t,"passed":False, **info}); continue
        tag = adversarial_tag(text, t)
        outs.append({"orig":text,"pert":t,"passed":True, **tag})
    return outs

# 对一批金融标题生成结果表
def build_library(headlines, K_per_type=1):
    rows = []
    for h in headlines:
        rows += gen_perturb_sample(h)
    return rows

# print(X_train_.text.tolist()[:1])

build_library(X_train_new.text.tolist()[:100])


[{'orig': 'neutral',
  'pert': 'neutral',
  'passed': True,
  'flipped': False,
  'weakened': False,
  'm0': 0.9077935703098774,
  'm1': 0.9077935703098774},
 {'orig': 'neutral',
  'pert': 'neutral',
  'passed': True,
  'flipped': False,
  'weakened': False,
  'm0': 0.9077935703098774,
  'm1': 0.9077935703098774},
 {'orig': 'neutral',
  'pert': 'neutral',
  'passed': True,
  'flipped': False,
  'weakened': False,
  'm0': 0.9077935703098774,
  'm1': 0.9077935703098774},
 {'orig': 'positive',
  'pert': 'positive',
  'passed': True,
  'flipped': False,
  'weakened': False,
  'm0': 0.9347926788032055,
  'm1': 0.9347926788032055},
 {'orig': 'positive',
  'pert': 'positive',
  'passed': True,
  'flipped': False,
  'weakened': False,
  'm0': 0.9347926788032055,
  'm1': 0.9342988766729832},
 {'orig': 'positive',
  'pert': 'positive',
  'passed': True,
  'flipped': False,
  'weakened': False,
  'm0': 0.9347926788032055,
  'm1': 0.9347926788032055},
 {'orig': 'negative',
  'pert': 'negative',
  

In [30]:
# To Do
# 扰动结果分析



In [31]:
# 1) FinBERT 句向量相似度（像不像同一句话）
# 先粘贴一版可用的接口（不用懂内部）
from transformers import AutoTokenizer, AutoModel
import torch, torch.nn.functional as F

_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
_FIN = "ProsusAI/finbert"
_tok = AutoTokenizer.from_pretrained(_FIN)
_enc = AutoModel.from_pretrained(_FIN).to(_DEVICE).eval()

@torch.no_grad()
def _embed(text):
    t = _tok(text, return_tensors="pt", truncation=True, padding=True, max_length=128).to(_DEVICE)
    out = _enc(**t).last_hidden_state[:,0,:]               # 取 [CLS]
    return F.normalize(out, p=2, dim=-1).squeeze(0)

def cosine_sim(a, b):
    return float(F.cosine_similarity(a[None], b[None]).item())

# 2) 简单可读性（先占位：用长度限制替代，后面再换 PPL 也行）
def readability_ok(text, max_len=80):
    return len(text) <= max_len

In [32]:
TAU_SEM = 0.85   # 语义相似度阈值（先用 0.85）
def pass_filter(orig, pert):
    sim = cosine_sim(_embed(orig), _embed(pert))
    ok = (sim >= TAU_SEM) and readability_ok(pert)
    return ok, {"sim": sim}

In [33]:
# ====== 1) FinBERT (reward judge) 概率接口 ======
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# 用于“奖励”的 FinBERT（带分类头）
FIN_CLS = "yiyanghkust/finbert-tone"
fin_tok = AutoTokenizer.from_pretrained(FIN_CLS)
fin_mdl = AutoModelForSequenceClassification.from_pretrained(FIN_CLS).to(DEVICE).eval()

@torch.no_grad()
def finbert_probs(texts):
    toks = fin_tok(texts, padding=True, truncation=True, return_tensors="pt", max_length=256).to(DEVICE)
    logits = fin_mdl(**toks).logits
    return logits.softmax(-1)  # [B, 3]

In [34]:
# =========================
# Verbalizer 体检与清洗工具
# 依赖：只需要已经就绪的 `tokenizer`
# =========================

from typing import Dict, List, Tuple
import re

def tokenize_no_special(text: str, tokenizer) -> List[int]:
    """对字符串做分词，返回 token id 列表（不加特殊符号）。"""
    return tokenizer.encode(text, add_special_tokens=False)

def inspect_tokens(words: List[str], tokenizer) -> List[Tuple[str, List[int]]]:
    """
    对每个词做体检，返回 (词, token_id_list)。
    建议传入带前导空格的词（如 ' positive'），更易单 token。
    """
    report = []
    for w in words:
        ids = tokenize_no_special(w, tokenizer)
        report.append((w, ids))
    return report

def try_variants_for_single_token(word: str, tokenizer, max_variants: int = 6) -> Tuple[str, List[int]]:
    """
    自动尝试若干“变体”，力争找出一个单 token 的版本。
    返回 (选中的词形, token_ids)。若都不行，返回最后一次的 ids。
    变体规则可按需扩展。
    """
    candidates = []
    w = word
    # 常用变体：加/去前导空格、首字母小写、大写、带连字符等
    variants = [
        w, 
        " " + w.lstrip(), 
        w.strip(), 
        w.lower() if w != w.lower() else w,
        w.upper() if w != w.upper() else w,
        " " + w.lower().lstrip(),
        " " + re.sub(r"\s+", "-", w.strip()),  # 空格->连字符
    ]
    # 去重保序
    seen = set(); uniq_variants = []
    for v in variants:
        if v not in seen:
            uniq_variants.append(v); seen.add(v)

    for v in uniq_variants[:max_variants]:
        ids = tokenize_no_special(v, tokenizer)
        candidates.append((v, ids))
        if len(ids) == 1:
            return v, ids
    # 都不是单 token，则返回最后一个尝试
    return candidates[-1]

def build_verbalizers(
    seed_dict: Dict[str, List[str]], 
    tokenizer, 
    prefer_single_token: bool = True, 
    allow_last_token_fallback: bool = True,
    auto_variant_search: bool = True,
) -> Dict[str, Dict[str, List]]:
    """
    清洗入口：
    - 优先保留单 token；
    - 若允许回退，把多 token 词用“最后一个 token”近似；
    - 去重同一 token id；
    - 输出 {label: {"tokens": [ids...], "words": [对应词...], "dropped": [...], "multi_token": [...]} }
    """
    out = {}
    for label, words in seed_dict.items():
        kept_tokens, kept_words = [], []
        dropped, multi_info = [], []
        seen_ids = set()

        for w in words:
            if auto_variant_search:
                w_chosen, ids = try_variants_for_single_token(w, tokenizer)
            else:
                w_chosen, ids = w, tokenize_no_special(w, tokenizer)

            if len(ids) == 1:
                tid = ids[0]
                if tid not in seen_ids:
                    kept_tokens.append(tid); kept_words.append(w_chosen); seen_ids.add(tid)
                else:
                    dropped.append((w_chosen, ids, "dup_token"))
            else:
                # 多 token 情况
                multi_info.append((w_chosen, ids))
                if prefer_single_token:
                    if allow_last_token_fallback and len(ids) > 0:
                        tid = ids[-1]
                        if tid not in seen_ids:
                            kept_tokens.append(tid); kept_words.append(w_chosen+" (last)"); seen_ids.add(tid)
                        else:
                            dropped.append((w_chosen, ids, "dup_last_token"))
                    else:
                        dropped.append((w_chosen, ids, "multi_token_dropped"))
                else:
                    # 不强制单 token，直接聚合多 token：这里通常不建议
                    # 若一定要，改为 kept_tokens.extend(ids) 并在聚合时特殊处理
                    dropped.append((w_chosen, ids, "multi_token_ignored"))

        out[label] = {
            "tokens": kept_tokens,
            "words": kept_words,
            "dropped": dropped,       # 被丢弃或因重复删除的
            "multi_token": multi_info # 原始多 token 情况备案
        }
    return out

def pretty_print_verbalizer_report(verb_clean: Dict[str, Dict], tokenizer, label_order=None, max_show=999):
    """打印清洗结果与诊断信息。"""
    if label_order is None: label_order = list(verb_clean.keys())
    for label in label_order:
        blk = verb_clean[label]
        print(f"\n=== {label.upper()} ===")
        print(f"Kept {len(blk['tokens'])} tokens:")
        for w, tid in zip(blk["words"][:max_show], blk["tokens"][:max_show]):
            print(f"  - {w!r:<18} -> id={tid:<6} token_str={tokenizer.convert_ids_to_tokens([tid])[0]!r}")
        if blk["multi_token"]:
            print(f"Multi-token encountered ({len(blk['multi_token'])}):")
            for w, ids in blk["multi_token"][:max_show]:
                toks = tokenizer.convert_ids_to_tokens(ids)
                print(f"  * {w!r} -> ids={ids} toks={toks}")
        if blk["dropped"]:
            print(f"Dropped ({len(blk['dropped'])}):")
            for w, ids, reason in blk["dropped"][:max_show]:
                print(f"  x {w!r} -> {reason}  ids={ids}")

# ===== 示例：把你的初始词表丢进来体检清洗 =====
SEEDS = {
    "negative": [" negative"," bearish"," down"," slump"," plunge"," drop"," fall"," miss",
                 " lower"," downgrade"," warn"," loss"," decline"," weak"," soften"],
    "neutral":  [" neutral"," unchanged"," flat"," steady"," stable"," mixed"," inline"," rangebound",
                 " muted"," sideways"],
    "positive": [" positive"," bullish"," up"," rally"," surge"," jump"," soar"," beat",
                 " raise"," upgrade"," profit"," growth"," expand"," strong"," outperform"],
}

# 跑清洗（注意：这里用的是你已经加载好的 LLaMA tokenizer）
VERB = build_verbalizers(
    SEEDS, tokenizer,
    prefer_single_token=True, 
    allow_last_token_fallback=True,
    auto_variant_search=True
)


# 运行完你会看到一份清晰的报告：

# Kept：最终参与聚合的 token ids 与对应词形（后缀 “(last)” 表示走了“最后 token 回退”）。

# Multi-token：原始被拆的词及其 token 列表，便于你手工替换同义词。

# Dropped：因为重复 token 或策略丢弃的词。
pretty_print_verbalizer_report(VERB, tokenizer, label_order=["negative","neutral","positive"])


=== NEGATIVE ===
Kept 15 tokens:
  - 'negative'         -> id=8178   token_str='▁negative'
  - ' BEARISH (last)'  -> id=29950  token_str='H'
  - 'down'             -> id=1623   token_str='▁down'
  - ' SLUMP (last)'    -> id=3580   token_str='MP'
  - ' PLUNGE (last)'   -> id=1692   token_str='GE'
  - 'drop'             -> id=5768   token_str='▁drop'
  - 'fall'             -> id=6416   token_str='▁fall'
  - 'miss'             -> id=3052   token_str='▁miss'
  - 'lower'            -> id=5224   token_str='▁lower'
  - ' DOWNGRADE (last)' -> id=2287   token_str='DE'
  - 'warn'             -> id=29383  token_str='▁warn'
  - 'loss'             -> id=6410   token_str='▁loss'
  - ' DECLINE (last)'  -> id=8895   token_str='INE'
  - 'weak'             -> id=8062   token_str='▁weak'
  - ' SOFTEN (last)'   -> id=1430   token_str='EN'
Multi-token encountered (6):
  * ' BEARISH' -> ids=[29871, 20700, 1718, 3235, 29950] toks=['▁', '▁BE', 'AR', 'IS', 'H']
  * ' SLUMP' -> ids=[29871, 27146, 29965, 3580] 

In [35]:
# step 1: 构造/清洗 verbalizers（你前面已经跑完也行）
SEEDS = {
    "negative": [" negative"," bearish"," down"," slump"," plunge"," drop"," fall"," miss",
                 " lower"," downgrade"," warn"," loss"," decline"," weak"," soften"],
    "neutral":  [" neutral"," unchanged"," flat"," steady"," stable"," mixed"," inline"," rangebound",
                 " muted"," sideways"],
    "positive": [" positive"," bullish"," up"," rally"," surge"," jump"," soar"," beat",
                 " raise"," upgrade"," profit"," growth"," expand"," strong"," outperform"],
}

# 直接使用你之前提供的 build_verbalizers / pretty_print_verbalizer_report
VERB = build_verbalizers(
    SEEDS, tokenizer,
    prefer_single_token=True,
    allow_last_token_fallback=True,
    auto_variant_search=True
)
pretty_print_verbalizer_report(VERB, tokenizer, label_order=["negative","neutral","positive"])


=== NEGATIVE ===
Kept 15 tokens:
  - 'negative'         -> id=8178   token_str='▁negative'
  - ' BEARISH (last)'  -> id=29950  token_str='H'
  - 'down'             -> id=1623   token_str='▁down'
  - ' SLUMP (last)'    -> id=3580   token_str='MP'
  - ' PLUNGE (last)'   -> id=1692   token_str='GE'
  - 'drop'             -> id=5768   token_str='▁drop'
  - 'fall'             -> id=6416   token_str='▁fall'
  - 'miss'             -> id=3052   token_str='▁miss'
  - 'lower'            -> id=5224   token_str='▁lower'
  - ' DOWNGRADE (last)' -> id=2287   token_str='DE'
  - 'warn'             -> id=29383  token_str='▁warn'
  - 'loss'             -> id=6410   token_str='▁loss'
  - ' DECLINE (last)'  -> id=8895   token_str='INE'
  - 'weak'             -> id=8062   token_str='▁weak'
  - ' SOFTEN (last)'   -> id=1430   token_str='EN'
Multi-token encountered (6):
  * ' BEARISH' -> ids=[29871, 20700, 1718, 3235, 29950] toks=['▁', '▁BE', 'AR', 'IS', 'H']
  * ' SLUMP' -> ids=[29871, 27146, 29965, 3580] 

In [36]:
LABEL_ORDER = ["negative","neutral","positive"]

@torch.no_grad()
def probs_from_causallm(texts, tokenizer, causal_lm, VERB, label_order=LABEL_ORDER, max_length=256):
    causal_lm.eval()  # 推理更稳
    # 1) 构造 prompt（你也可以换成自己的分类prompt或chat模板）
    prompts = [f"Text: {t}\nSentiment:" for t in texts]

    # 2) 编码（不要手动 .to(cuda)，交给 HF 处理；对 device_map='auto' 兼容）
    toks = tokenizer(prompts, padding=True, truncation=True,
                     return_tensors="pt", max_length=max_length)
    
    toks = _to_model_device(toks, model)

    # 3) 前向
    out  = causal_lm(**toks)

    # 4) 取“下一个词”的 logits（最后一个非 pad 位置之后的预测）
    last_idx    = (toks["attention_mask"].sum(dim=1) - 1)                 # [B]
    logits_last = out.logits[torch.arange(out.logits.size(0)), last_idx]  # [B, V]

    # 5) 用 VERB 聚合到三类分数（logsumexp），再 softmax 得到概率
    scores = []
    for label in label_order:
        ids = VERB[label]["tokens"]
        if len(ids) == 0:
            sc = torch.full((logits_last.size(0),), -1e9, dtype=logits_last.dtype, device=logits_last.device)
        elif len(ids) == 1:
            sc = logits_last[:, ids[0]]
        else:
            sc = torch.logsumexp(logits_last[:, ids], dim=1)
        scores.append(sc)
    scores = torch.stack(scores, dim=1)   # [B, 3]
    probs  = scores.softmax(dim=1)        # [B, 3] -> 顺序=LABEL_ORDER
    return probs

In [37]:
texts = [
    "Apple shares rallied after beating earnings expectations.",
    "The company issued a profit warning and the stock fell.",
    "Indexes were little changed in thin trading."
]

p_x = probs_from_causallm(texts, tokenizer, model, VERB)   # [B,3], 顺序[neg, neu, pos]
print(p_x)
print("pred:", p_x.argmax(dim=-1).tolist())  # 2->positive, 0->negative, 1->neutral

tensor([[0.0174, 0.0128, 0.9698],
        [0.9776, 0.0166, 0.0058],
        [0.4102, 0.0610, 0.5287]], device='cuda:0')
pred: [2, 0, 2]


In [38]:
# 你可以按需要微调词表
LABEL_ORDER = ["negative", "neutral", "positive"]  # 统一顺序


class StudentClassifierForCausalLM(nn.Module):
    """
    把 LLaMA CausalLM + tokenizer 包装成三分类“概率模型”（不改你原模型结构）
    """
    def __init__(self, causal_lm, tokenizer, verbalizers=VERB, max_length=256, use_chat_template=False):
        super().__init__()
        self.causal_lm = causal_lm
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.use_chat_template = use_chat_template

        # 预编码 verbalizer -> token ids（尽量取单 token；若多 token，取最后一个作为近似）
        self.label_token_ids = {}
        for label, words in verbalizers.items():
            ids = []
            for w in words:
                toks = tokenizer.encode(w, add_special_tokens=False)
                if len(toks) == 1:
                    ids.append(toks[0])
                elif len(toks) > 1:
                    ids.append(toks[-1])  # 退路：取最后一个 token 聚合
            # 去重
            ids = list(sorted(set(ids)))
            if not ids:
                # 再保底一个常见 token，避免空
                ids = tokenizer.encode(" positive", add_special_tokens=False)[-1:]
            self.label_token_ids[label] = ids

    def _build_prompts(self, texts):
        # 与你之前 prompt 逻辑类似：让下一个词直接就是情感词
        # 若你用了 chat 模板 (setup_chat_format)，可以切换 use_chat_template=True
        if not self.use_chat_template:
            return [f"Text: {t}\nSentiment:" for t in texts]
        else:
            prompts = []
            for t in texts:
                msgs = [{"role":"user","content":f"Classify the sentiment of the following text into negative, neutral, or positive.\nText: {t}\nAnswer one word: Sentiment:"}]
                prompts.append(self.tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True))
            return prompts

    @torch.no_grad()
    def forward(self, texts):
        self.causal_lm.eval()
        prompts = self._build_prompts(texts)

        toks = self.tokenizer(
            prompts, padding=True, truncation=True,
            return_tensors="pt", max_length=self.max_length
        )
        
        toks = _to_model_device(toks, self.causal_lm)  # 兼容 device_map='auto'

        out = self.causal_lm(**toks)
        # 取“下一个词”的 logits
        last_idx = (toks["attention_mask"].sum(dim=1) - 1)  # [B]
        logits_last = out.logits[torch.arange(out.logits.size(0)), last_idx, :]  # [B, V]

        # 把若干 verbalizer token 聚合成 3 类得分（logsumexp）
        label_scores = []
        for label in LABEL_ORDER:
            ids = self.label_token_ids[label]
            if len(ids) == 1:
                score = logits_last[:, ids[0]]
            else:
                score = torch.logsumexp(logits_last[:, ids], dim=1)
            label_scores.append(score)
        scores = torch.stack(label_scores, dim=1)  # [B, 3]
        probs  = scores.softmax(dim=1)
        return probs  # [B,3]



# 直接用你已有的 `model` 和 `tokenizer`
student_model = StudentClassifierForCausalLM(
    causal_lm = model,
    tokenizer = tokenizer,
    verbalizers = VERB,
    max_length = 256,
    use_chat_template = False  # 如果你的 setup_chat_format 必须走 chat 模板，就改成 True
)

@torch.no_grad()
def student_probs(texts):
    return student_model.forward(texts)  # -> [B,3]

def predict_by_probs(texts, batch_size=64):
    y_pred = []
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i+batch_size]
        p = student_probs(batch)                  # [B,3]
        idx = p.argmax(dim=-1).tolist()
        for j in idx:
            y_pred.append(["negative","neutral","positive"][j])
    return y_pred


In [39]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
_ = fin_mdl.to(DEVICE).eval()

# 把 finbert 输出列顺序对齐到 LABEL_ORDER
FIN_ID2LABEL = {int(k): v.lower() for k, v in fin_mdl.config.id2label.items()}  # e.g., {0:'positive',1:'negative',2:'neutral'}
FIN_LABEL2IDX = {v: k for k, v in FIN_ID2LABEL.items()}

def _finbert_reorder_indices(label_order):
    # 返回一个 index 列表，用于把 q[:, finbert_idx] 变成 q[:, label_order_idx]
    idxs = []
    for lab in label_order:
        if lab not in FIN_LABEL2IDX:
            raise ValueError(f"FinBERT没有这个label: {lab}. 现有: {list(FIN_LABEL2IDX.keys())}")
        idxs.append(FIN_LABEL2IDX[lab])
    return torch.tensor(idxs, device=DEVICE)

FIN_REORDER_IDX = _finbert_reorder_indices(LABEL_ORDER)  # shape [3]

@torch.no_grad()
def finbert_probs(texts):
    toks = fin_tokenizer(texts, padding=True, truncation=True,
                         return_tensors="pt", max_length=256).to(DEVICE)
    logits = fin_mdl(**toks).logits            # [B, 3]（FinBERT的自身顺序）
    q = logits.softmax(-1)
    # 重排到你的 LABEL_ORDER
    q = q.index_select(dim=1, index=FIN_REORDER_IDX)  # [B, 3]，顺序与 LABEL_ORDER 一致
    return q


In [None]:
# todo 奖励函数 READ

def _sym_kl(p, q):
    p = p.clamp_min(1e-12); q = q.clamp_min(1e-12)
    return (F.kl_div(p.log(), q, reduction="none").sum(-1) +
            F.kl_div(q.log(), p, reduction="none").sum(-1))

def consistency_reward(p_x, p_xt):
    # D：x 与 x~ 的分布要接近 → 负对称KL当奖励
    return - _sym_kl(p_x, p_xt)

@torch.no_grad()
def finbert_align_reward(texts, p_x):
    # C2：学生分布贴近 FinBERT（判官）
    q_x = finbert_probs(texts)   # 已对齐到 LABEL_ORDER
    q_x = q_x.clamp_min(1e-12); p = p_x.clamp_min(1e-12)
    return 1.0 - F.kl_div(q_x.log(), p, reduction="none").sum(-1)

def calibration_reward(p_x):
    # 轻量置信度正则（-Entropy）
    p = p_x.clamp_min(1e-12)
    return (p * p.log()).sum(-1)

W_CONS, W_FIN, W_CAL = 1.0, 0.5, 0.01

def total_reward(p_x, p_xt, texts):
    return (W_CONS * consistency_reward(p_x, p_xt)
            + W_FIN  * finbert_align_reward(texts, p_x)
            + W_CAL  * calibration_reward(p_x))


In [41]:
import random

def pick_one_pert(text):
    cands = gen_perturb_sample(text)        # 你的函数：返回多个候选及 passed 标志
    passed = [c for c in cands if c.get("passed", False)]
    if not passed: 
        return None
    return random.choice(passed)["pert"]


In [None]:
from torch.optim import AdamW

EPS_CLIP   = 0.1
LR         = 1e-5
KL_COEF    = 0.01
EMA_ALPHA  = 0.9
baseline_ema = 0.0

# 只优化可训练参数（如果你挂了 LoRA，只会包含 LoRA 的参数）
trainable_params = [p for p in student_model.causal_lm.parameters() if p.requires_grad]
print( f"Training {len(trainable_params)} parameters." ) # 0
optimizer = AdamW(trainable_params, lr=LR)

def _batched_probs_with_grad(texts):
    """在训练图里取 p(x) 概率；student_model 内部会用 causal_lm 前向"""
    student_model.causal_lm.train()
    # 用包装器直接拿 [B,3] 概率（带梯度）
    # 取消 no_grad：我们需要反向
    prompts = student_model._build_prompts(texts)
    toks = student_model.tokenizer(
        prompts, padding=True, truncation=True,
        return_tensors="pt", max_length=student_model.max_length
    )
    toks = _to_model_device(toks, student_model.causal_lm)
    out = student_model.causal_lm(**toks)
    last_idx = (toks["attention_mask"].sum(dim=1) - 1)
    logits_last = out.logits[torch.arange(out.logits.size(0)), last_idx, :]
    label_scores = []
    for label in LABEL_ORDER:
        ids = student_model.label_token_ids[label]
        if len(ids) == 1:
            score = logits_last[:, ids[0]]
        else:
            score = torch.logsumexp(logits_last[:, ids], dim=1)
        label_scores.append(score)
    scores = torch.stack(label_scores, dim=1)
    probs  = scores.softmax(dim=1)  # [B,3]
    # 选动作（贪心），并得到 logp(a|x)
    actions = probs.argmax(dim=-1)
    logp = (probs.clamp_min(1e-12).log().gather(1, actions[:,None]).squeeze(1))
    return probs, actions, logp

@torch.no_grad()
def _batched_probs_old(texts):
    student_model.causal_lm.eval()
    p = student_probs(texts)  # 你的无梯度接口
    actions = p.argmax(dim=-1)
    logp = (p.clamp_min(1e-12).log().gather(1, actions[:,None]).squeeze(1))
    return p, actions, logp

def grpo_train_step(batch_texts):
    global baseline_ema

    # 1) 构造成对 (x, x̃)
    texts, texts_tilde = [], []
    for t in batch_texts:
        pt = pick_one_pert(t)
        if pt is None:
            continue
        texts.append(t)
        texts_tilde.append(pt)
    if not texts:
        return {"skipped": True}

    # 2) 旧策略
    p_old, _, logp_old = _batched_probs_old(texts)

    # 3) 当前策略（带梯度）
    p_new, _, logp_new = _batched_probs_with_grad(texts)
    p_xt, _, _         = _batched_probs_with_grad(texts_tilde)

    # 4) 奖励 & 优势（奖励不回传）
    with torch.no_grad():
        R = total_reward(p_new.detach(), p_xt.detach(), texts)  # [B]
        baseline_ema = EMA_ALPHA*baseline_ema + (1-EMA_ALPHA)*R.mean().item()
        adv = R - baseline_ema

    # 5) PPO/GRPO 裁剪目标
    ratio = torch.exp(logp_new - logp_old)
    unclipped = - ratio * adv
    clipped   = - torch.clamp(ratio, 1.0-EPS_CLIP, 1.0+EPS_CLIP) * adv
    loss_policy = torch.maximum(unclipped, clipped).mean()

    # 6) 稳定项：新旧分布 KL 正则
    kl_reg = F.kl_div(p_old.clamp_min(1e-12).log(), p_new.clamp_min(1e-12), reduction="batchmean")
    loss = loss_policy + KL_COEF * kl_reg

    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(trainable_params, 1.0)
    optimizer.step()

    with torch.no_grad():
        flip_rate = (p_new.argmax(-1) != p_xt.argmax(-1)).float().mean().item()

    return {"loss": float(loss.item()), "R": float(R.mean().item()), "flip": flip_rate}


Training 448 parameters.


In [None]:
from tqdm import trange
import random

def run_grpo(train_texts, batch_size=32, steps=200):
    for step in trange(steps):
        batch = random.sample(train_texts, k=min(batch_size, len(train_texts)))
        stats = grpo_train_step(batch)
        if stats.get("skipped"): 
            continue
        if step % 10 == 0:
            print(f"[{step}] loss={stats['loss']:.4f}  R={stats['R']:.3f}  flip={stats['flip']:.3f}")

# 评估（用你现有的预测接口已换成概率版）
def evaluate_macro_f1(texts, labels, batch_size=64):
    from sklearn.metrics import f1_score
    preds = predict_by_probs(texts, batch_size=batch_size)   # 使用上面写好的概率->argmax
    return f1_score(labels, preds, labels=["negative","neutral","positive"], average="macro")


In [44]:
# E1) 准备训练/验证集（把列名改成你实际的）
train_texts = X_train_new["text"].astype(str).tolist()
val_texts   = X_eval_new["text"].astype(str).tolist()
val_labels  = X_eval_new["sentiment"].astype(str).tolist()

# E2) 先跑一个 sanity 小步
run_grpo(train_texts, batch_size=32, steps=200)

# E3) 评估（用你的概率预测接口）
print("Macro-F1 after GRPO:", evaluate_macro_f1(val_texts, val_labels))


  0%|          | 1/200 [00:02<08:29,  2.56s/it]

[0] loss=1.8453  R=-2.050  flip=0.000


  6%|▌         | 11/200 [00:26<07:28,  2.37s/it]

[10] loss=nan  R=nan  flip=0.000


 10%|█         | 21/200 [00:50<07:04,  2.37s/it]

[20] loss=nan  R=nan  flip=0.000


 16%|█▌        | 31/200 [01:13<06:39,  2.36s/it]

[30] loss=nan  R=nan  flip=0.000


 20%|██        | 41/200 [01:37<06:16,  2.37s/it]

[40] loss=nan  R=nan  flip=0.000


 26%|██▌       | 51/200 [02:01<05:54,  2.38s/it]

[50] loss=nan  R=nan  flip=0.000


 30%|███       | 61/200 [02:24<05:30,  2.38s/it]

[60] loss=nan  R=nan  flip=0.000


 36%|███▌      | 71/200 [02:48<05:05,  2.37s/it]

[70] loss=nan  R=nan  flip=0.000


 40%|████      | 81/200 [03:12<04:41,  2.37s/it]

[80] loss=nan  R=nan  flip=0.000


 46%|████▌     | 91/200 [03:35<04:18,  2.37s/it]

[90] loss=nan  R=nan  flip=0.000


 50%|█████     | 101/200 [03:59<03:55,  2.37s/it]

[100] loss=nan  R=nan  flip=0.000


 56%|█████▌    | 111/200 [04:23<03:30,  2.37s/it]

[110] loss=nan  R=nan  flip=0.000


 60%|██████    | 121/200 [04:46<03:07,  2.37s/it]

[120] loss=nan  R=nan  flip=0.000


 66%|██████▌   | 131/200 [05:10<02:43,  2.37s/it]

[130] loss=nan  R=nan  flip=0.000


 70%|███████   | 141/200 [05:34<02:19,  2.37s/it]

[140] loss=nan  R=nan  flip=0.000


 76%|███████▌  | 151/200 [05:58<01:56,  2.37s/it]

[150] loss=nan  R=nan  flip=0.000


 80%|████████  | 161/200 [06:21<01:32,  2.37s/it]

[160] loss=nan  R=nan  flip=0.000


 86%|████████▌ | 171/200 [06:45<01:08,  2.37s/it]

[170] loss=nan  R=nan  flip=0.000


 90%|█████████ | 181/200 [07:09<00:45,  2.37s/it]

[180] loss=nan  R=nan  flip=0.000


 96%|█████████▌| 191/200 [07:32<00:21,  2.37s/it]

[190] loss=nan  R=nan  flip=0.000


100%|██████████| 200/200 [07:54<00:00,  2.37s/it]


Macro-F1 after GRPO: 0.16666666666666666
