In [1]:
import cv2
import torch
from transformers import ViTImageProcessor, ViTForImageClassification

# Initialize the image processor and model
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')
model = ViTForImageClassification.from_pretrained('ChrisGuarino/model')  # Replace with your model
model.eval()

# Define your class labels
class_labels = ['Prim', 'Rupe', 'No Cat']  # Replace with your actual labels

# Start the webcam
cap = cv2.VideoCapture(0)

#Confidence Threshold
confidence_threshold = 0.5  # Define a threshold

while True:
    ret, frame = cap.read()

    if ret:
        # Preprocess the frame
        frame_resized = cv2.resize(frame, (224, 224))
        frame_rgb = cv2.cvtColor(frame_resized, cv2.COLOR_BGR2RGB)
        inputs = processor(images=frame_rgb, return_tensors="pt")

        # Get predictions
        with torch.no_grad():
            predictions = model(**inputs).logits

            # Convert predictions to probabilities and get the highest probability class
            probabilities = torch.nn.functional.softmax(predictions, dim=-1)
            print(predictions)
            confidences, predicted_class_idx = torch.max(probabilities, dim=-1)
            predicted_class = class_labels[predicted_class_idx]#Something with +1 to shift the labels if we add a No Cat label

        # Check if confidence is above the threshold
        if confidences.item() < confidence_threshold:
            label = 'No Cat'
            confidence = 0
        else:
            label = class_labels[predicted_class_idx.item()]  # +1 to account for 'No Cat'
            confidence = confidences.item()

        # Prepare the display text
        display_text = f'{label} ({confidence:.2f})'

        # Display the prediction on the frame
        cv2.putText(frame_resized, display_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)

        # Display the resulting frame
        cv2.imshow('Frame', frame_resized)

        # Break the loop with 'q'
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break
    else:
        break

# Release the capture when done
cap.release()
cv2.destroyAllWindows()


  from .autonotebook import tqdm as notebook_tqdm
config.json: 100%|██████████| 761/761 [00:00<00:00, 93.5kB/s]
model.safetensors: 100%|██████████| 343M/343M [00:10<00:00, 32.8MB/s] 


tensor([[-1.3676, -0.2928,  1.7800]])
tensor([[-1.2919, -0.2023,  1.6204]])
tensor([[-1.2654, -0.2423,  1.6587]])
tensor([[-1.3328, -0.3608,  1.7941]])
tensor([[-1.3539, -0.3331,  1.8035]])
tensor([[-1.3420, -0.3464,  1.7897]])
tensor([[-1.3233, -0.4154,  1.8518]])
tensor([[-1.3586, -0.4452,  1.8693]])
tensor([[-1.3516, -0.3096,  1.7404]])
tensor([[-1.3409, -0.3579,  1.7930]])
tensor([[-1.3797, -0.3506,  1.8389]])
tensor([[-1.3515, -0.3042,  1.7406]])
tensor([[-1.3506, -0.4442,  1.8725]])
tensor([[-1.3248, -0.3814,  1.8247]])
tensor([[-1.3516, -0.3381,  1.8197]])
tensor([[-1.3341, -0.3179,  1.7642]])
tensor([[-1.3574, -0.3848,  1.8454]])
tensor([[-1.3546, -0.3002,  1.7646]])
tensor([[-1.3626, -0.3976,  1.8316]])
tensor([[-1.3495, -0.2802,  1.7457]])
tensor([[-1.3249, -0.3830,  1.8313]])
tensor([[-1.3247, -0.4240,  1.8353]])
tensor([[-1.3149, -0.3744,  1.8177]])
tensor([[-1.3821, -0.2951,  1.8141]])
tensor([[-1.3668, -0.3268,  1.7877]])
tensor([[-1.3299, -0.2681,  1.7434]])
tensor([[-1.