In [1]:
import os
from PIL import Image
import json
from torchvision import transforms
import torch
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix
from tqdm import tqdm
from transformers import SegformerForImageClassification
from functools import partial
import shutil

def load_images_from_directory(root_path: str):
    dataset = []
    for label in os.listdir(root_path):
        label_path = os.path.join(root_path, label)
        if os.path.isdir(label_path):
            for image_file in os.listdir(label_path):
                image_path = os.path.join(label_path, image_file)
                if image_path.lower().endswith(('.png', '.jpg', '.jpeg')):
                    img = Image.open(image_path)
                    dataset.append((img, label, image_file))
    return dataset

current_dir = "/home/workstation/code/XAImethods/CAIN"
dataset_path = f"{current_dir}/imagenet/val_images5k"
dataset = load_images_from_directory(dataset_path)

# Load ImageNet class index
with open(f"{current_dir}/imagenet/imagenet_class_index.json", "r") as f:
    imagenet_class_index = json.load(f)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
label_to_index_description = {v[0]: (k, v[1]) for k, v in imagenet_class_index.items()}

# Initialize the Segformer model
model = SegformerForImageClassification.from_pretrained('nvidia/mit-b0').to(device)
model.eval()

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

transform = transforms.Compose([
    transforms.Resize((480, 640)), # Adjust size for Segformer
    transforms.ToTensor(),
    normalize
])

def ensure_rgb(img):
    if img.mode != 'RGB':
        return img.convert('RGB')
    return img

true_labels = []
predicted_labels = []

# Create the target directory if it does not exist
target_dir = dataset_path + "_mit"
os.makedirs(target_dir, exist_ok=True)

for img, label, filename in tqdm(dataset):
    img = ensure_rgb(img)  
    img_tensor = transform(img).to(device)
    
    # Model prediction
    with torch.no_grad():
        logits = model(img_tensor.unsqueeze(0)).logits
        predicted_label = torch.argmax(logits, dim=1).item()
    
    # Get true label
    index_str, _ = label_to_index_description.get(label, (None, None))
    if index_str is None:
        continue
    true_label = int(index_str)
    true_labels.append(true_label)
    predicted_labels.append(predicted_label)
    
    # If prediction is correct, copy the image to the target directory
    if true_label == predicted_label:
        source_path = os.path.join(dataset_path, label, filename)
        target_label_dir = os.path.join(target_dir, label)
        os.makedirs(target_label_dir, exist_ok=True)  # Make sure the directory exists
        shutil.copy(source_path, target_label_dir)  # Copy the file

precision, recall, f1, _ = precision_recall_fscore_support(true_labels, predicted_labels, average='micro')
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")

for target_class in set(true_labels):
    # Use sklearn's confusion matrix
    matrix = confusion_matrix(true_labels, predicted_labels, labels=[target_class])
    if matrix.size == 1:
        value = matrix[0][0]
        if target_class in predicted_labels:  # It was predicted at least once
            tp = value
            tn, fp, fn = 0, 0, 0
        else:  # It was never predicted
            tn = value
            tp, fp, fn = 0, 0, 0
    else:
        tn, fp, fn, tp = matrix.ravel()
    
    print(f"For class {target_class}: TP={tp}, FP={fp}, FN={fn}, TN={tn}")


  from .autonotebook import tqdm as notebook_tqdm
 58%|█████▊    | 2876/5000 [00:54<00:40, 52.35it/s]


KeyboardInterrupt: 