In [16]:
import cv2
import torch
import torchvision.transforms as transforms
from torchvision.models import resnet50
from PIL import Image
import matplotlib.pyplot as plt
import time

In [2]:
model = torch.load('models/ViT_flowers_Clf_10_epoch.pth')
model.eval()

VisionTransformer(
  (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (encoder_layer_0): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_1): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_a

In [3]:
# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

VisionTransformer(
  (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (encoder_layer_0): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_1): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_a

In [4]:
# Define the classes
class_names = ["Daisy", "Dandelion", "Rose", "Sunflower", "Tulip"]

In [5]:
# Define the image transformation
image_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [24]:
# Open a connection to the camera (0 is usually the default camera)
cap = cv2.VideoCapture(0)
while True:
    # Capture frame-by-frame
    ret, frame = cap.read()

    # Draw a box (you can adjust the coordinates based on your needs)
    cv2.rectangle(frame, (100, 100), (324, 324), (255, 0, 0), 2)

    # Get the content within the box
    start_time = time.time()
    box_content = frame[100:324, 100:324]
    
    # Convert the box content to PIL Image
    box_content_pil = Image.fromarray(cv2.cvtColor(box_content, cv2.COLOR_BGR2RGB))

    # Apply the transformation
    transformed_box_content = image_transform(box_content_pil).unsqueeze(dim=0).to(device)

    with torch.inference_mode():
        prediction = model(transformed_box_content)

    # Convert logits -> prediction probabilities (using torch.softmax() for multi-class classification)
    predicted_probabilities = torch.softmax(prediction, dim=1)

    # Convert prediction probabilities -> prediction labels
    predicted_label = torch.argmax(predicted_probabilities, dim=1).item()

    predicted_class = class_names[predicted_label]
    end_time = time.time()
    frame_processing_time = end_time - start_time

    predicted_prob = predicted_probabilities[0, predicted_label].item()

    if predicted_prob > 0.75 and predicted_class != "No Flower":
        cv2.putText(frame, f'Pred: {predicted_class} | Prob: {predicted_prob:.3f}', (10, 30),
                cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2, cv2.LINE_AA)
        inference_time_text = f"Inference Time: {frame_processing_time:.4f} seconds"
        cv2.putText(frame, inference_time_text, (10, 70), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2, cv2.LINE_AA)

    cv2.imshow('Frame', frame)

    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

# Release the capture
cap.release()
cv2.destroyAllWindows()