## Check the loss diff manually

In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import torch
from functools import partial

def encode_with_prompt_completion_format(example, tokenizer, max_seq_length, add_bos=False):
    '''
    Here we assume each example has 'prompt' and 'completion' fields.
    We concatenate prompt and completion and tokenize them together because otherwise prompt will be padded/trancated 
    and it doesn't make sense to follow directly with the completion.
    '''
    # if prompt doesn't end with space and completion doesn't start with space, add space
    if not example['prompt'].endswith((' ', '\n', '\t')) and not example['completion'].startswith((' ', '\n', '\t')):
        example_text = example['prompt'] + ' ' + example['completion']
    else:
        example_text = example['prompt'] + example['completion']
    example_text = example_text + tokenizer.eos_token
    if add_bos:
        example_text = tokenizer.bos_token + example_text
    tokenized_example = tokenizer(example_text, return_tensors='pt', max_length=max_seq_length, truncation=True)
    input_ids = tokenized_example.input_ids
    labels = input_ids.clone()
    tokenized_prompt = tokenizer(example['prompt'], return_tensors='pt', max_length=max_seq_length, truncation=True)
    # mask the prompt part for avoiding loss
    # labels[:, :tokenized_prompt.input_ids.shape[1]] = -100
    attention_mask = torch.ones_like(input_ids)
    return {
        'input_ids': input_ids.flatten(),
        'labels': labels.flatten(),
        'attention_mask': attention_mask.flatten(),
    }


def encode_with_messages_format(example, tokenizer, max_seq_length, add_bos=False):
    '''
    Here we assume each example has a 'messages' field Each message is a dict with 'role' and 'content' fields.
    We concatenate all messages with the roles as delimiters and tokenize them together.
    '''
    messages = example['messages']
    if len(messages) == 0:
        raise ValueError('messages field is empty.')
    
    def _concat_messages(messages):
        message_text = ""
        for message in messages:
            if message["role"] == "system":
                message_text += "<|system|>\n" + message["content"].strip() + "\n"
            elif message["role"] == "user":
                message_text += "<|user|>\n" + message["content"].strip() + "\n"
            elif message["role"] == "assistant":
                message_text += "<|assistant|>\n" + message["content"].strip() + tokenizer.eos_token + "\n"
            else:
                raise ValueError("Invalid role: {}".format(message["role"]))
        return message_text
        
    example_text = _concat_messages(messages).strip()
    if add_bos:
        example_text = tokenizer.bos_token + example_text
    tokenized_example = tokenizer(example_text, return_tensors='pt', max_length=max_seq_length, truncation=True)
    input_ids = tokenized_example.input_ids
    labels = input_ids.clone()

    # mask the non-assistant part for avoiding loss
    for message_idx, message in enumerate(messages):
        if message["role"] != "assistant":
            if message_idx == 0:
                message_start_idx = 0
            else:
                message_start_idx = tokenizer(
                    _concat_messages(messages[:message_idx]), return_tensors='pt', max_length=max_seq_length, truncation=True
                ).input_ids.shape[1]
            if message_idx < len(messages) - 1 and messages[message_idx+1]["role"] == "assistant":
                # here we also ignore the role of the assistant
                messages_so_far = _concat_messages(messages[:message_idx+1]) + "<|assistant|>\n"
            else:
                messages_so_far = _concat_messages(messages[:message_idx+1])
            message_end_idx = tokenizer(
                messages_so_far,
                return_tensors='pt', 
                max_length=max_seq_length, 
                truncation=True
            ).input_ids.shape[1]
            # labels[:, message_start_idx:message_end_idx] = -100

            if message_end_idx >= max_seq_length:
                break

    attention_mask = torch.ones_like(input_ids)
    return {
        'input_ids': input_ids.flatten(),
        'labels': labels.flatten(),
        'attention_mask': attention_mask.flatten(),
    }

# dataset_name='test_dataset'
dataset_name = 'filtered-cured-50k-all-iter-sample-subset-small-new_6'

model_name_or_path = "meta-llama/Llama-3.2-3B"
data_path = f"selected_data/{dataset_name}.json"

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
raw_dataset = load_dataset("json", data_files=data_path)

if "prompt" in raw_dataset["train"].column_names and "completion" in raw_dataset["train"].column_names:
    encode_function = partial(
        encode_with_prompt_completion_format,
        tokenizer=tokenizer,
        max_seq_length= 2048,
        add_bos= False,
    )
elif "messages" in raw_dataset["train"].column_names:
    encode_function = partial(
        encode_with_messages_format,
        tokenizer=tokenizer,
        max_seq_length= 2048,
        add_bos= False,
    )
    
raw_dataset = raw_dataset.map(
    lambda example, idx: {"idx": idx},
    with_indices=True,  
    desc="Adding idx column",
)
        

lm_datasets = raw_dataset.map(
    encode_function,
    batched=False,
    # remove_columns=[name for name in raw_dataset["train"].column_names if name not in ["idx", "input_ids", "labels", "attention_mask"]],
    desc="Tokenizing and reformatting instruction data",
)

train_dataset = lm_datasets['train']
raw_labels = train_dataset['labels']

Adding idx column:   0%|          | 0/1000 [00:00<?, ? examples/s]

Tokenizing and reformatting instruction data:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [17]:
import torch
import numpy as np


def get_global_top_k_indices(data, k):

    flattened = [(value, i, j) for i, sublist in enumerate(data) for j, value in enumerate(sublist)]
    
    top_k = sorted(flattened, key=lambda x: x[0], reverse=True)[:k] ##loss
    
    top_k_indices = [(item[1], item[2]+1) for item in top_k]  #item[2]+1 fix the first label biased to match the position
    return top_k_indices



data_type = 'filtered-cured-50k-all-iter-sample-subset-small-new_6'
model_type = 'filtered-cured-50k-all-iter-sample-subset-small-new_4'
new_model_type ='filtered-cured-50k-all-iter-sample-subset-small-new_5'
data_prop=0.2

losses_pre = torch.load(f"results/loss/token_losses_{data_type}_{model_type}.pt")
losses_cur = torch.load(f"results/loss/token_losses_{data_type}_{new_model_type}.pt")

# initialize: ignore all tokens
selected_labels = [[-100 for _ in range(len(label))] for label in raw_labels]

##the calculation different loss of two models

loss_diff = [(np.array(loss1) - np.array(loss2)).tolist() for loss1, loss2 in zip(losses_pre, losses_cur)]

all_token_count = sum(len(label) for label in raw_labels)
print(f"#### all token counting: {all_token_count}\n")

print(f"model pair: ({model_type}, {new_model_type}) -- dataset: {data_type}")


#### selected pattern    
selected_global_labels = [[-100 for _ in range(len(label))] for label in raw_labels]
select_global_tokens_indices = get_global_top_k_indices(loss_diff, int(all_token_count * data_prop))
for i, j in select_global_tokens_indices:
        selected_global_labels[i][j] = raw_labels[i][j] 

### sample-level 
selected_sample_labels = [[-100 for _ in range(len(label))] for label in raw_labels]
select_sample_tokens_indices = []
for diff in loss_diff:
    _, indices = torch.topk(torch.tensor(diff), k=int(len(diff) * data_prop), largest=True)
    select_sample_tokens_indices.append((indices + 1).tolist()) ## indices +1 represents the biased value, which match the real token in the original dataset    
for i, (selected_indices, label) in enumerate(zip(select_sample_tokens_indices, raw_labels)):
    for j in selected_indices:
        selected_sample_labels[i][j] = label[j]

                
    
# 计算相似比例
valid_match_count = 0
valid_union_count = 0

all_token_count = sum(len(label) for label in selected_sample_labels)

for sample_labels, global_labels in zip(selected_sample_labels, selected_global_labels):
    # 确保两个列表对齐比较（以较短的长度为准）
    for sl, gl in zip(sample_labels, global_labels):
        if sl == gl and sl != -100:
            valid_match_count += 1
        if sl != -100 or gl !=-100:
            valid_union_count +=1

# 计算相似比例
similarity_ratio = valid_match_count / all_token_count 
union_ratio = valid_union_count / all_token_count 

print(f"Similarity ratio: {similarity_ratio:.2%}")
print(f"Union ratio: {union_ratio:.2%}")

  losses_pre = torch.load(f"results/loss/token_losses_{data_type}_{model_type}.pt")
  losses_cur = torch.load(f"results/loss/token_losses_{data_type}_{new_model_type}.pt")


#### all token counting: 487626

model pair: (filtered-cured-50k-all-iter-sample-subset-small-new_4, filtered-cured-50k-all-iter-sample-subset-small-new_5) -- dataset: filtered-cured-50k-all-iter-sample-subset-small-new_6
Similarity ratio: 18.02%
Union ratio: 21.85%


In [28]:
for loss, labels in zip(loss_diff, raw_labels):
    print(f"len loss: {len(loss)} -- len labels: {len(labels)}")

len loss: 72 -- len labels: 73
len loss: 1322 -- len labels: 1323
len loss: 231 -- len labels: 232
len loss: 167 -- len labels: 168
len loss: 337 -- len labels: 338
len loss: 272 -- len labels: 273
len loss: 435 -- len labels: 436
len loss: 121 -- len labels: 122
len loss: 250 -- len labels: 251
len loss: 113 -- len labels: 114
len loss: 46 -- len labels: 47
len loss: 223 -- len labels: 224
len loss: 788 -- len labels: 789
len loss: 311 -- len labels: 312
len loss: 199 -- len labels: 200
len loss: 530 -- len labels: 531
len loss: 468 -- len labels: 469
len loss: 67 -- len labels: 68
len loss: 682 -- len labels: 683
len loss: 342 -- len labels: 343
len loss: 549 -- len labels: 550
len loss: 418 -- len labels: 419
len loss: 423 -- len labels: 424
len loss: 125 -- len labels: 126
len loss: 32 -- len labels: 33
len loss: 142 -- len labels: 143
len loss: 797 -- len labels: 798
len loss: 925 -- len labels: 926
len loss: 621 -- len labels: 622
len loss: 712 -- len labels: 713
len loss: 359 --

In [12]:
all_token_count
selected_sample_labels_count = sum(
    sum(1 for label in sample if label != -100) for sample in selected_sample_labels
)

selected_global_labels_count = sum(
    sum(1 for label in sample if label != -100) for sample in selected_global_labels
)

print(f"selected_sample_labels_count: {selected_sample_labels_count}")
print(f"selected_global_labels_count: {selected_global_labels_count}")


selected_sample_labels_count: 486626
selected_global_labels_count: 486626


In [30]:
for i, j in select_global_tokens_indices[:100]:
    print(loss_diff[i][j-1])

8.806870087981224
7.24859094619751
5.935330390930176
5.923524379730225
5.831694602966309
5.696369171142578
5.599660396575928
5.513375282287598
5.452551603317261
5.446721792221069
5.339729070663452
5.325538635253906
5.273169368505478
5.250312805175781
5.1521932780742645
5.13873291015625
5.120546251535416
5.079023122787476
5.059002876281738
5.033239841461182
4.9899067878723145
4.970749855041504
4.970494270324707
4.877408921718597
4.873149871826172
4.8416032791137695
4.835258483886719
4.813791990280151
4.802964925765991
4.719444274902344
4.6886162757873535
4.616905689239502
4.6069276332855225
4.573083877563477
4.572339057922363
4.5693793296813965
4.558010935783386
4.553904183208942
4.536604881286621
4.53176212310791
4.52959406375885
4.526368111371994
4.510976791381836
4.508610963821411
4.50351881980896
4.497838497161865
4.497535824775696
4.488606750965118
4.387574315071106
4.382802486419678
4.359728693962097
4.353085041046143
4.350334644317627
4.335271120071411
4.323821067810059
4.3050212

In [29]:
select_sample_tokens_indices = []
loss_selected = []
for diff in loss_diff:
    values, indices = torch.topk(torch.tensor(diff), k=int(len(diff) * data_prop), largest=True)
    select_sample_tokens_indices.append((indices).tolist()) 
    loss_selected.extend(values.tolist())
    # print(values.tolist())

# sorted(flattened, key=lambda x: x[0], reverse=True)[:k]
sorted(loss_selected, reverse=True)[:100]

[8.806870460510254,
 7.24859094619751,
 5.935330390930176,
 5.923524379730225,
 5.831694602966309,
 5.696369171142578,
 5.599660396575928,
 5.513375282287598,
 5.45255184173584,
 5.446722030639648,
 5.339729309082031,
 5.325538635253906,
 5.27316951751709,
 5.250312805175781,
 5.152193069458008,
 5.13873291015625,
 5.120546340942383,
 5.079023361206055,
 5.059002876281738,
 5.033239841461182,
 4.9899067878723145,
 4.970749855041504,
 4.970494270324707,
 4.877408981323242,
 4.873149871826172,
 4.8416032791137695,
 4.835258483886719,
 4.8137922286987305,
 4.80296516418457,
 4.719444274902344,
 4.6886162757873535,
 4.616905689239502,
 4.606927871704102,
 4.573083877563477,
 4.572339057922363,
 4.5693793296813965,
 4.558011054992676,
 4.553904056549072,
 4.536604881286621,
 4.53176212310791,
 4.5295939445495605,
 4.526368141174316,
 4.510976791381836,
 4.508610725402832,
 4.503519058227539,
 4.497838497161865,
 4.497535705566406,
 4.488606929779053,
 4.387574195861816,
 4.382802486419678,
