In [23]:
import cv2
import torch
import numpy as np
from cnn_to_mlp import CNNtoMLP  # import your model class
import torch.nn.functional as F

# Parameters
IMG_SIZE = 128  # input image size expected by model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#best values
output_dim=2
num_hidden_layers=4
neurons_per_layer=112
dropout_rate=0.1

# Load your trained model weights
model = CNNtoMLP(output_dim, num_hidden_layers, neurons_per_layer, dropout_rate)
model.load_state_dict(torch.load("model_weights.pth", map_location=device))
model.to(device)
model.eval()

# OpenCV face detector
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')

def preprocess_face(face_img):
    # Resize
    face_resized = cv2.resize(face_img, (IMG_SIZE, IMG_SIZE))
    # Normalize (scale pixels to [0,1])
    face_normalized = face_resized.astype(np.float32) / 255.0
    # Convert HWC to CHW format for PyTorch (3, IMG_SIZE, IMG_SIZE)
    face_chw = np.transpose(face_normalized, (2, 0, 1))
    # Add batch dimension (1, 3, IMG_SIZE, IMG_SIZE)
    face_tensor = torch.from_numpy(face_chw).unsqueeze(0).to(device)
    return face_tensor

# Start webcam
cap = cv2.VideoCapture(0)

while True:
    ret, frame = cap.read()
    if not ret:
        break

    gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
    faces = face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5)

    for (x, y, w, h) in faces:
        face_img = frame[y:y+h, x:x+w]
        face_input = preprocess_face(face_img)

        # Predict mask or no mask
        with torch.no_grad():
            outputs = model(face_input)
            probs = F.softmax(outputs, dim=1)
            prob_mask = probs[0][1].item()
            prob_no_mask = probs[0][0].item()

        label = "No Mask" if prob_mask > prob_no_mask else "Mask"
        color = (0, 0, 255) if label == "No Mask" else (0, 255, 0)
        confidence = max(prob_mask, prob_no_mask) * 100

        # Draw rectangle and label
        cv2.rectangle(frame, (x, y), (x+w, y+h), color, 2)
        cv2.putText(frame, f"{label}: {confidence:.1f}%", (x, y-10),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2)

    cv2.imshow("Real-Time Mask Detection", frame)

    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

cap.release()
cv2.destroyAllWindows()
