### Token mask label level

In [None]:
import numpy as np
import os
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw
import csv
import torch
import torchvision.transforms as T

import csv
import torchvision.transforms.functional as TF


import pickle

data = "generated"


def preprocess(img, target_image_size=256):
    s = min(img.size)
    if s < target_image_size:
        raise ValueError(f'Min dimension for image {s} < {target_image_size}')
    r = target_image_size / s
    s = (round(r * img.size[1]), round(r * img.size[0]))
    img = TF.resize(img, s, interpolation=Image.LANCZOS)
    img = TF.center_crop(img, output_size=2 * [target_image_size])
    img = torch.unsqueeze(T.ToTensor()(img), 0)
    return img


def load_top_tokens_baseline(input_csv, top_n):
    # 存储Token和Count的列表
    tokens = []
    
    # 读取CSV文件
    with open(input_csv, 'r', encoding='utf-8') as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            token = int(row['Token'])  # 读取Token
            count = int(row['Count'])  # 读取Count
            tokens.append((token, count))  # 将Token和Count存入列表
    
    # 按照Count降序排序
    tokens.sort(key=lambda x: x[1], reverse=True)
    
    # 获取前N个Token
    top_n_tokens = [token for token, _ in tokens[:top_n]]
    
    return top_n_tokens

def load_top_tokens(csv_path, top_n, token_number):
    target_token_list = []
    """
    从指定的csv文件中加载Top N的某一行token及其对应的文件列表。

    参数:
    csv_path (str): 要读取的csv文件路径。
    top_n (int): 要查询的Top N级别（如1, 5, 10, 20）。
    row_num (int): 要查询的行号（从1开始，排除表头）。

    返回:
    token (str): 对应的token索引。
    files (list): 对应的文件列表。
    """
    with open(csv_path, 'r') as csvfile:
        reader = csv.reader(csvfile)
        in_top_n_section = False
        current_row = 0

        for row in reader:
            # 检查是否到了Top N的部分
            if f"Top {top_n} Tokens" in row:
                in_top_n_section = True
                next(reader)
                current_row = 0
                continue

            # 如果到了Top N部分，开始读取指定行
            if in_top_n_section:
                if "Top" in row[0] or current_row == token_number:
                    break
                token = int(row[0])  # 获取token索引
                files = row[2].split('; ')  # 获取文件名列表，并按分号分割
                target_token_list.append(token)
                current_row += 1

    return target_token_list  # 如果未找到，则返回None


# 提供要查找的特定 token
target_label = 83  # 替换为你要查找的token索引

# # 提供要处理的npy文件名列表

if data == "generated":
    csv_path = f"/data2/ty45972_data2/taming-transformers/codebook_explanation_classification/results/Explanation/generated_data/label/Net1/label_activation_statistics/label_{target_label}.csv"
    test_csv = f"/data2/ty45972_data2/taming-transformers/codebook_explanation_classification/datasets/VQGAN_16384_generated_new/test_embeddings.csv"
    image_base_path = '/data2/ty45972_data2/taming-transformers/datasets/imagenet_VQGAN_generated/'
    baseline_path = f"/data2/ty45972_data2/taming-transformers/codebook_explanation_classification/results/Explanation/baseline_statistics/label_{target_label}.csv"
    with open('/data2/ty45972_data2/taming-transformers/codebook_explanation_classification/datasets/VQGAN_16384_generated_new/test_token_indices.pkl', 'rb') as f:
        token_dict = pickle.load(f)

# elif data == "original":
#     csv_path = f"/data2/ty45972_data2/taming-transformers/codebook_explanation_classification/results/Explanation/original_data/label/Net1/label_activation_statistics/label_{target_label}.csv"
#     activation_results_path = f"/data2/ty45972_data2/taming-transformers/codebook_explanation_classification/results/Explanation/original_data/label/Net1/label_activation_results/label_{target_label}.csv"
#     image_base_path = "/data2/ty45972_data2/taming-transformers/datasets/imagenet/train"
#     with open('/data2/ty45972_data2/taming-transformers/codebook_explanation_classification/datasets/VQGAN_16384_original/train_token_indices.pkl', 'rb') as f:
#         token_dict = pickle.load(f)

top_n = 20  # 表示查找Top n Tokens
token_num = 50  # 查找第1行的token及其文件列表


target_token_list = load_top_tokens(csv_path, top_n, token_num)
target_token_list_baseline = load_top_tokens_baseline(baseline_path, token_num)
print(f"baseline token list is {target_token_list_baseline}")

if target_token_list:
    print(f"target token list is {len(target_token_list)}")
else:
    print("Cannot find the specific token")

npy_file_list = []
with open(test_csv, 'r', encoding='utf-8') as csvfile:
    reader = csv.reader(csvfile)
    next(reader)  # 跳过表头
    
    # 遍历CSV文件的每一行
    for row in reader:
        filename = row[0]  # 获取filename列（第1列）
        label = int(row[1])     # 获取label列（第2列）
        
        # 如果label是指定的label，则将filename添加到列表中
        if label == target_label:
            npy_file_list.append(filename)
print(f"npy_file_list is {len(npy_file_list)}")
# 定义每个图像的token网格大小，假设为16x16
grid_size = 16
image_size = 256
patch_size = image_size // grid_size

def visualize_token_on_image(npy_filename, token_dict, target_token_list):
    # 提取子文件夹和图片名信息
    subfolder, image_name = npy_filename.split('_')
    image_name = image_name.replace('.npy', '.png')
    
    # 构建图像路径
    if data == "generated":
        image_path = os.path.join(image_base_path, subfolder, image_name)
    elif data == "original":
        image_path = os.path.join(image_base_path, subfolder, npy_filename.replace(".npy", ".JPEG"))
    
    print(f"Image path is {image_path}")
    # 检查文件是否存在
    if not os.path.exists(image_path):
        print(f"Image {image_path} does not exist.")
        return
    
    # 打开图像
    image = Image.open(image_path)
    if data == "original":
        processed_img = preprocess(image)
        processed_img_pil = Image.fromarray((processed_img.squeeze(0).permute(1, 2, 0).numpy() * 255).astype(np.uint8))
        image = processed_img_pil
    
    # 获取该文件对应的 token 列表
    token_list = token_dict.get(npy_filename)
    
    if token_list is None:
        print(f"No token list found for {npy_filename}.")
        return
    
    # 查找目标 token 的所有索引位置
    token_positions = [i for i, token in enumerate(token_list) if token in target_token_list]

    

    # 在图像上mask掉每个目标token的位置
    draw = ImageDraw.Draw(image)
    for token_position in token_positions:
        row = token_position // grid_size
        col = token_position % grid_size
        
        # 计算token在原图中的坐标
        left = col * patch_size
        upper = row * patch_size
        right = left + patch_size
        lower = upper + patch_size
        
        # 用黑色填充这些区域，表示mask
        draw.rectangle([left, upper, right, lower], fill=(0, 0, 0))
    
    # 可视化图像
    plt.figure(figsize=(6, 6))
    plt.imshow(image)
    plt.title(f"token list")
    plt.axis('off')
    plt.show()

# 遍历所有 npy 文件，进行可视化

for i, npy_file in enumerate(npy_file_list):
    visualize_token_on_image(npy_file, token_dict, target_token_list_baseline)
    if i > 10:
        break


### Save Token indice pkl

In [None]:
import pickle
from tqdm import tqdm
import ast

def load_token_indices(embedding_csv_path):
    # 生成保存token indices的文件名
    token_indices_save_path = embedding_csv_path.replace("test_embeddings.csv", "test_token_indices.pkl")
    
    # 检查是否已经存在保存的token_indices文件
    if os.path.exists(token_indices_save_path):
        print(f"Loading token indices from {token_indices_save_path}")
        with open(token_indices_save_path, 'rb') as f:
            token_indices_dict = pickle.load(f)
        return token_indices_dict
    
    # 如果没有保存的token_indices文件，就进行处理
    print(f"Processing token indices from {embedding_csv_path}")
    token_indices_dict = {}

    # 先计算文件中的总行数，以便显示进度条
    with open(embedding_csv_path, 'r') as infile:
        total_lines = sum(1 for _ in infile) - 1  # 减去header行

    # 重新打开文件并读取内容，同时显示进度条
    with open(embedding_csv_path, 'r') as infile:
        reader = csv.reader(infile)
        next(reader)  # 跳过header
        for row in tqdm(reader, total=total_lines, desc="Loading token indices"):
            npy_file = row[0]
            token_indices = ast.literal_eval(row[2])
            token_indices_dict[npy_file] = token_indices
    
    # 保存处理后的token_indices_dict
    with open(token_indices_save_path, 'wb') as f:
        pickle.dump(token_indices_dict, f)
    print(f"Token indices saved to {token_indices_save_path}")
    
    return token_indices_dict

embedding_csv_path = "/data2/ty45972_data2/taming-transformers/codebook_explanation_classification/datasets/VQGAN_16384_generated_new/test_embeddings.csv"
token_indices_dict = load_token_indices(embedding_csv_path)

In [4]:
import torch
import torchvision
from torchvision import transforms
from PIL import Image, ImageDraw
import os
import pickle
import csv
import matplotlib.pyplot as plt

# 1. 加载预训练的ResNet50模型
model = torchvision.models.vit_b_32(pretrained=True)
model.eval()

# 2. 定义图像预处理步骤（调整为224x224尺寸）
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

# 图像预处理函数
def preprocess_image(image):
    input_tensor = preprocess(image)
    input_batch = input_tensor.unsqueeze(0)  # 添加批次维度
    return input_batch

# 加载top N tokens的函数
def load_top_tokens_baseline(input_csv, top_n):
    tokens = []
    with open(input_csv, 'r', encoding='utf-8') as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            token = int(row['Token'])
            count = int(row['Count'])
            tokens.append((token, count))
    
    tokens.sort(key=lambda x: x[1], reverse=True)
    top_n_tokens = [token for token, _ in tokens[:top_n]]
    
    return top_n_tokens

# 加载token列表
def load_top_tokens(csv_path, top_n, token_number):
    target_token_list = []
    with open(csv_path, 'r') as csvfile:
        reader = csv.reader(csvfile)
        in_top_n_section = False
        current_row = 0

        for row in reader:
            if f"Top {top_n} Tokens" in row:
                in_top_n_section = True
                next(reader)
                current_row = 0
                continue

            if in_top_n_section:
                if "Top" in row[0] or current_row == token_number:
                    break
                token = int(row[0])
                target_token_list.append(token)
                current_row += 1

    return target_token_list

# 定义路径和目标label
target_label = 83
csv_path = f"/data2/ty45972_data2/taming-transformers/codebook_explanation_classification/results/Explanation/generated_data/label/Net1/label_activation_statistics/label_{target_label}.csv"
test_csv = f"/data2/ty45972_data2/taming-transformers/codebook_explanation_classification/datasets/VQGAN_16384_generated_new/test_embeddings.csv"
image_base_path = '/data2/ty45972_data2/taming-transformers/datasets/imagenet_VQGAN_generated/'
baseline_path = f"/data2/ty45972_data2/taming-transformers/codebook_explanation_classification/results/Explanation/baseline_statistics/label_{target_label}.csv"

# 加载token字典
with open('/data2/ty45972_data2/taming-transformers/codebook_explanation_classification/datasets/VQGAN_16384_generated_new/test_token_indices.pkl', 'rb') as f:
    token_dict = pickle.load(f)

top_n = 20  # 查找前N个token
token_num = 50  # 查找第1行的token及其文件列表

# 加载tokens
target_token_list = load_top_tokens(csv_path, top_n, token_num)
target_token_list_baseline = load_top_tokens_baseline(baseline_path, token_num)
print(f"baseline token list is {target_token_list_baseline}")

if target_token_list:
    print(f"target token list is {len(target_token_list)}")
else:
    print("Cannot find the specific token")

# 从test CSV加载文件列表
npy_file_list = []
with open(test_csv, 'r', encoding='utf-8') as csvfile:
    reader = csv.reader(csvfile)
    next(reader)
    for row in reader:
        filename = row[0]
        label = int(row[1])
        if label == target_label:
            npy_file_list.append(filename)
print(f"npy_file_list is {len(npy_file_list)}")

# 定义图像网格
grid_size = 16
image_size = 256
patch_size = image_size // grid_size

# mask和对比logits函数
def visualize_token_on_image_and_compare(npy_filename, token_dict, target_token_list):
    subfolder, image_name = npy_filename.split('_')
    image_name = image_name.replace('.npy', '.png')
    image_path = os.path.join(image_base_path, subfolder, image_name)
    
    print(f"Image path is {image_path}")
    if not os.path.exists(image_path):
        print(f"Image {image_path} does not exist.")
        return
    
    image = Image.open(image_path)
    token_list = token_dict.get(npy_filename)
    
    if token_list is None:
        print(f"No token list found for {npy_filename}.")
        return
    
    token_positions = [i for i, token in enumerate(token_list) if token in target_token_list]

    # 原始图像预测logits
    input_batch = preprocess_image(image).to(device)
    with torch.no_grad():
        original_output = model(input_batch)
    original_logits = original_output[0, target_label].item()
    print(f"Original logits for target label: {original_logits}")
    
    # Mask token区域并预测logits
    draw = ImageDraw.Draw(image)
    for token_position in token_positions:
        row = token_position // grid_size
        col = token_position % grid_size
        left = col * patch_size
        upper = row * patch_size
        right = left + patch_size
        lower = upper + patch_size
        draw.rectangle([left, upper, right, lower], fill=(0, 0, 0))
    
    masked_batch = preprocess_image(image).to(device)
    with torch.no_grad():
        masked_output = model(masked_batch)
    masked_logits = masked_output[0, target_label].item()
    print(f"Masked logits for target label: {masked_logits}")

    # 对比logits差异
    logits_diff = original_logits - masked_logits
    print(f"Difference in logits for target label: {logits_diff}")
    print("=" * 50)

# 遍历所有文件，进行可视化和logits对比
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

for i, npy_file in enumerate(npy_file_list):
    visualize_token_on_image_and_compare(npy_file, token_dict, target_token_list)
    if i > 10:
        break




baseline token list is [4815, 10227, 11456, 11718, 12618, 7749, 948, 13670, 13291, 7504, 6172, 7467, 14447, 14703, 13909, 6965, 9470, 6523, 5772, 4493, 12159, 11004, 5722, 6805, 601, 1304, 10627, 10278, 12016, 11196, 11147, 8867, 15523, 14247, 12623, 6628, 774, 11591, 6783, 3932, 7771, 4945, 14037, 2528, 14975, 3812, 2127, 6690, 6386, 2255]
target token list is 50
npy_file_list is 50
Image path is /data2/ty45972_data2/taming-transformers/datasets/imagenet_VQGAN_generated/83/001300.png
Original logits for target label: 8.720815658569336
Masked logits for target label: 6.315000057220459
Difference in logits for target label: 2.405815601348877
Image path is /data2/ty45972_data2/taming-transformers/datasets/imagenet_VQGAN_generated/83/001301.png
Original logits for target label: 8.419610977172852
Masked logits for target label: 3.4305028915405273
Difference in logits for target label: 4.989108085632324
Image path is /data2/ty45972_data2/taming-transformers/datasets/imagenet_VQGAN_generated