In [26]:
import cv2
import torch
from transformers import ViTImageProcessor, ViTForImageClassification

# Initialize the image processor and model
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')
model = ViTForImageClassification.from_pretrained('ChrisGuarino/model')  # Replace with your model
model.eval()

# Define your class labels
class_labels = ['Prim', 'Rupe', 'No Cat']  # Replace with your actual labels

# Start the webcam
cap = cv2.VideoCapture(0)

#Confidence Threshold
confidence_threshold = .8  # Define a threshold

while True:
    ret, frame = cap.read()

    if ret:
        # Preprocess the frame
        frame_resized = cv2.resize(frame, (244, 244))
        frame_rgb = cv2.cvtColor(frame_resized, cv2.COLOR_BGR2RGB)
        inputs = processor(images=frame_rgb, return_tensors="pt")

        # Get predicti?ons
        with torch.no_grad():
            predictions = model(**inputs).logits

            # Convert predictions to probabilities and get the highest probability class
            probabilities = torch.nn.functional.softmax(predictions, dim=-1)
            confidences, predicted_class_idx = torch.max(probabilities, dim=-1)
            predicted_class = class_labels[predicted_class_idx]#Something with +1 to shift the labels if we add a No Cat label
        print(probabilities)

        # Check if confidence is above the threshold
        if confidences.item() < confidence_threshold:
            label = 'No Cat'
            confidence = 0
        else:
            label = class_labels[predicted_class_idx.item()]  # +1 to account for 'No Cat'
            confidence = confidences.item()

        # Prepare the display text
        display_text = f'{label} ({confidence:.2f})'

        # Display the prediction on the frame
        cv2.putText(frame_resized, display_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)

        # Display the resulting frame
        cv2.imshow('Frame', frame_resized)

        # Break the loop with 'q'
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break
    else:
        break

# Release the capture when done
cap.release()
cv2.destroyAllWindows()


tensor([[0.9390, 0.0340, 0.0270]])
tensor([[0.9355, 0.0406, 0.0239]])
tensor([[0.9138, 0.0640, 0.0222]])
tensor([[0.9347, 0.0396, 0.0258]])
tensor([[0.9377, 0.0369, 0.0254]])
tensor([[0.9371, 0.0364, 0.0265]])
tensor([[0.9461, 0.0306, 0.0234]])
tensor([[0.9406, 0.0337, 0.0257]])
tensor([[0.9379, 0.0372, 0.0248]])
tensor([[0.9386, 0.0367, 0.0248]])
tensor([[0.9389, 0.0333, 0.0278]])
tensor([[0.9398, 0.0343, 0.0259]])
tensor([[0.9423, 0.0326, 0.0251]])
tensor([[0.9398, 0.0330, 0.0272]])
tensor([[0.9391, 0.0340, 0.0268]])
tensor([[0.9402, 0.0333, 0.0265]])
tensor([[0.9409, 0.0312, 0.0279]])
tensor([[0.9546, 0.0248, 0.0205]])
tensor([[0.9216, 0.0476, 0.0308]])
tensor([[0.0107, 0.0143, 0.9750]])
tensor([[0.0782, 0.0540, 0.8678]])
tensor([[0.0819, 0.0567, 0.8614]])
tensor([[0.0812, 0.0562, 0.8626]])
tensor([[0.0813, 0.0561, 0.8626]])
tensor([[0.0821, 0.0567, 0.8613]])
tensor([[0.0828, 0.0576, 0.8596]])
tensor([[0.0838, 0.0585, 0.8577]])
tensor([[0.0826, 0.0574, 0.8599]])
tensor([[0.0806, 0.0