In [1]:
import cv2
import torch
from transformers import ViTImageProcessor, ViTForImageClassification

# Initialize the image processor and model
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')
model = ViTForImageClassification.from_pretrained('ChrisGuarino/model')  # Replace with your model
model.eval()

# Define your class labels
class_labels = ['Prim', 'Rupe', 'No Cat']  # Replace with your actual labels

# Start the webcam
cap = cv2.VideoCapture(0)

#Confidence Threshold
confidence_threshold = .8  # Define a threshold

while True:
    ret, frame = cap.read()

    if ret:
        # Preprocess the frame
        frame_resized = cv2.resize(frame, (400, 400))
        frame_rgb = cv2.cvtColor(frame_resized, cv2.COLOR_BGR2RGB)
        inputs = processor(images=frame_rgb, return_tensors="pt")

        # Get predicti?ons
        with torch.no_grad():
            predictions = model(**inputs).logits

            # Convert predictions to probabilities and get the highest probability class
            probabilities = torch.nn.functional.softmax(predictions, dim=-1)
            confidences, predicted_class_idx = torch.max(probabilities, dim=-1)
            predicted_class = class_labels[predicted_class_idx]#Something with +1 to shift the labels if we add a No Cat label
        print(probabilities)

        # Check if confidence is above the threshold
        if confidences.item() < confidence_threshold:
            label = 'No Cat'
            confidence = 0
        else:
            label = class_labels[predicted_class_idx.item()]  # +1 to account for 'No Cat'
            confidence = confidences.item()

        # Prepare the display text
        display_text = f'{label} ({confidence:.2f})'

        # Display the prediction on the frame
        cv2.putText(frame_resized, display_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)

        # Display the resulting frame
        cv2.imshow('Cat or Not', frame_resized)

        # Break the loop with 'q'
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break
    else:
        break

# Release the capture when done
cap.release()
cv2.destroyAllWindows()


tensor([[0.0318, 0.0238, 0.9444]])
tensor([[0.0317, 0.0238, 0.9445]])
tensor([[0.0339, 0.0242, 0.9419]])
tensor([[0.0329, 0.0241, 0.9431]])
tensor([[0.0321, 0.0239, 0.9440]])
tensor([[0.0312, 0.0234, 0.9454]])
tensor([[0.0336, 0.0244, 0.9421]])
tensor([[0.0338, 0.0248, 0.9414]])
tensor([[0.0331, 0.0235, 0.9434]])
tensor([[0.0320, 0.0238, 0.9442]])
tensor([[0.0306, 0.0230, 0.9463]])
tensor([[0.0336, 0.0236, 0.9428]])
tensor([[0.0323, 0.0227, 0.9450]])
tensor([[0.0346, 0.0222, 0.9432]])
tensor([[0.0340, 0.0232, 0.9428]])
tensor([[0.0346, 0.0271, 0.9383]])
tensor([[0.0328, 0.0240, 0.9432]])
tensor([[0.0339, 0.0233, 0.9428]])
tensor([[0.0334, 0.0234, 0.9432]])
tensor([[0.0337, 0.0238, 0.9425]])
tensor([[0.0339, 0.0248, 0.9413]])
tensor([[0.0348, 0.0245, 0.9407]])
tensor([[0.0338, 0.0242, 0.9420]])
tensor([[0.0341, 0.0229, 0.9429]])
tensor([[0.0328, 0.0228, 0.9444]])
tensor([[0.0269, 0.0203, 0.9529]])
tensor([[0.0298, 0.0215, 0.9487]])
tensor([[0.0309, 0.0214, 0.9477]])
tensor([[0.0367, 0.0