In [3]:
import torch
from models import BiTLikeModel

# Load the model checkpoint
checkpoint = torch.load('best.pt', map_location=torch.device('cpu'))

# Initialize the model
model = BiTLikeModel(2)  # Replace with your model class

# Get the model's state_dict
model_state_dict = model.state_dict()

# Filter out mismatched keys (e.g., fc layer)
filtered_checkpoint = {k: v for k, v in checkpoint.items() if k in model_state_dict and model_state_dict[k].shape == v.shape}

# Load the compatible weights
model_state_dict.update(filtered_checkpoint)
model.load_state_dict(model_state_dict)

# Reinitialize the fc layer
torch.nn.init.xavier_uniform_(model.fc.weight)
torch.nn.init.zeros_(model.fc.bias)

# Set the model to evaluation mode
model.eval()

print("Model loaded successfully!")

Model loaded successfully!


In [4]:
import torch
from torchvision import transforms
from PIL import Image

def predict_image_class(model, image_path):
    """
    Predicts the class of an image using the given model.

    Args:
        model (torch.nn.Module): The trained model for inference.
        image_path (str): Path to the image file.

    Returns:
        int: Predicted class index.
    """
    # Define the image preprocessing pipeline
    preprocess = transforms.Compose([
        transforms.Resize((224, 224)),  # Resize to match the model's input size
        transforms.ToTensor(),         # Convert image to tensor
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize
    ])

    # Load and preprocess the image
    image = Image.open(image_path).convert('RGB')
    input_tensor = preprocess(image).unsqueeze(0)  # Add batch dimension

    # Perform inference
    with torch.no_grad():
        output = model(input_tensor)
        probabilities = torch.softmax(output, dim=1)

    # Get the predicted class
    predicted_class = torch.argmax(probabilities, dim=1).item()
    return predicted_class

In [6]:
import cv2
import os
import time

def capture_and_predict(model, capture_interval=5):
    """
    Automatically captures images from the webcam at regular intervals and predicts their class using the given model.

    Args:
        model (torch.nn.Module): The trained model for inference.
        capture_interval (int): Time interval (in seconds) between captures.

    Returns:
        None
    """
    # Initialize the webcam
    cap = cv2.VideoCapture(0)
    if not cap.isOpened():
        raise Exception("Could not open webcam")

    print("Capturing images automatically. Press 'Esc' to exit.")

    while True:
        ret, frame = cap.read()
        if not ret:
            print("Failed to grab frame")
            break

        # Display the webcam feed
        cv2.imshow("Webcam", frame)

        # Save the captured frame as a temporary image
        temp_image_path = "temp_image.jpg"
        cv2.imwrite(temp_image_path, frame)
        print("Image captured!")

        # Predict the class of the captured image
        predicted_class = predict_image_class(model, temp_image_path)
        print(f"Predicted class: {predicted_class}")

        # Remove the temporary image
        os.remove(temp_image_path)

        # Wait for the specified interval or until 'Esc' is pressed
        if cv2.waitKey(capture_interval * 1000) == 27:  # Esc key to exit
            break

    # Release the webcam and close the window
    cap.release()
    cv2.destroyAllWindows()

In [8]:
capture_and_predict(model)

Capturing images automatically. Press 'Esc' to exit.
Image captured!
Predicted class: 0


KeyboardInterrupt: 

In [15]:
# Check if webcam is available
import cv2

cap = cv2.VideoCapture(0)
if not cap.isOpened():
    print("Could not open webcam")
else:
    print("Webcam initialized successfully")
cap.release()

Webcam initialized successfully
