In [2]:
try:
    import google.colab # type: ignore
    from google.colab import output
    COLAB = True
    %pip install sae-lens transformer-lens
except:
    COLAB = False
    from IPython import get_ipython # type: ignore
    ipython = get_ipython(); assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

# Standard imports
import os
import torch
from tqdm import tqdm
import plotly.express as px  
import random
from datasets import Dataset, DatasetDict, IterableDataset, load_dataset,load_from_disk
from transformer_lens import HookedTransformer
from typing import Any, Generator, Iterator, Literal, cast
from sae_lens import SAE
from transformers import (
    AutoTokenizer,
    LlavaNextForConditionalGeneration,
    LlavaNextProcessor,
    AutoModelForCausalLM,
)
from pathlib import Path



from transformer_lens.HookedLlava import HookedLlava
from sae_lens.activation_visualization import (
    load_llava_model,
    load_sae,
    separate_feature,
    run_model,
)
# os.environ["TMP_DIR"]="/home/yaodong/tmp"
# os.environ["CUDA_VISIBLE_DEVICES"] = "6,7" 
model_name = "llava-hf/llava-v1.6-mistral-7b-hf"
model_path="/data/models/llava-v1.6-mistral-7b-hf"
sae_path="/data/changye/model/llavasae_obliec100k_SAEV"
sae_device="cuda:1"
device="cuda:0"

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
 # 加载模型
processor,  hook_language_model = load_llava_model(
        model_name, model_path, device,n_devices=2,stop_at_layer=17
    )
sae = load_sae(sae_path, sae_device)
# del vision_model
# the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store)
# Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict
# We also return the feature sparsities which are stored in HF for convenience. 
# sae, cfg_dict, sparsity = SAE.from_pretrained(
#     release = "gpt2-small-res-jb", # see other options in sae_lens/pretrained_saes.yaml
#     sae_id = "blocks.8.hook_resid_pre", # won't always be a hook point
#     device = device
# )



## loading dataset

In [None]:
import os
os.environ["HF_DATASETS_CACHE"] = "/aifs4su/yaodong/changye/tmp"
dataset_path="/aifs4su/yaodong/hantao/datasets/MMInstruct-GPT4V"
# system_prompt= "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. "
# user_prompt= 'USER: \n<image> {input}'
# assistant_prompt= '\nASSISTANT: {output}'
# split_token= 'ASSISTANT:'
train_dataset = load_dataset(
            dataset_path,
            'qa_en',
            split="train",
            trust_remote_code=True,
            cache_dir="/aifs4su/yaodong/changye/tmp"
        )
print(train_dataset)
sample_size = 1000
total_size = len(train_dataset)
random_indices = random.sample(range(total_size), sample_size)
sampled_dataset = train_dataset.select(random_indices)

Generating train split: 216462 examples [00:00, 249622.62 examples/s]

Dataset({
    features: ['id', 'image', 'conversations'],
    num_rows: 216462
})





In [13]:
from PIL import Image
import io
from io import BytesIO
print(train_dataset[0]['image'])
image = Image.open("/aifs4su/yaodong/changye/images/"+train_dataset[0]['image'])
image = image.resize((336, 336)).convert('RGBA')
print(image)

social_relation/0001/00000770.jpg
<PIL.Image.Image image mode=RGBA size=336x336 at 0x155137B0DE90>


In [None]:


# 定义格式化函数
def format_sample(raw_sample: dict[str, Any]) -> dict[str, Any]:
    """
    格式化样本，只提取 question 和 image 字段，并生成所需的 prompt。
    """
    # 获取并清洗 question 字段
    prompt = raw_sample['question'].replace('<image>\n', '').replace('\n<image>', '').replace('<image>', '')
    
    # 加载和处理 image 字段
    image = raw_sample['image']
    # if isinstance(image, str):  # 如果 image 是路径
    #     image = Image.open(image).convert('RGBA')
    # elif hasattr(image, "convert"):  # 如果是 PIL.Image 对象
    image=image.resize((336,336))
    image = image.convert('RGBA')

    
    # 格式化 Prompt
    formatted_prompt = (
        f'{system_prompt}'
        f'{user_prompt.format(input=prompt)}'
        f'{assistant_prompt.format(output="")}'
    )
    
    return {
        'prompt': formatted_prompt,
        'image': image,
        'image_name':raw_sample['image_name']
    }

# 使用 map 方法处理数据集
formatted_dataset = sampled_dataset.map(
    format_sample,
    num_proc=80,  # 根据您的 CPU 核心数量调整
    remove_columns=['chosen','rejected','question'],
)
# print(formatted_dataset)
# 如果需要进一步处理，可以将 formatted_dataset 转换为列表
formatted_sample = formatted_dataset[:]
# print(formatted_sample['image_name'][0])

hf_dataset = Dataset.from_dict(formatted_sample)

# 保存为 Arrow 格式
save_path = "/data/changye/data/SPA_VL1k"
os.makedirs(save_path, exist_ok=True)
hf_dataset.save_to_disk(save_path)
print(f"Dataset saved to {save_path}")




In [None]:
# image_name_list=[]

# for data in tqdm(train_dataset):
#     image_name=data['image_name']
#     if image_name in image_name_list:
#         print("error!")
#         break
#     else:
#         image_name_list.append(image_name)

In [None]:
inputs = processor(
        text=formatted_sample['prompt'],
        images=formatted_sample['image'],
        return_tensors='pt',
        padding='max_length',  # 设置padding为最大长度
        max_length=256,  # 设置最大长度
    ).to(device)

# 打印一个处理后的示例
print((inputs['input_ids'].shape))
torch.cuda.empty_cache()


In [None]:
# for batch in processed_dataset:
#     # print(dir(batch))
#     image_indices, feature_act = run_model(batch, hook_language_model, sae, sae_device)
#     break  

image_indices, feature_act = run_model(inputs, hook_language_model, sae, sae_device)


In [None]:
print((image_indices.shape))
print(feature_act.shape)




In [None]:
cooccurrence_feature=separate_feature(image_indices, feature_act)
print(len(cooccurrence_feature[1]))

In [None]:
data_dict={}
for i in range(len(cooccurrence_feature)):
    data_dict[formatted_sample['image_name'][i]]=cooccurrence_feature[i]
print(data_dict)
batch_size = 10000
for i in range(0, len(data_dict), batch_size):
    batch_dict = dict(list(data_dict.items())[i:i+batch_size])
    torch.save(batch_dict, f'data_batch_{i // batch_size}.pt')

In [None]:
from datasets import load_dataset
train_dataset = load_dataset('PKU-Alignment/Align-Anything',name='text-image-to-text',cache_dir="/mnt/file2/changye/dataset/Align-Anything_preference")['train']

In [None]:
print(train_dataset)