In [None]:
import torch
from transformers import ViTForImageClassification, ViTImageProcessor
from PIL import Image
import os
import pandas as pd

# -------------------------
# CONFIG
# -------------------------
MODEL_PATH = r"C:\Users\Ahmed Pasha\Desktop\vit_cicd\models\best_vit_model.pth"
IMAGE_FOLDER = r"C:\Users\Ahmed Pasha\Desktop\vit_cicd\data_set\test"
NUM_CLASSES = 6

device = "cuda" if torch.cuda.is_available() else "cpu"

# -------------------------
# LOAD MODEL & PROCESSOR
# -------------------------
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")

model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224",
    num_labels=NUM_CLASSES,
    ignore_mismatched_sizes=True
)

model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.to(device)
model.eval()


class_names = sorted(os.listdir(IMAGE_FOLDER))

# -------------------------
# PREDICT FOLDER
# -------------------------
results = []

for class_dir in os.listdir(IMAGE_FOLDER):
    class_path = os.path.join(IMAGE_FOLDER, class_dir)

    if not os.path.isdir(class_path):
        continue

    for img_name in os.listdir(class_path):
        img_path = os.path.join(class_path, img_name)

        try:
            image = Image.open(img_path).convert("RGB")

            inputs = processor(images=image, return_tensors="pt")
            inputs = {k: v.to(device) for k, v in inputs.items()}

            with torch.no_grad():
                outputs = model(**inputs)
                pred_idx = torch.argmax(outputs.logits, dim=1).item()

            results.append({
                "image": img_name,
                "actual_class": class_dir,
                "predicted_class": class_names[pred_idx]
            })

            print(f"{img_name} → {class_names[pred_idx]}")

        except Exception as e:
            print(f"❌ Error processing {img_path}: {e}")

# -------------------------
# SAVE RESULTS
# -------------------------
df = pd.DataFrame(results)
df.to_csv("folder_predictions.csv", index=False)

print("\n✅ Predictions saved to folder_predictions.csv")
