In [None]:
import torch
import torchvision.transforms as transforms
from torchvision.models import vit_b_16
from PIL import Image
import cv2
import numpy as np
from collections import OrderedDict

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 1. Initialize model with correct number of classes (2 for helmet/no-helmet)
model = vit_b_16(pretrained=False, num_classes=2)  # Changed num_classes to match your trained model

# 2. Load state_dict with proper handling
state_dict = torch.load('my_vit_model.pth', map_location=device)

# 3. Create new state_dict with corrected key names
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    if k == 'heads.weight':
        k = 'heads.head.weight'
    elif k == 'heads.bias':
        k = 'heads.head.bias'
    new_state_dict[k] = v

# 4. Load the state_dict
model.load_state_dict(new_state_dict, strict=True)
model = model.to(device)
model.eval()

# Transformations (must match your training)
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Webcam processing
cap = cv2.VideoCapture(0)
class_names = ['NO HELMET', 'HELMET']  # Your class names

while True:
    ret, frame = cap.read()
    if not ret:
        break
    
    # Preprocess
    pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
    img_tensor = transform(pil_img).unsqueeze(0).to(device)
    
    # Predict
    with torch.no_grad():
        output = model(img_tensor)
        probs = torch.softmax(output, dim=1)
    
    # Get results
    pred_class = torch.argmax(probs).item()
    confidence = probs[0][pred_class].item()
    
    # Display
    label = f"{class_names[pred_class]} ({confidence:.2f})"
    color = (0, 255, 0) if pred_class == 1 else (0, 0, 255)
    cv2.putText(frame, label, (20, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2)
    
    # Add visual bounding box
    cv2.rectangle(frame, (50, 50), (frame.shape[1]-50, frame.shape[0]-100), color, 2)
    
    cv2.imshow('Helmet Detection', frame)
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

cap.release()
cv2.destroyAllWindows()