In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch


def filter_and_pad_tensor_foro(tokens, activation):
    original_length = len(tokens)
    tokens_list = tokens.tolist()
    activation_list = activation.tolist()
    indices_to_remove = set()

    k = len(tokens) - 1
    pad_token_id = 128001
    if tokens[k] >= 128000:
        pad_token_id = tokens[k]
        while k >= 0 and tokens[k] == pad_token_id:
            k -= 1

    i = 0
    while i <= k:
        if tokens_list[i] >= 128000:
            indices_to_remove.add(i)
        i += 1

    filtered_tokens_list = [tokens_list[idx] for idx in range(len(tokens_list)) if idx not in indices_to_remove]
    filtered_activation_list = [activation_list[idx] for idx in range(len(tokens_list)) if idx not in indices_to_remove]

    num_to_pad = original_length - len(filtered_tokens_list)
    if num_to_pad > 0:
        filtered_tokens_list.extend([filtered_tokens_list[-1]] * num_to_pad)
        filtered_activation_list.extend([filtered_activation_list[-1]] * num_to_pad)

    return torch.tensor(filtered_tokens_list), torch.tensor(filtered_activation_list)


def filter_and_pad_tensor_fori(tokens, activation, otemplate):
    original_length = len(tokens)
    tokens_list = tokens.tolist()
    activation_list = activation.tolist()
    o_list = otemplate.tolist()
    objlen = len(o_list)
    indices_to_remove = set()

    k = len(tokens) - 1
    pad_token_id = 128009
    if tokens[k] >= 128000:
        pad_token_id = tokens[k]
        while k >= 0 and tokens[k] == pad_token_id:
            k -= 1

    j = 0
    for i in range(k + 1):
        if j < objlen and tokens_list[i] == o_list[j]:
            j += 1
        else:
            indices_to_remove.add(i)

    filtered_tokens_list = [tokens_list[idx] for idx in range(len(tokens_list)) if idx not in indices_to_remove]
    filtered_activation_list = [activation_list[idx] for idx in range(len(tokens_list)) if idx not in indices_to_remove]

    num_to_pad = original_length - len(filtered_tokens_list)
    if num_to_pad > 0:
        filtered_tokens_list.extend([filtered_tokens_list[-1]] * num_to_pad)
        filtered_activation_list.extend([filtered_activation_list[-1]] * num_to_pad)

    return torch.tensor(filtered_tokens_list), torch.tensor(filtered_activation_list)


def filter_and_pad_tensor_forb(tokens, activation, otemplate):
    original_length = len(tokens)
    tokens_list = tokens.tolist()
    activation_list = activation.tolist()
    o_list = otemplate.tolist()
    objlen = len(o_list)
    indices_to_remove = set()

    k = len(tokens) - 1
    pad_token_id = 128001
    if tokens[k] >= 128000:
        pad_token_id = tokens[k]
        while k >= 0 and tokens[k] == pad_token_id:
            k -= 1

    j = 0
    for i in range(k + 1):
        if j < objlen and tokens_list[i] == o_list[j]:
            j += 1
        else:
            indices_to_remove.add(i)

    filtered_tokens_list = [tokens_list[idx] for idx in range(len(tokens_list)) if idx not in indices_to_remove]
    filtered_activation_list = [activation_list[idx] for idx in range(len(tokens_list)) if idx not in indices_to_remove]

    num_to_pad = original_length - len(filtered_tokens_list)
    if num_to_pad > 0:
        filtered_tokens_list.extend([filtered_tokens_list[-1]] * num_to_pad)
        filtered_activation_list.extend([filtered_activation_list[-1]] * num_to_pad)

    return torch.tensor(filtered_tokens_list), torch.tensor(filtered_activation_list)

In [4]:
import os

from safetensors.torch import load_file, save_file

In [18]:
def process_file(args):
    input_folder, output_folder, filename, model_name = args
    input_path = os.path.join(input_folder, filename)
    output_path = os.path.join(output_folder, filename)
    if model_name == "o":
        filter_func = filter_and_pad_tensor_foro
    elif model_name == "i":
        filter_func = filter_and_pad_tensor_fori
    elif model_name == "b":
        filter_func = filter_and_pad_tensor_forb
    else:
        return f"there is no such model : {model_name}"

    if os.path.exists(input_path):
        tensor_data = load_file(input_path)

        tokens = tensor_data["tokens"][0]
        activation = tensor_data["activation"][0]

        tensor_data["tokens"][0], tensor_data["activation"][0] = filter_func(tokens, activation)

        save_file(tensor_data, output_path)
        return f"Processed and saved: {output_path}"
    else:
        return f"File not found: {input_path}"

In [19]:
from multiprocessing import Pool, cpu_count


def main(
    input_folder="/inspire/hdd/global_user/hezhengfu-240208120186/jiaxing_activations/reasondata-o-2d-l15/blocks.15.hook_resid_post",
    output_folder="/inspire/hdd/global_user/hezhengfu-240208120186/jiaxing_activations/reasondata-o-2d-l15-f/blocks.15.hook_resid_post",
):
    os.makedirs(output_folder, exist_ok=True)

    all_files = [f"shard-{shard}-chunk-{chunk:08d}.safetensors" for shard in range(8) for chunk in range(1557)]

    cpu_cores = min(cpu_count(), 16)
    with Pool(cpu_cores) as pool:
        results = pool.map(process_file, [(input_folder, output_folder, f, "o") for f in all_files])

    for result in results:
        print(result)


main()

KeyboardInterrupt: 

In [13]:
input_folder = (
    "/inspire/hdd/global_user/hezhengfu-240208120186/jiaxing_activations/reasondata-o-2d-l15/blocks.15.hook_resid_post"
)
output_folder = "/inspire/hdd/global_user/hezhengfu-240208120186/jiaxing_activations/reasondata-o-2d-l15-f/blocks.15.hook_resid_post"
os.makedirs(output_folder, exist_ok=True)

for shard in range(8):
    for chunk in range(1557):
        filename = f"shard-{shard}-chunk-{chunk:08d}.safetensors"
        input_path = os.path.join(input_folder, filename)
        output_path = os.path.join(output_folder, filename)

        if os.path.exists(input_path):
            tensor_data = load_file(input_path)

            tokens = tensor_data["tokens"][0]
            activation = tensor_data["activation"][0]

            tensor_data["tokens"][0], tensor_data["activation"][0] = filter_and_pad_tensor_foro(tokens, activation)

            save_file(tensor_data, output_path)
            print(f"Processed and saved: {output_path}")
        else:
            print(f"File not found: {input_path}")

Processed and saved: /inspire/hdd/global_user/hezhengfu-240208120186/jiaxing_activations/reasondata-o-2d-l15/blocks.15.hook_resid_post/shard-0-chunk-00000000.safetensors
Processed and saved: /inspire/hdd/global_user/hezhengfu-240208120186/jiaxing_activations/reasondata-o-2d-l15/blocks.15.hook_resid_post/shard-0-chunk-00000001.safetensors
Processed and saved: /inspire/hdd/global_user/hezhengfu-240208120186/jiaxing_activations/reasondata-o-2d-l15/blocks.15.hook_resid_post/shard-0-chunk-00000002.safetensors
Processed and saved: /inspire/hdd/global_user/hezhengfu-240208120186/jiaxing_activations/reasondata-o-2d-l15/blocks.15.hook_resid_post/shard-0-chunk-00000003.safetensors
Processed and saved: /inspire/hdd/global_user/hezhengfu-240208120186/jiaxing_activations/reasondata-o-2d-l15/blocks.15.hook_resid_post/shard-0-chunk-00000004.safetensors
Processed and saved: /inspire/hdd/global_user/hezhengfu-240208120186/jiaxing_activations/reasondata-o-2d-l15/blocks.15.hook_resid_post/shard-0-chunk-0

KeyboardInterrupt: 

In [12]:
print(
    os.path.exists(
        "/inspire/hdd/global_user/hezhengfu-240208120186/jiaxing_activations/reasondata-o-2d-l15/shard-0-chunk-00000000.safetensors"
    )
)
print(os.path.exists("/inspire/hdd/global_user/hezhengfu-240208120186/jiaxing_activations/reasondata-o-2d-l15"))
print(os.listdir("/inspire/hdd/global_user/hezhengfu-240208120186/jiaxing_activations/reasondata-o-2d-l15"))

False
True
['blocks.15.hook_resid_post']
