In [6]:
import cv2
import torch
from transformers import ViTImageProcessor, ViTForImageClassification

# Initialize the image processor and model
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')
model = ViTForImageClassification.from_pretrained('ChrisGuarino/cat_ds')  # Replace with your model
model.eval()

# Define your class labels
class_labels = ['Prim', 'Rupe']  # Replace with your actual labels

# Start the webcam
cap = cv2.VideoCapture(0)

while True:
    ret, frame = cap.read()

    if ret:
        # Preprocess the frame
        frame_resized = cv2.resize(frame, (224, 224))
        frame_rgb = cv2.cvtColor(frame_resized, cv2.COLOR_BGR2RGB)
        inputs = processor(images=frame_rgb, return_tensors="pt")

        # Get predictions
        with torch.no_grad():
            predictions = model(**inputs).logits

        # Convert predictions to probabilities and get the highest probability class
        probabilities = torch.nn.functional.softmax(predictions, dim=-1)
        predicted_class_idx = probabilities.argmax(-1).item()
        predicted_class = class_labels[predicted_class_idx]

        # Display the prediction on the frame
        cv2.putText(frame_resized, predicted_class, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)

        # Display the resulting frame
        cv2.imshow('Frame', frame_resized)

        # Break the loop with 'q'
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break
    else:
        break

# Release the capture when done
cap.release()
cv2.destroyAllWindows()
