In [1]:
import os
import sys
import json
import pandas as pd
from tqdm import tqdm
from torch.utils.data import Dataset
from transformers import AutoTokenizer, AutoModel

In [2]:
%matplotlib inline
%config InlineBackend.figure_format='svg'

In [3]:
sys.path.append('/root/tuning_space/Components/')
from Static import prompt_dict, si

In [4]:
raw_data_path='/root/autodl-tmp/weights/smile/data'
aim_path='/root/autodl-tmp/dataset/smile/smile_conversation.jsonl'
model_path = '/root/autodl-tmp/weights/chatglm3-6b'

In [5]:
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

In [6]:
def build_labels(input_ids):
    # 初始化一个新列表，用于存储结果
    result = [-100] * len(input_ids)
    # 遍历列表，查找满足条件的[anw]元素
    inside_ast = False  # 标记是否在【sop】和【eop】之间
    for i, item in enumerate(input_ids):
        if item == si["sop"]:
            inside_ast = True
        elif item == si["eop"]:
            inside_ast = False
            result[i] = item
        elif inside_ast:
            result[i] = item
    return result

In [7]:
for last_index in range(1,10,2):
    print(last_index)

1
3
5
7
9


In [8]:
import os
import json
from itertools import chain
from multiprocessing import Pool
from torch.utils.data import Dataset

class conversation_dataset_trad(Dataset):
    def __init__(self, data_path, tokenizer, truncate_length, query_key, answer_key, max_sample=None, num_workers=12):
        super().__init__()
        self.tokenizer = tokenizer
        self.truncate_length = truncate_length
        self.query_key = query_key
        self.answer_key = answer_key
        self.max_sample = max_sample
        self.examples = []  # 存储最终结果的对象

        # 读取文件
        with open(data_path, 'r') as file:
            lines = file.readlines()
            if max_sample:
                lines = lines[:max_sample]

        # 创建一个进程池
        with Pool(num_workers) as p:
            self.examples = p.map(self.process_line, lines)
        self.examples = list(chain.from_iterable(self.examples))

    def process_line(self, line):
        conversation = json.loads(line)['conversation']
        result = []
        for last_index in range(1, 100, 2):
            try:
                input_ids = []
                for sample in conversation[:-last_index]:
                    role = next(iter(sample))
                    if role == self.query_key:
                        input_ids += [si["<|user|>"], si['\n']]
                        input_ids += self.tokenizer.encode(sample[role], add_special_tokens=False)
                    elif role == self.answer_key:
                        input_ids += [si["<|assistant|>"], si['\n']]
                        input_ids += self.tokenizer.encode(sample[role], add_special_tokens=False)

                # 添加最后一个预测位
                sample = conversation[-last_index]
                role = next(iter(sample))
                input_ids += [si["<|assistant|>"], si['\n'], si["[gMASK]"], si['sop']]
                input_ids += self.tokenizer.encode(sample[role], add_special_tokens=False)
                input_ids += [si["eop"]]

                # 判断截断，添加终止生成符号，padding
                if len(input_ids) > self.truncate_length - 1:
                    input_ids = input_ids[:self.truncate_length - 1] + [si["eop"]]
                input_ids += [self.tokenizer.pad_token_id] * (self.truncate_length - len(input_ids))

                # 制作labels
                labels = build_labels(input_ids)  # 假设这个函数已经定义

                # 制作attention_mask
                eop_position = input_ids.index(self.tokenizer.pad_token_id)
                attention_mask = [True] * eop_position + [False] * (self.truncate_length - eop_position)

                result.append({
                    'input_ids': input_ids,
                    'labels': labels,
                    'attention_mask': attention_mask,
                })
            except Exception as e:
                #print(f"Error processing line: {e}")
                break
        return result

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, item):
        return self.examples[item]

In [9]:
%%time
temp=conversation_dataset_trad(data_path=aim_path, tokenizer=tokenizer, truncate_length=1024, query_key='client', answer_key='counselor', max_sample=5000, num_workers=32)

CPU times: user 4.92 s, sys: 2.34 s, total: 7.26 s
Wall time: 7.36 s


In [10]:
count=0
for item in temp:
    count+=1
print(count)

27613


In [64]:
sample=None
for times, item in enumerate(temp):
    if times%3==0 and times !=0:
        sample=item
        break
    else:
        pass

In [65]:
for item in sample.values():
    print(len(item))

1024
1024
1024


In [66]:
input_ids=sample['input_ids']
labels=sample['labels']
attention_mask=sample['attention_mask']

In [67]:
tokenizer.decode(input_ids)

'<|user|>\n 高三后的迷茫，高考前的恐惧，能给我一些建议么？<|assistant|>\n 看到你的提问感觉你很焦虑，这个状态在高中高压下很容易出现。我想说的是，我看到了你的决心。这点是很多人没有的！高考最重要的不是知识是心态。是必胜的心态！什么放松吧缓缓吧，都是站着说话不腰疼，保送的又不是我，我能放松什么？！我有我的目标，我怎么可能放弃！有目标就好办，计划！缺个计划，缺个时间合理配置的复习计划。<|user|>\n 你说的对，我是非常焦虑，确实需要调整心态。我也明白高考的心态很重要，但是要怎样才能真正拥有必胜的心态呢？<|assistant|>\n[gMASK]sop 首先，你要明确自己的目标，既然你想考本科，那就要为此做好准备。然后，你需要制定一个合理的复习计划，根据自己的情况来安排每天的学习时间和内容。这样可以帮助你更好地掌控整个复习过程，减少焦虑感。eop'

In [68]:
tokenizer.decode(labels)

'首先，你要明确自己的目标，既然你想考本科，那就要为此做好准备。然后，你需要制定一个合理的复习计划，根据自己的情况来安排每天的学习时间和内容。这样可以帮助你更好地掌控整个复习过程，减少焦虑感。eop'

In [None]:
print(sample)

In [None]:
input_ids.index(46173)

In [None]:
labels.index(46173)

In [None]:
input_ids.index(64793), labels.index(64793)

In [None]:
text=[30910, 31857, 31822, 37817, 32044, 54622, 54657, 34843, 31123, 31646, 32218, 54534, 32959, 37826, 54578, 33903, 31755, 31155, 33103, 46173, 31123, 54546, 34518, 31822, 34887, 31155, 41037, 54532, 32497, 31631, 54530, 31404, 33390, 35215, 31663, 31848, 54532, 35098, 31155, 54532, 55020, 55442, 38888, 31404, 31642, 34794, 55370, 41652, 55370, 31123, 31700, 53221, 33742, 54535, 56278, 56289, 31123, 54685, 55244, 54530, 54892, 31663, 54546, 31123, 40601, 34794, 31642, 43638, 37736, 31791, 31919, 31123, 54546, 52687, 33154, 31404, 54536, 31919, 37501, 54879, 31123, 31864, 31404, 55478, 54550, 31864, 31123, 55478, 54550, 31643, 32584, 33339, 54530, 37047, 31864, 31155, 64793]
tokenizer.decode(text)

In [None]:
tokenizer.decode()

In [None]:
tokenizer.decode([64795, 13, 36431, 54645, 32313, 40739, 31123, 33390, 33711, 36355, 31123, 54558, 33575, 31784, 32108, 54705, 31514, 64796, 13, 64790, 64792, 30910, 31857, 31822, 37817, 32044, 54622, 54657, 34843, 31123, 31646, 32218, 54534, 32959, 37826, 54578, 33903, 31755, 31155, 33103, 46173, 31123, 54546, 34518, 31822, 34887, 31155, 41037, 54532, 32497, 31631, 54530, 31404, 33390, 35215, 31663, 31848, 54532, 35098, 31155, 54532, 55020, 55442, 38888, 31404, 31642, 34794, 55370, 41652, 55370, 31123, 31700, 53221, 33742, 54535, 56278, 56289, 31123, 54685, 55244, 54530, 54892, 31663, 54546, 31123, 40601, 34794, 31642, 43638, 37736, 31791, 31919, 31123, 54546, 52687, 33154, 31404, 54536, 31919, 37501, 54879, 31123, 31864, 31404, 55478, 54550, 31864, 31123, 55478, 54550, 31643, 32584, 33339, 54530, 37047, 31864, 31155, 64793, 64795, 13, 36474, 32980, 54570, 31123, 33030, 31685, 34843, 31123, 32967, 31665, 32271, 35098, 31155, 33876, 32855, 33390, 38888, 37523, 31123, 31694, 54552, 33613, 32017, 32047, 32104, 55020, 55442, 38888, 55282, 31514, 64796, 13, 64790, 64792, 30910, 32342, 31123, 34526, 32309, 31674, 31919, 31123, 33647, 37902, 54814, 32894, 31123, 54728, 32926, 35016, 51197, 31155, 32043, 31123, 44531, 32624, 31623, 37239, 37047, 31864, 31123, 52120, 31689, 54556, 32158, 32096, 34769, 40036, 31795, 31155, 31676, 41230, 54622, 34450, 40416, 32307, 37047, 31807, 31123, 32382, 34843, 54706, 31155, 64793, 64795, 13, 34211, 32967, 47409, 37239, 37047, 31864, 31123, 42701, 31897, 31643, 33793, 54571, 31123, 39463, 54960, 54708, 54873, 31699, 31155, 37474, 33575, 31784, 32108, 55398, 31514, 64796, 13, 64790, 64792, 30910, 32276, 31628, 31404, 34738, 54708, 54589, 35872, 31123, 39887, 54736, 54726, 39546, 31155, 32096, 54573, 42096, 37234, 31123, 31676, 31759, 31803, 31822, 31658, 33045, 31155, 31701, 31123, 31738, 32886, 54551, 54811, 36322, 31123, 54573, 55208, 54573, 55379, 35073, 32333, 31123, 54744, 54588, 54615, 31687, 31123, 31628, 42096, 42733, 37234, 31123, 33041, 32169, 33467, 32289, 54814, 33177, 37234, 31123, 32843, 34216, 54725, 54736, 31155, 64793, 64795, 13, 50700, 35029, 31123, 54546, 43903, 32375, 35483, 31123, 37845, 54545, 32677, 31155, 43467, 54546, 49086, 55282, 31514, 64796, 13, 64790, 64792, 30910, 31844, 32375, 35483, 31123, 54551, 35565, 31917, 31822, 32128, 35022, 31155, 31654, 32411, 35029, 31123, 34738, 37318, 54701, 54761, 31764, 31123, 31633, 31632, 54541, 31788, 55437, 31822, 31639, 54617, 35134, 31155, 31844, 33519, 34550, 35221, 31123, 47076, 32192, 31674, 36319, 31123, 41406, 41167, 31658, 31123, 31781, 32702, 54626, 33095, 54530, 31155, 64793, 64795, 13, 34211, 37824, 32823, 31822, 32108, 31155, 35094, 40328, 32271, 35098, 31123, 32624, 31623, 37239, 37047, 31864, 31123, 54724, 35087, 44035, 31764, 47683, 31639, 31155, 38505, 32100, 31862, 31123, 54567, 31781, 31759, 32084, 31674, 31919, 31155, 64796, 13, 64790, 64792, 30910, 48895, 31759, 55268, 50589, 31404, 36597, 50139, 33213, 31123, 31875, 31862, 31123, 38505, 54622, 31781, 31759, 32058, 53626, 31155, 34649, 31404, 64793, 64793, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [None]:
attention=[True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False]

In [None]:
count=0
for item in attention:
    if item==True:
        count+=1
count

In [None]:
inputs=[64795, 13, 36431, 54645, 32313, 40739, 31123, 33390, 33711, 36355, 31123, 54558, 33575, 31784, 32108, 54705, 31514, 64796, 13, 64790, 64792, 30910, 31857, 31822, 37817, 32044, 54622, 54657, 34843, 31123, 31646, 32218, 54534, 32959, 37826, 54578, 33903, 31755, 31155, 33103, 46173, 31123, 54546, 34518, 31822, 34887, 31155, 41037, 54532, 32497, 31631, 54530, 31404, 33390, 35215, 31663, 31848, 54532, 35098, 31155, 54532, 55020, 55442, 38888, 31404, 31642, 34794, 55370, 41652, 55370, 31123, 31700, 53221, 33742, 54535, 56278, 56289, 31123, 54685, 55244, 54530, 54892, 31663, 54546, 31123, 40601, 34794, 31642, 43638, 37736, 31791, 31919, 31123, 54546, 52687, 33154, 31404, 54536, 31919, 37501, 54879, 31123, 31864, 31404, 55478, 54550, 31864, 31123, 55478, 54550, 31643, 32584, 33339, 54530, 37047, 31864, 31155, 64793, 64795, 13, 36474, 32980, 54570, 31123, 33030, 31685, 34843, 31123, 32967, 31665, 32271, 35098, 31155, 33876, 32855, 33390, 38888, 37523, 31123, 31694, 54552, 33613, 32017, 32047, 32104, 55020, 55442, 38888, 55282, 31514, 64796, 13, 64790, 64792, 30910, 32342, 31123, 34526, 32309, 31674, 31919, 31123, 33647, 37902, 54814, 32894, 31123, 54728, 32926, 35016, 51197, 31155, 32043, 31123, 44531, 32624, 31623, 37239, 37047, 31864, 31123, 52120, 31689, 54556, 32158, 32096, 34769, 40036, 31795, 31155, 31676, 41230, 54622, 34450, 40416, 32307, 37047, 31807, 31123, 32382, 34843, 54706, 31155, 64793, 64795, 13, 34211, 32967, 47409, 37239, 37047, 31864, 31123, 42701, 31897, 31643, 33793, 54571, 31123, 39463, 54960, 54708, 54873, 31699, 31155, 37474, 33575, 31784, 32108, 55398, 31514, 64796, 13, 64790, 64792, 30910, 32276, 31628, 31404, 34738, 54708, 54589, 35872, 31123, 39887, 54736, 54726, 39546, 31155, 32096, 54573, 42096, 37234, 31123, 31676, 31759, 31803, 31822, 31658, 33045, 31155, 31701, 31123, 31738, 32886, 54551, 54811, 36322, 31123, 54573, 55208, 54573, 55379, 35073, 32333, 31123, 54744, 54588, 54615, 31687, 31123, 31628, 42096, 42733, 37234, 31123, 33041, 32169, 33467, 32289, 54814, 33177, 37234, 31123, 32843, 34216, 54725, 54736, 31155, 64793, 64795, 13, 50700, 35029, 31123, 54546, 43903, 32375, 35483, 31123, 37845, 54545, 32677, 31155, 43467, 54546, 49086, 55282, 31514, 64796, 13, 64790, 64792, 30910, 31844, 32375, 35483, 31123, 54551, 35565, 31917, 31822, 32128, 35022, 31155, 31654, 32411, 35029, 31123, 34738, 37318, 54701, 54761, 31764, 31123, 31633, 31632, 54541, 31788, 55437, 31822, 31639, 54617, 35134, 31155, 31844, 33519, 34550, 35221, 31123, 47076, 32192, 31674, 36319, 31123, 41406, 41167, 31658, 31123, 31781, 32702, 54626, 33095, 54530, 31155, 64793, 64795, 13, 34211, 37824, 32823, 31822, 32108, 31155, 35094, 40328, 32271, 35098, 31123, 32624, 31623, 37239, 37047, 31864, 31123, 54724, 35087, 44035, 31764, 47683, 31639, 31155, 38505, 32100, 31862, 31123, 54567, 31781, 31759, 32084, 31674, 31919, 31155, 64796, 13, 64790, 64792, 30910, 48895, 31759, 55268, 50589, 31404, 36597, 50139, 33213, 31123, 31875, 31862, 31123, 38505, 54622, 31781, 31759, 32058, 53626, 31155, 34649, 31404, 64793, 64793, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

In [None]:
count=0
for item in attention:
    if item!=0:
        count+=1
count

In [None]:
#这不是一个训练代码，而是一个检查Loss为什么不按照预期下降的代码
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, AutoModel
#tokenizer = AutoTokenizer.from_pretrained('/root/autodl-fs/weights/chatglm3-6b', trust_remote_code=True)

query_ids=tokenizer.encode('中国的首都在哪', add_special_tokens=False)
answer_ids=tokenizer.encode('北京', add_special_tokens=False)
input_ids = query_ids + [0] + [1] + answer_ids + [2]
print(input_ids)

#input_ids=['a','b','c','d','gmask','sop','e','f','g','eop','pad','pad','pad']
pre_context_length=input_ids.index(1)
end_answer_index = input_ids.index(2)

labels = [-100] * (pre_context_length+1) + input_ids[pre_context_length+0: end_answer_index+1]
labels = labels + [-100] * (len(input_ids)-len(labels))
print(labels)

In [None]:
tokenizer.decode([30910, 32914, 54895, 32595, 55315, 0, 1, 37536, 2])

In [None]:
def bulid_labels(input_ids):
    # 初始化一个新列表，用于存储结果
    result = [-100] * len(input_ids)
    # 遍历列表，查找满足条件的[anw]元素
    inside_ast = False  # 标记是否在【ast】和【user】/【eop】之间
    for i, item in enumerate(input_ids):
        if item == si["<|assistant|>"]:
            inside_ast = True
        elif item == si["<|user|>"]:
            inside_ast = False
        elif item == si["eop"]:
            inside_ast = False
            result[i] = item
        elif inside_ast:
            result[i] = item
    return result
   

class conversation_dataset(Dataset):
    def __init__(self, data_path, tokenizer, truncate_length, query_key, answer_key, max_sample=None, num_workers=12):
        super().__init__()
        self.tokenizer = tokenizer
        self.truncate_length = truncate_length
        self.query_key = query_key
        self.answer_key = answer_key
        self.max_sample = max_sample
        self.examples = []  # 存储最终结果的对象

        # 读取文件
        with open(data_path, 'r') as file:
            lines = file.readlines()
            if max_sample:
                lines = lines[:max_sample]

        # 创建一个进程池
        with Pool(num_workers) as p:
            self.examples = p.map(self.process_line, lines)

    def process_line(self, line):
        conversation=json.loads(line)['conversation'] 
        '''
        此处假设conversation的query和answer均不会太长，故删除了判断长度的步骤
        '''
        # 制作input_ids
        # 添加开始生成标识符
        input_ids = [si["[gMASK]"]] + [si['sop']]

        # 遍历双方对话，首端添加特殊token
        for sample in conversation:
            role=next(iter(sample))
            if  role== self.query_key:
                input_ids += [si["<|user|>"]]
                input_ids += self.tokenizer.encode(sample[role], add_special_tokens=False)
            elif role == self.answer_key:
                input_ids += [si["<|assistant|>"]]
                input_ids += self.tokenizer.encode(sample[role], add_special_tokens=False)

        # 判断截断，添加终止生成符号，padding
        if len(input_ids) > self.truncate_length-1:
            input_ids = input_ids[:self.truncate_length-1]

        input_ids += [si["eop"]]
        input_ids += [self.tokenizer.pad_token_id] * (self.truncate_length-len(input_ids))

        # 制作labels
        labels = bulid_labels(input_ids)

        # 制作attention_mask
        eop_position = input_ids.index(si['eop']) + 1
        attention_mask = [True] * eop_position
        attention_mask += [False] * (self.truncate_length - len(attention_mask))
        return {
            'input_ids': input_ids,
            'labels': labels,
            'attention_mask': attention_mask,
        }

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, item):
        return self.examples[item]