In [6]:
import os
import sys
import json
import pandas as pd
from tqdm import tqdm
from torch.utils.data import Dataset
from transformers import AutoTokenizer, AutoModel

import os
import json
from multiprocessing import Pool
from torch.utils.data import Dataset

sys.path.append('/root/tuning_space/Components/')
from Static import prompt_dict, si

raw_data_path='/root/autodl-tmp/weights/smile/data'
aim_path='/root/autodl-tmp/dataset/OESD-GG-zh_cn-1/single_query.jsonl'
model_path = '/root/autodl-tmp/weights/chatglm3-6b'

tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)



In [7]:
class instruction_dataset(Dataset):
    def __init__(self, data_path, tokenizer, truncate_length, max_query_length, query_key, answer_key, max_sample=None, num_workers=12):
        super().__init__()
        self.tokenizer = tokenizer
        self.truncate_length = truncate_length
        self.max_query_length = max_query_length
        self.query_key = query_key
        self.answer_key = answer_key
        self.examples = []

        # 使用多进程读取和处理数据
        with open(data_path, 'r') as file:
            lines = file.readlines()
            if max_sample:
                lines = lines[:max_sample]

        with Pool(num_workers) as p:
            self.examples = p.map(self.process_line, lines)

    def process_line(self, line):
        sample=json.loads(line)
        max_answer_length = self.truncate_length - self.max_query_length - 3

        # 判断query的长度
        query = sample[self.query_key]
        query_ids = self.tokenizer.encode(query, add_special_tokens=False)
        if len(query_ids) > self.max_query_length:
            query_ids = query_ids[:self.max_query_length]

        # 判断answer的长度
        answer = sample[self.answer_key]
        answer_ids = self.tokenizer.encode(answer, add_special_tokens=False)
        if len(answer) > max_answer_length:
            answer_ids = answer_ids[:max_answer_length]

        # 合并
        input_ids = query_ids + [si['[gMASK]']] + [si['sop']] + answer_ids + [si['eop']]
        pre_context_length = input_ids.index(si['sop'])
        end_answer_index = input_ids.index(si['eop'])

        # padding
        padding_length=self.truncate_length-len(input_ids)
        input_ids+=padding_length*[self.tokenizer.pad_token_id]

        # 制作labels；其中query部分，pad部分均不参与loss的计算 # 因为需要整体向左移动，所以要少填充一个
        labels = [-100] * (pre_context_length+1) + input_ids[pre_context_length+1: end_answer_index+1]
        labels = labels + [-100] * (self.truncate_length-len(labels))

        # 制作attention_mask
        eop_position = input_ids.index(si['eop'])+1
        attention_mask = [True]*eop_position
        attention_mask += [False]*(self.truncate_length-len(attention_mask))

        return {
            'query': query,
            'answer': answer,
            'input_ids': input_ids,
            'labels': labels,
            'attention_mask': attention_mask,
        }

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, item):
        return self.examples[item]

In [8]:
temp=instruction_dataset(data_path=aim_path, tokenizer=tokenizer, truncate_length=192, max_query_length=64, query_key='User', answer_key='Assisstant', max_sample=5, num_workers=1)

KeyError: 'Assistant'

In [None]:
sample=None
for item in temp:
    sample=item
    break

In [None]:
for item in sample.values():
    print(len(item))

In [None]:
input_ids=sample['input_ids']
labels=sample['labels']
attention_mask=sample['attention_mask']

In [None]:
tokenizer.decode(input_ids)

In [None]:
tokenizer.decode(labels)