In [1]:
import os
from PIL import Image
import json
from torchvision import transforms
import torch
from sklearn.metrics import precision_recall_fscore_support
from tqdm import tqdm
from transformers import ViTForImageClassification
import numpy as np


def write_to_txt(filepath, content):
    with open(filepath, "w") as file:
        file.write(content)

def load_images_from_directory(root_path: str):
    dataset = []
    for label_folder in os.listdir(root_path):
        parts = label_folder.split('_')
        image_id = '_'.join(parts[:-1]) 
        image_label = parts[-1]  
        dataset.append((image_id, image_label))
    return dataset


def calculate_prediction_changes(original_scores, masked_scores, top_n_indices):
    idx = top_n_indices[0]
    original_score = original_scores[idx]
    masked_score = masked_scores[idx]
    change = max(0, original_score - masked_score)
    if original_score > 0:
        percentage = change / original_score
    elif original_score < 0:
        percentage = change / abs(original_score)
    else:
        percentage = 0
    return change, percentage

current_dir = "/home/workstation/code/XAImethods/CAIN"
dataset_path = f"{current_dir}/results/imagenet/val_images10k_attack/defocus_blur/2_vit/google/vit-large-patch32-384/HiResCAM"
dataset = load_images_from_directory(dataset_path)

with open(f"{current_dir}/imagenet/imagenet_class_index.json", "r") as f:
    imagenet_class_index = json.load(f)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
label_to_index_description = {v[0]: (k, v[1]) for k, v in imagenet_class_index.items()}
model = ViTForImageClassification.from_pretrained('google/vit-large-patch32-384').to(device)
model.eval()

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transform = transforms.Compose([transforms.Resize((384, 384)), transforms.ToTensor(), normalize])

def ensure_rgb(img):
    if img.mode != 'RGB':
        return img.convert('RGB')
    return img

# ... [其它导入和函数定义]

true_labels = []
predicted_labels_original = []
predicted_labels_masked = []
changes = []
percentages = []
for image_id, label in tqdm(dataset):
    original_img_path = os.path.join(dataset_path, image_id + "_" + label, 'original.jpg')
    masked_img_path = os.path.join(dataset_path, image_id + "_" + label, 'masked_image.jpg')

    original_img = ensure_rgb(Image.open(original_img_path))
    original_img_tensor = transform(original_img).to(device)

    masked_img = ensure_rgb(Image.open(masked_img_path))
    masked_img_tensor = transform(masked_img).to(device)

    with torch.no_grad():
        original_logits = model(original_img_tensor.unsqueeze(0)).logits
        original_scores = torch.nn.functional.softmax(original_logits, dim=1).squeeze().tolist()
        predicted_label_original = torch.argmax(original_logits, dim=1).item()

        masked_logits = model(masked_img_tensor.unsqueeze(0)).logits
        masked_scores = torch.nn.functional.softmax(masked_logits, dim=1).squeeze().tolist()
        predicted_label_masked = torch.argmax(masked_logits, dim=1).item()
    
    top_n_indices = [predicted_label_original]
    change, percentage = calculate_prediction_changes(original_scores, masked_scores, top_n_indices)
    changes.append(change)
    percentages.append(percentage)

    

    index_str, _ = label_to_index_description.get(label, (None, None))
    if index_str is None:
        continue
    true_label = int(index_str)
    true_labels.append(true_label)
    predicted_labels_original.append(predicted_label_original)
    predicted_labels_masked.append(predicted_label_masked)

precision_original, recall_original, f1_original, _ = precision_recall_fscore_support(true_labels, predicted_labels_original, average='macro')
precision_masked, recall_masked, f1_masked, _ = precision_recall_fscore_support(true_labels, predicted_labels_masked, average='macro')

mean_changes = np.mean(changes)
mean_percentages = np.mean(percentages)
q25_changes = np.percentile(changes, 25)
median_changes = np.median(changes)
q75_changes = np.percentile(changes, 75)
q25_percentages = np.percentile(percentages, 25)
median_percentages = np.median(percentages)
q75_percentages = np.percentile(percentages, 75)

results = [
    f"--- Original Dataset ---\n",
    f"Precision: {precision_original:.4f}\n",
    f"Recall: {recall_original:.4f}\n",
    f"F1 Score: {f1_original:.4f}\n",
    f"--- Masked Dataset ---\n",
    f"Precision: {precision_masked:.4f}\n",
    f"Recall: {recall_masked:.4f}\n",
    f"F1 Score: {f1_masked:.4f}\n",
    f"\n--- Prediction Changes ---\n",
    f"Average Change: {mean_changes:.4f}\n",
    f"Average Percentage Change: {mean_percentages:.4f}\n",
    f"Q25 Change: {q25_changes:.4f}\n",
    f"Median Change: {median_changes:.4f}\n",
    f"Q75 Change: {q75_changes:.4f}\n",
    f"Q25 Percentage Change: {q25_percentages:.4f}\n",
    f"Median Percentage Change: {median_percentages:.4f}\n",
    f"Q75 Percentage Change: {q75_percentages:.4f}\n"
]
results
results_path = os.path.join(dataset_path + "_vit", "results.txt")
write_to_txt(results_path, ''.join(results))

  from .autonotebook import tqdm as notebook_tqdm
100%|██████████| 912/912 [00:45<00:00, 19.96it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
