In [1]:
import torch
import clip
from PIL import Image
import os

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


In [3]:
model, preprocess = clip.load("ViT-B/32", device=device)

In [4]:
# Define dataset path
training_damage_path = "../dataset/training/00-damage"
training_whole_path = "../dataset/training/01-whole"

validation_damage_path = "../dataset/validation/00-damage"
validation_whole_path = "../dataset/validation/01-whole"

In [5]:
def get_images_from_folder(folder_path):
    images = []
    for img_file in os.listdir(folder_path):
        if img_file.endswith(('.png', '.PNG', '.jpg', '.JPG', '.jpeg', '.JPEG')):
            image = preprocess(Image.open(os.path.join(folder_path, img_file))).unsqueeze(0).to(device)
            images.append(image)
    return torch.cat(images, dim=0)

def evaluate(images, label, text_descriptions):
    with torch.no_grad():
        image_features = model.encode_image(images)
        text_features = model.encode_text(clip.tokenize(text_descriptions).to(device))
        similarities = image_features @ text_features.T
        predictions = similarities.argmax(dim=1)
        correct_predictions = (predictions == label).sum().item()
    return correct_predictions, len(images)

In [6]:
# Load pictures
training_damage_images = get_images_from_folder(training_damage_path)
training_whole_images = get_images_from_folder(training_whole_path)

validation_damage_images = get_images_from_folder(validation_damage_path)
validation_whole_images = get_images_from_folder(validation_whole_path)

damage_images = torch.cat((training_damage_images, validation_damage_images), dim=0)
whole_images = torch.cat((training_whole_images, validation_whole_images), dim=0)

# Text description
text_descriptions = ["damaged car", "whole car"]

# Evaluate
correct_damage, total_damage = evaluate(damage_images, 0, text_descriptions)
correct_whole, total_whole = evaluate(whole_images, 1, text_descriptions)
total_correct = correct_damage + correct_whole
total_images = total_damage + total_whole

# Calcuate
accuracy_damage = correct_damage / total_damage
accuracy_whole = correct_whole / total_whole
overall_accuracy = total_correct / total_images

print(f"Accuracy for damage: {accuracy_damage*100:.2f}%")
print(f"Accuracy for whole: {accuracy_whole*100:.2f}%")
print(f"Overall accuracy: {overall_accuracy*100:.2f}%")

Accuracy for damage: 64.43%
Accuracy for whole: 99.91%
Overall accuracy: 82.17%
