In [3]:
import cv2
import torch
from transformers import ViTImageProcessor, ViTForImageClassification

# Initialize the image processor and model
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')
model = ViTForImageClassification.from_pretrained('ChrisGuarino/cat_ds')  # Replace with your model
model.eval()

# Define your class labels
class_labels = ['Prim', 'Rupe', 'No Cat']  # Replace with your actual labels

# Start the webcam
cap = cv2.VideoCapture(0)

#Confidence Threshold
confidence_threshold = 0.5  # Define a threshold

while True:
    ret, frame = cap.read()

    if ret:
        # Preprocess the frame
        frame_resized = cv2.resize(frame, (224, 224))
        frame_rgb = cv2.cvtColor(frame_resized, cv2.COLOR_BGR2RGB)
        inputs = processor(images=frame_rgb, return_tensors="pt")

        # Get predictions
        with torch.no_grad():
            predictions = model(**inputs).logits

            # Convert predictions to probabilities and get the highest probability class
            probabilities = torch.nn.functional.softmax(predictions, dim=-1)
            print(predictions)
            confidences, predicted_class_idx = torch.max(probabilities, dim=-1)
            predicted_class = class_labels[predicted_class_idx]#Something with +1 to shift the labels if we add a No Cat label

        # Check if confidence is above the threshold
        if confidences.item() < confidence_threshold:
            label = 'No Cat'
            confidence = 0
        else:
            label = class_labels[predicted_class_idx.item()]  # +1 to account for 'No Cat'
            confidence = confidences.item()

        # Prepare the display text
        display_text = f'{label} ({confidence:.2f})'

        # Display the prediction on the frame
        cv2.putText(frame_resized, display_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)

        # Display the resulting frame
        cv2.imshow('Frame', frame_resized)

        # Break the loop with 'q'
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break
    else:
        break

# Release the capture when done
cap.release()
cv2.destroyAllWindows()


tensor([[ 0.1145, -0.1426]])
tensor([[ 0.3684, -0.0427]])
tensor([[0.3152, 0.0069]])
tensor([[0.2576, 0.0402]])
tensor([[-0.0507,  0.3502]])
tensor([[-0.5480,  0.9041]])
tensor([[-0.8661,  1.2233]])
tensor([[-0.8024,  1.0979]])
tensor([[0.2012, 0.1271]])
tensor([[ 0.3361, -0.0227]])
tensor([[ 0.3417, -0.0142]])
tensor([[-0.2128,  0.4447]])
tensor([[-0.8164,  1.0626]])
tensor([[-0.5232,  0.8070]])
tensor([[-0.6395,  0.9411]])
tensor([[-0.4717,  0.7772]])
tensor([[-0.4196,  0.7348]])
tensor([[-0.2317,  0.5971]])
tensor([[-0.4114,  0.6923]])
tensor([[-0.3210,  0.5954]])
tensor([[-0.0508,  0.3587]])
tensor([[-0.2964,  0.5974]])
tensor([[ 0.4879, -0.2001]])
tensor([[ 0.3666, -0.0075]])
tensor([[0.1581, 0.1284]])
tensor([[ 0.4052, -0.1135]])
tensor([[ 0.4673, -0.2009]])
tensor([[ 0.4838, -0.2158]])
tensor([[ 0.4735, -0.2073]])
tensor([[ 0.4998, -0.2351]])
tensor([[ 0.5029, -0.2121]])
tensor([[ 0.4138, -0.1374]])
tensor([[ 0.3042, -0.0481]])
tensor([[0.2133, 0.0182]])
tensor([[ 0.4257, -0.187