In [1]:
import cv2
import PIL 
from facenet_pytorch import MTCNN
from PIL import Image,ImageDraw

import torch
import numpy as np
from torchvision import datasets, models, transforms

In [2]:
# Define data transform to apply on the image from webcam
data_transform =  transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor()])

In [3]:
# check if gpu is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#Load Model
model = torch.load('face_mask_model.pth')
model.eval()

print(device)

cuda


In [4]:
# MTCNN is used for face detection
mtcnn = MTCNN(keep_all = True, device = device)

In [None]:
cap = cv2.VideoCapture(0)
sm = torch.nn.Softmax(dim=1)

while True:
    ret, image = cap.read()
    
    # Flip image horizontally
    image = cv2.flip(image, 1)
    
    frame = image.copy()
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    
    # Detect faces in image
    boxes, _ = mtcnn.detect(frame)
    
    color = (0,255,0)
    
    # Threshold for position of box points
    t = 5
    if boxes is not None:
        for box in boxes:
            startX, endX = int(box[0]) - t, int(box[2]) + t
            startY, endY = int(box[1]) - t, int(box[3]) + t
            
            # Draw rectangle over the image
            cv2.rectangle(image, (startX, startY), (endX, endY), color, 2)
            
            face = frame[startY:endY, startX:endX]
            img = PIL.Image.fromarray(face)
            with torch.no_grad():
                img = data_transform(img)
                img = img.to(device)
                img = img.view(-1, 3, 224, 224)
                res = model(img)
            
            is_mask = True if torch.argmax(res) == 0 else False 
            txt = 'Mask' if is_mask else 'No Mask'
            
            probabilities = sm(res)
            conf_value = probabilities[0][0].item() if is_mask else probabilities[0][1].item()
            conf = f'Confidence: {conf_value*100:.2f}%'

            color = (0, 255, 0)
            cv2.putText(image, txt, (startX, startY - 20), cv2.FONT_HERSHEY_SIMPLEX, 0.45, color, 1)
            cv2.putText(image, conf, (startX, startY - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.45, color, 1)
    
    cv2.imshow("Video", image)
    k = cv2.waitKey(1) & 0xFF
    if k == ord("q"):
        break
cap.release()
cv2.destroyAllWindows()