In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM
model_name = "/home/yunpeng/lianqi/fnlp/JudgePred/model_zoo/Qwen3-8B"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map="auto",
)


  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 7/7 [00:05<00:00,  1.33it/s]


In [2]:
model.device

device(type='cuda', index=0)

In [3]:
# prepare the model input
prompt = "Give me a short introduction to large language model."
messages = [
    {"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True,
    enable_thinking=False # Switches between thinking and non-thinking modes. Default is True.
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

print("template loaded")

# conduct text completion
generated_ids = model.generate(
    **model_inputs,
    max_new_tokens=32768
)
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist() 

print("finish generating")

# parsing thinking content
try:
    # rindex finding 151668 (</think>)
    index = len(output_ids) - output_ids[::-1].index(151668)
except ValueError:
    index = 0

thinking_content = tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip("\n")
content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")

print("thinking content:", thinking_content)
print("content:", content)


template loaded
finish generating
thinking content: 
content: A large language model (LLM) is an advanced type of artificial intelligence that is trained on vast amounts of text data to understand and generate human-like text. These models can perform a wide range of tasks, such as answering questions, writing stories, coding, and translating languages. LLMs are built using deep learning techniques, particularly transformer architectures, which allow them to process and generate text efficiently. Their ability to comprehend and produce natural language makes them powerful tools in various applications, from customer service chatbots to content creation and research assistance.


In [1]:
import json
import pandas as pd
from collections import defaultdict

# 加载罪名映射
with open('dataset/st1/charges.json') as f:
    charge_map = json.load(f)
id_to_charge = {v: k for k, v in charge_map.items()}

# 处理训练数据
def process_data(file_path):
    cases = []
    with open(file_path) as f:
        for line in f:
            case = json.loads(line)
            defendants = case['defendants']
            charges = []
            imprisonments = []
            
            for outcome in case['outcomes']:
                charge_list = [j['standard_accusation'] for j in outcome['judgment'] if j['standard_accusation']]
                imprisonment_list = [int(j['imprisonment']) if j['imprisonment'] else 0 
                                   for j in outcome['judgment']]
                
                charges.append(charge_list)
                imprisonments.append(imprisonment_list)
            
            cases.append({
                'id': case.get('id', ''),
                'fact': case['fact'],
                'defendants': defendants,
                'charges': charges,
                'imprisonments': imprisonments
            })
    return cases

train_data = process_data('dataset/st1/train.jsonl')
test_data = process_data('dataset/st1/test.jsonl')

# 构建标签矩阵
def build_label_matrices(data):
    charge_labels = []
    imprisonment_labels = []
    
    for case in data:
        case_charge_labels = []
        case_imprisonment_labels = []
        
        for defendant_charges, defendant_imprisonments in zip(case['charges'], case['imprisonments']):
            # 多标签罪名向量
            charge_vec = [0] * len(charge_map)
            for charge in defendant_charges:
                if charge in charge_map:
                    charge_vec[charge_map[charge]] = 1
            
            case_charge_labels.append(charge_vec)
            case_imprisonment_labels.append(defendant_imprisonments)
        
        charge_labels.append(case_charge_labels)
        imprisonment_labels.append(case_imprisonment_labels)
    
    return charge_labels, imprisonment_labels

train_charge_labels, train_imprisonment_labels = build_label_matrices(train_data)

In [9]:
train_data[100]

{'id': '',
 'fact': '某市人民检察院指控，被告人郑某、谢某于2015年5月至9月6日间，在没有取得药品经营许可证的情况下，通过手机微信销售“香港威龙生物延时伟某某”。2015年9月6日21时许，郑某、谢某在帮顾客送货至某市池尾街道邮政局附近时被公安机关抓获，现场缴获“香港威龙生物延时伟某某”2盒，根据《中华人民共和国药品管理法》的规定，被查获的药品应认定为假药。某市人民检察院向法庭提供了作案地点、查获的药品的照片，提取笔录，扣押清单，食品药品监督管理部门出具的函复，被告人供述及户籍证明等证据。认为被告人郑某、谢某的行为触犯了《中华人民共和国刑法》x法条的规定，均已构成x罪。提请本院依法判处。被告人郑某、谢某对公诉机关指控的犯罪事实无异议。经审理查明，2015年5月至9月6日间，被告人郑某、谢某在没有取得药品经营许可证的情况下，通过手机微信销售“香港威龙生物延时伟某某”。2015年9月6日21时许，郑某、谢某为顾客送货至某市池尾街道邮政局附近时被公安机关抓获，现场缴获“香港威龙生物延时伟某某”2盒。经揭阳市食品药品监督管理局鉴定，查获的药品为假药。认定上述事实，有公诉机关提供并经庭审示证、质证，本院予以确认的下列证据证明：1.经被告人郑某、谢某确认无误的作案地点、查获药品的照片。2.提取笔录、扣押清单，证明：某市公安局于2015年9月6日在某市池尾街道邮政局附近，从郑某身上查获“香港威龙生物延时伟某某”2盒。3.揭阳市食品药品监督管理局出具的《关于＜委托鉴定函＞的函复》，证明：经鉴定，查获的“香港威龙生物延时伟某某”为假药。4.被告人郑某、谢某的供述，分别证明：他们两人是夫妻关系。2015年5月至9月6日间，郑某在没有取得药品经营许可证的情况下，通过手机微信销售“香港威龙生物延时伟某某”，已售出8盒。有两次是他们夫妻两个一起去送货给买家。2015年9月6日21时许，他们为顾客送货至某市池尾街道邮政局附近时被公安机关抓获。5.被告人郑某、谢某的户籍证明材料。',
 'defendants': ['郑某', '谢某'],
 'charges': [['生产、销售、提供假药罪'], ['生产、销售、提供假药罪']],
 'imprisonments': [[6], [6]]}

In [2]:
import torch
import torch_geometric
from torch_geometric.data import Data

# 构建案件事实图
def build_fact_graph(fact, defendants):
    # 使用依存分析构建图 (简化版)
    nodes = defendants + ['PROSECUTOR', 'VICTIM', 'EVIDENCE']  # 添加法律角色
    node_features = torch.randn(len(nodes), 128)  # 实际应使用法律实体嵌入
    
    # 创建边 (实际应使用句法分析)
    edges = []
    for i in range(len(nodes)):
        for j in range(i+1, len(nodes)):
            edges.append([i, j])
    
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    return Data(x=node_features, edge_index=edge_index)

# GNN编码器
class LegalGNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch_geometric.nn.GCNConv(128, 256)
        self.conv2 = torch_geometric.nn.GCNConv(256, 512)
    
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = torch.relu(x)
        x = self.conv2(x, edge_index)
        return x

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from transformers import AutoModel, AutoTokenizer, TrainingArguments, Trainer

class QwenLegalModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.qwen = AutoModel.from_pretrained("/home/yunpeng/lianqi/fnlp/JudgePred/model_zoo/Qwen3-8B")
        self.gnn = LegalGNN()
        self.charge_classifier = torch.nn.Linear(512, len(charge_map))
        self.imprisonment_predictor = torch.nn.Linear(512, 1)  # 预测单个罪名刑期
    
    def forward(self, input_ids, attention_mask, graph_data):
        # 文本编码
        text_emb = self.qwen(input_ids, attention_mask).last_hidden_state[:, 0, :]
        
        # 图编码
        graph_emb = self.gnn(graph_data)
        defendant_embs = graph_emb.x[:len(graph_data.defendants)]
        
        # 融合特征
        fused_emb = torch.cat([text_emb, defendant_embs], dim=1)
        
        # 罪名预测
        charge_logits = self.charge_classifier(fused_emb)
        
        # 刑期预测
        imprisonment_preds = self.imprisonment_predictor(fused_emb)
        
        return charge_logits, imprisonment_preds

In [4]:
class LegalLoss(torch.nn.Module):
    def __init__(self, alpha=0.7):
        super().__init__()
        self.alpha = alpha  # 罪名损失权重
        self.charge_loss = torch.nn.BCEWithLogitsLoss()
        self.imprisonment_loss = torch.nn.L1Loss()  # MAE更适合刑期预测
    
    def forward(self, charge_logits, imprisonment_preds, charge_labels, imprisonment_labels):
        # 罪名损失
        charge_loss = self.charge_loss(charge_logits, charge_labels.float())
        
        # 刑期损失 (只计算实际存在的罪名)
        imp_loss = 0
        count = 0
        for i, def_imprisonments in enumerate(imprisonment_labels):
            for j, term in enumerate(def_imprisonments):
                imp_loss += self.imprisonment_loss(imprisonment_preds[i][j], torch.tensor([term]).float())
                count += 1
        
        imp_loss /= max(count, 1)  # 防止除零
        
        total_loss = self.alpha * charge_loss + (1 - self.alpha) * imp_loss
        return total_loss, charge_loss, imp_loss

In [None]:
# 初始化模型
model = QwenLegalModel()
tokenizer = AutoTokenizer.from_pretrained("/home/yunpeng/lianqi/fnlp/JudgePred/model_zoo/Qwen3-8B")

# 训练参数
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=5,
    per_device_train_batch_size=2,  # 8张A100可调至4
    gradient_accumulation_steps=8,
    learning_rate=2e-5,
    fp16=True,
    logging_dir='./logs',
)

# 自定义数据收集函数
def collate_fn(batch):
    import pdb
    pdb.set_trace()
    facts = [item['fact'] for item in batch]
    defendants_list = [item['defendants'] for item in batch]
    
    # 文本编码
    text_encodings = tokenizer(facts, padding=True, truncation=True, return_tensors="pt")
    
    # 构建图数据
    graph_data = [build_fact_graph(fact, defs) for fact, defs in zip(facts, defendants_list)]
    
    return {
        'input_ids': text_encodings['input_ids'],
        'attention_mask': text_encodings['attention_mask'],
        'graph_data': graph_data,
        'charge_labels': [item['charge_labels'] for item in batch],
        'imprisonment_labels': [item['imprisonment_labels'] for item in batch]
    }

# 创建Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_data,
    data_collator=collate_fn,
    # compute_metrics=compute_metrics  # 需自定义评估函数
)

# 开始训练
trainer.train()

Loading checkpoint shards: 100%|██████████| 7/7 [00:00<00:00, 124.29it/s]


Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


> [0;32m/tmp/ipykernel_1882579/1492525209.py[0m(20)[0;36mcollate_fn[0;34m()[0m
[0;32m     18 [0;31m    [0;32mimport[0m [0mpdb[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     19 [0;31m    [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 20 [0;31m    [0mfacts[0m [0;34m=[0m [0;34m[[0m[0mitem[0m[0;34m[[0m[0;34m'fact'[0m[0;34m][0m [0;32mfor[0m [0mitem[0m [0;32min[0m [0mbatch[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     21 [0;31m    [0mdefendants_list[0m [0;34m=[0m [0;34m[[0m[0mitem[0m[0;34m[[0m[0;34m'defendants'[0m[0;34m][0m [0;32mfor[0m [0mitem[0m [0;32min[0m [0mbatch[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     22 [0;31m[0;34m[0m[0m
[0m
<class 'list'>
16
{}
{}
