In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import torch
from functools import partial

def encode_with_prompt_completion_format(example, tokenizer, max_seq_length, add_bos=False):
    '''
    Here we assume each example has 'prompt' and 'completion' fields.
    We concatenate prompt and completion and tokenize them together because otherwise prompt will be padded/trancated 
    and it doesn't make sense to follow directly with the completion.
    '''
    # if prompt doesn't end with space and completion doesn't start with space, add space
    if not example['prompt'].endswith((' ', '\n', '\t')) and not example['completion'].startswith((' ', '\n', '\t')):
        example_text = example['prompt'] + ' ' + example['completion']
    else:
        example_text = example['prompt'] + example['completion']
    example_text = example_text + tokenizer.eos_token
    if add_bos:
        example_text = tokenizer.bos_token + example_text
    tokenized_example = tokenizer(example_text, return_tensors='pt', max_length=max_seq_length, truncation=True)
    input_ids = tokenized_example.input_ids
    labels = input_ids.clone()
    tokenized_prompt = tokenizer(example['prompt'], return_tensors='pt', max_length=max_seq_length, truncation=True)
    # mask the prompt part for avoiding loss
    # labels[:, :tokenized_prompt.input_ids.shape[1]] = -100
    attention_mask = torch.ones_like(input_ids)
    return {
        'input_ids': input_ids.flatten(),
        'labels': labels.flatten(),
        'attention_mask': attention_mask.flatten(),
    }


def encode_with_messages_format(example, tokenizer, max_seq_length, add_bos=False):
    '''
    Here we assume each example has a 'messages' field Each message is a dict with 'role' and 'content' fields.
    We concatenate all messages with the roles as delimiters and tokenize them together.
    '''
    messages = example['messages']
    if len(messages) == 0:
        raise ValueError('messages field is empty.')
    
    def _concat_messages(messages):
        message_text = ""
        for message in messages:
            if message["role"] == "system":
                message_text += "<|system|>\n" + message["content"].strip() + "\n"
            elif message["role"] == "user":
                message_text += "<|user|>\n" + message["content"].strip() + "\n"
            elif message["role"] == "assistant":
                message_text += "<|assistant|>\n" + message["content"].strip() + tokenizer.eos_token + "\n"
            else:
                raise ValueError("Invalid role: {}".format(message["role"]))
        return message_text
        
    example_text = _concat_messages(messages).strip()
    if add_bos:
        example_text = tokenizer.bos_token + example_text
    tokenized_example = tokenizer(example_text, return_tensors='pt', max_length=max_seq_length, truncation=True)
    input_ids = tokenized_example.input_ids
    labels = input_ids.clone()

    # mask the non-assistant part for avoiding loss
    for message_idx, message in enumerate(messages):
        if message["role"] != "assistant":
            if message_idx == 0:
                message_start_idx = 0
            else:
                message_start_idx = tokenizer(
                    _concat_messages(messages[:message_idx]), return_tensors='pt', max_length=max_seq_length, truncation=True
                ).input_ids.shape[1]
            if message_idx < len(messages) - 1 and messages[message_idx+1]["role"] == "assistant":
                # here we also ignore the role of the assistant
                messages_so_far = _concat_messages(messages[:message_idx+1]) + "<|assistant|>\n"
            else:
                messages_so_far = _concat_messages(messages[:message_idx+1])
            message_end_idx = tokenizer(
                messages_so_far,
                return_tensors='pt', 
                max_length=max_seq_length, 
                truncation=True
            ).input_ids.shape[1]
            # labels[:, message_start_idx:message_end_idx] = -100

            if message_end_idx >= max_seq_length:
                break

    attention_mask = torch.ones_like(input_ids)
    return {
        'input_ids': input_ids.flatten(),
        'labels': labels.flatten(),
        'attention_mask': attention_mask.flatten(),
    }

dataset_name='random_1'
model_name_or_path = "meta-llama/Llama-3.2-3B"
data_path = f"selected_data/{dataset_name}.json"

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
raw_dataset = load_dataset("json", data_files=data_path)

if "prompt" in raw_dataset["train"].column_names and "completion" in raw_dataset["train"].column_names:
    encode_function = partial(
        encode_with_prompt_completion_format,
        tokenizer=tokenizer,
        max_seq_length= 2048,
        add_bos= False,
    )
elif "messages" in raw_dataset["train"].column_names:
    encode_function = partial(
        encode_with_messages_format,
        tokenizer=tokenizer,
        max_seq_length= 2048,
        add_bos= False,
    )
    
raw_dataset = raw_dataset.map(
    lambda example, idx: {"idx": idx},
    with_indices=True,  
    desc="Adding idx column",
)
        

lm_datasets = raw_dataset.map(
    encode_function,
    batched=False,
    # remove_columns=[name for name in raw_dataset["train"].column_names if name not in ["idx", "input_ids", "labels", "attention_mask"]],
    desc="Tokenizing and reformatting instruction data",
)

train_dataset = lm_datasets['train']

  from .autonotebook import tqdm as notebook_tqdm
Adding idx column: 100%|██████████| 2000/2000 [00:00<00:00, 12171.62 examples/s]
Tokenizing and reformatting instruction data: 100%|██████████| 2000/2000 [00:05<00:00, 379.63 examples/s]


In [5]:
labels = train_dataset['labels']

all_len = 0

for label in labels:
    all_len += len(label)
    
all_len

664120

## select token

In [65]:
import torch
import numpy as np
from datasets import load_dataset
dataset_name='random'
losses_pre = torch.load(f"token_losses_{dataset_name}_base.pt")
losses_new = torch.load(f"token_losses_{dataset_name}_test.pt")

start=1000
length=1000
loss_diff = []
loss_HL_prop = []
select_tokens_indices = []
for loss1, loss2 in zip(losses_pre[start:start+length], losses_new[start:start+length]):
    # print(f"shape1: {len(loss1)}; shape2: {len(loss2)}")
    diff = np.array(loss1)-np.array(loss2)
    loss_diff.append(diff)
    _, indices = torch.topk(torch.tensor(diff), k=len(diff)//2)
    select_tokens_indices.append((indices + 1).tolist()) ## indices +1 represents the biased value, which match the real token in the original dataset
    loss_HL_prop.append(round(np.sum(diff>0)/len(diff) * 100, 3))
    
# dataset = load_dataset("json", data_files="selected_data/meta-llama/Meta-Llama-3.1-8B-Instruct/all_train/random_dataset.json")
train_dataset['labels']

new_labels=[]
for selected_indices, label in zip(select_tokens_indices, train_dataset['labels'][start:start+length]):
    # print(f"selected indices: {len(selected_indices)};; label: {len(label)}")
    new_label = [-100] * len(label)
    for idx in sorted(selected_indices):
        new_label[idx] = label[idx]
    new_labels.append(new_label)
    


  losses_pre = torch.load(f"token_losses_{dataset_name}_base.pt")
  losses_new = torch.load(f"token_losses_{dataset_name}_test.pt")


In [66]:
train_dataset['labels'][start:start+length] = new_labels ##need to determine how to convert to labels to the dataset

## Split data into several subsets for multiple epoch running 

In [2]:
import json
from datasets import load_dataset

data_path = 'selected_data/'

dataset_name = 'filtered-cured'

dataset = load_dataset("json", data_files= data_path + f'{dataset_name}_dataset.json')['train']

data_size = 2000
subset_size = len(dataset) // data_size


for i in range(subset_size):
    selected_indices = [idx for idx in range(data_size *i, data_size * (i+1))]
    subset = dataset.select(selected_indices)
    subset.to_json(data_path + f"{dataset_name}_{i}.json")

Generating train split: 10000 examples [00:00, 86674.93 examples/s]
Creating json from Arrow format: 100%|██████████| 2/2 [00:00<00:00, 25.49ba/s]
Creating json from Arrow format: 100%|██████████| 2/2 [00:00<00:00, 26.00ba/s]
Creating json from Arrow format: 100%|██████████| 2/2 [00:00<00:00, 35.97ba/s]
Creating json from Arrow format: 100%|██████████| 2/2 [00:00<00:00, 25.15ba/s]
Creating json from Arrow format: 100%|██████████| 2/2 [00:00<00:00, 34.96ba/s]
