In [None]:
import os
import sys
import json
import torch
from transformers import AutoTokenizer
import torch.distributed as dist

# 添加项目根目录到路径
sys.path.append(os.path.dirname(os.getcwd()))

from lit.configs.train_config import train_config
from lit.utils.dataset_utils import (
    LatentQADataset, 
    DataCollatorForLatentQA,
    NUM_READ_TOKENS_TO_SHIFT,
    NUM_WRITE_TOKENS_TO_SHIFT,
    DECODER_CHAT_TEMPLATES,
    get_dataset,
    get_dataloaders,
    get_dist_batch_sampler,
    mask_inputs,
    print_dataset_samples
)
from lit.utils.infra_utils import get_tokenizer, get_model_config_name

In [None]:
args = train_config()

args.batch_size_training = 2
args.train_system = "../data/train/system.json"
args.train_stimulus_completion = "../data/train/stimulus_completion.json"
# args.train_stimulus = "../data/train/stimulus.json"
# args.train_control = "../data/train/control.json"
args.train_qa = "../data/train/qa.json" 
# args.filter = "goal"

model_name = "/data1/ckx/hf-checkpoints/meta-llama/Llama-3.1-8B-Instruct"
    
tokenizer = get_tokenizer(model_name)

dataset = get_dataset(args, tokenizer)

# get_dataset函数

In [None]:
import json
import random
from collections import defaultdict
train = True
FILTER = args.filter.split("-")
with open(args.train_qa , "r") as f:
    qa_data = json.load(f)

    NUM_QA = max([len(qa_data[label]) for label in qa_data])
    assert NUM_QA == min([len(qa_data[label]) for label in qa_data])

def build_data_and_idx(path):
        # Get data
        print("Loading data from:", path)
        data = defaultdict(list)
        if path == "":
            return data, []
        with open(path, "r") as f:
            raw_data = json.load(f)
            for item in raw_data:
                
                if item["label"].split("-")[0] in FILTER:
                    continue
                data[item["label"]].append(
                    (
                        item.get("system", ""),
                        item.get("control_user", ""),
                        item.get("control_thought", ""),
                        item.get("control_model", ""),
                        item.get("stimulus_user", ""),
                        item.get("stimulus_thought", ""),
                        item.get("stimulus_model", ""),
                    )
                )
        print("Loaded {} labels".format(len(data)))
        # Get id tuples
        NUM_BEHAVIORS = max([len(data[label]) for label in data])
        assert NUM_BEHAVIORS == min([len(data[label]) for label in data])
        print("NUM_BEHAVIORS:", NUM_BEHAVIORS)
        id_tuples = range(len(data) * NUM_BEHAVIORS * NUM_QA)
        print(id_tuples)
        if args.train_percent == 1 or not train:
            id_tuples = list(id_tuples)
        else:
            id_tuples = random.sample(
                id_tuples, int(len(id_tuples) * args.train_percent)
            )
        for i in range(len(id_tuples)):
            label_idx = id_tuples[i] // (NUM_BEHAVIORS * NUM_QA)
            data_idx = (id_tuples[i] // NUM_QA) % NUM_BEHAVIORS
            qa_idx = id_tuples[i] % NUM_QA
            id_tuples[i] = (label_idx, data_idx, qa_idx)
        return data, id_tuples

In [None]:
FILTER

In [None]:
qa_data

In [None]:
p0 = args.train_system if train else args.eval_system
p1 = (
        args.train_stimulus_completion
        if train
        else args.eval_stimulus_completion
    )
p2 = args.train_stimulus if train else args.eval_stimulus
p3 = args.train_control if train else args.eval_control
data_system = build_data_and_idx(p0)
data_stimulus_completion = build_data_and_idx(p1)
data_stimulus = build_data_and_idx(p2)
data_control = build_data_and_idx(p3)

In [None]:
data_control[0]

In [None]:
data_control[1]

# dataset类

In [None]:
dataset.data[1]

In [None]:
dataset.id_tuples[0]

In [None]:
len(dataset.labels[1])

In [None]:
dataset.qa_data

In [None]:
dataset.get_behavior_qa(17777)

In [None]:
dataset.lengths

In [None]:
len(dataset)

In [None]:
dataset[0]

In [None]:
print(dataset[64839]['read_prompt'])

## mask_type = "user"

In [None]:
print(dataset[17777]['read_prompt'])

In [None]:
dataset[17777]['dialog']

In [None]:
dataset[17777]['mask_type']

## mask_type = "system"

In [None]:
print(dataset[0]["read_prompt"])

In [None]:
dataset[0]['dialog']

In [None]:
dataset[0]['mask_type']

# get_dataloaders

In [None]:
get_dist_batch_sampler(dataset, args, "train")

In [None]:
train_dataloader, eval_dataloader = get_dataloaders(args,tokenizer)

In [None]:
len(train_dataloader)

## lqa_tokenize 函数

In [None]:
batch = [dataset[0]['read_prompt'], dataset[17777]['read_prompt']]
batch

In [None]:
tokenized_read = tokenizer(batch, 
                   return_tensors="pt",
                   padding=True,
                   add_special_tokens=False,
    )
tokenized_read

In [None]:
print(tokenizer.decode(tokenized_read.input_ids[0]))
print('+++++++++++++++')
print(tokenizer.decode(tokenized_read.input_ids[1]))

In [None]:
read_lengths = torch.sum(tokenized_read.attention_mask, dim=1)
read_lengths

In [None]:
mask_inputs._debug_print = True

In [None]:
verb_mask = mask_inputs(tokenized_read.input_ids, 
                        tokenizer.name_or_path,
                        mask_type=['system','user'],
                        # mask_type = None,
                        mask_all_but_last=False)


In [None]:
# 验证字符串编码
encoded = tokenizer.encode("\n\n", add_special_tokens=False)
print(encoded)  # 输出: [271]
for ids in (
        torch.tensor([128006, 9125, 128007, 271]),
        torch.tensor([128006, 882, 128007, 271]),
        torch.tensor([128006, 78191, 128007, 271]),
        torch.tensor([128006, 36013, 128007, 271]),
    ):
    print(tokenizer.decode(ids))

In [None]:
verb_mask

In [None]:

verb_lengths = torch.sum(verb_mask, dim=1)
verb_lengths

In [None]:
read_lengths

In [None]:
pad_lengths = read_lengths - verb_lengths
pad_lengths

In [None]:
batch = [dataset[0]['dialog'], dataset[17777]['dialog']]

In [None]:
batch[0]

In [None]:
queries = []
for i in range(len(pad_lengths)):
    query = [
        {
                "role": "user",
                "content": "? " * (pad_lengths[i] - 1),
        }
    ]
    query += batch[i]
    print(query)
    queries.append(
            tokenizer.apply_chat_template(
                query,
                tokenize=False,
                add_generation_prompt=False,
                chat_template=(
                    DECODER_CHAT_TEMPLATES[get_model_config_name(model_name)]
                ),
            )
        )



In [None]:
queries

In [None]:
tokenized_write = tokenizer(
        queries,
        return_tensors="pt",
        padding=True,
        add_special_tokens=False,
    )
tokenized_write

In [None]:
write_lengths = torch.sum(tokenized_write.attention_mask, dim=1)
write_lengths

In [None]:
user_inputs_mask = mask_inputs(
            tokenized_write.input_ids,
            tokenizer.name_or_path,
            mask_type=None,
            shift_start=True,
            modify_chat_template=True,
        )
assert tokenizer.padding_side == "left"


In [None]:
tokenized_write["labels"] = tokenized_write.input_ids.clone()
mask = (tokenized_write.attention_mask == 0) | user_inputs_mask
tokenized_write["labels"][mask] = -100
tokenized_write["labels"]

In [None]:
first_batch = next(iter(train_dataloader))

In [None]:
first_batch

In [None]:
print_dataset_samples(train_dataloader, tokenizer, num_samples=2)