In [None]:
import torch
import torchvision.transforms as transforms
from torchvision.models import vit_b_16
from PIL import Image
import cv2
import numpy as np
from collections import OrderedDict

# Load face detection model (Haar Cascade)
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize ViT model for helmet detection
model = vit_b_16(pretrained=False, num_classes=2)
state_dict = torch.load('my_vit_model.pth', map_location=device)

# Adapt state_dict keys
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    if k == 'heads.weight':
        k = 'heads.head.weight'
    elif k == 'heads.bias':
        k = 'heads.head.bias'
    new_state_dict[k] = v

model.load_state_dict(new_state_dict)
model = model.to(device)
model.eval()

# Image transformations
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Webcam processing
cap = cv2.VideoCapture(0)
class_names = ['NO HELMET', 'HELMET']

while True:
    ret, frame = cap.read()
    if not ret:
        break
    
    # Convert to grayscale for face detection
    gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
    
    # Detect faces
    faces = face_cascade.detectMultiScale(gray, 1.1, 4)
    
    for (x, y, w, h) in faces:
        # Expand face ROI to include head/helmet area
        head_y = max(0, y - int(h * 0.5))
        head_h = min(frame.shape[0], h + int(h * 0.5))
        head_x = max(0, x - int(w * 0.2))
        head_w = min(frame.shape[1], w + int(w * 0.4))
        
        # Extract head region
        head_roi = frame[head_y:head_y+head_h, head_x:head_x+head_w]
        
        # Convert to PIL and preprocess
        pil_img = Image.fromarray(cv2.cvtColor(head_roi, cv2.COLOR_BGR2RGB))
        img_tensor = transform(pil_img).unsqueeze(0).to(device)
        
        # Predict
        with torch.no_grad():
            output = model(img_tensor)
            probs = torch.softmax(output, dim=1)
        
        pred_class = torch.argmax(probs).item()
        confidence = probs[0][pred_class].item()
        
        # Draw bounding box and label
        color = (0, 255, 0) if pred_class == 1 else (0, 0, 255)
        cv2.rectangle(frame, (head_x, head_y), (head_x+head_w, head_y+head_h), color, 2)
        
        label = f"{class_names[pred_class]} ({confidence:.2f})"
        cv2.putText(frame, label, (head_x, head_y-10), 
                   cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)
    
    cv2.imshow('Helmet Detection', frame)
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

cap.release()
cv2.destroyAllWindows()

