In [1]:
import torch
import torchvision.transforms as transforms
import cv2
import numpy as np
from PIL import Image
from facenet_pytorch import MTCNN
from timm import create_model

# Import custom modules
from scripts.face_detection import detect_faces
from scripts.vit_sr import load_vit_sr, enhance_face

# Load ViT-SR model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vit_sr = load_vit_sr(device)

# Initialize Face Detector (MTCNN)
face_detector = MTCNN(keep_all=False, device=device)

# Open Webcam
cap = cv2.VideoCapture(0)

while True:
    ret, frame = cap.read()
    if not ret:
        break

    # Convert to RGB
    rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

    # Detect Faces
    faces, boxes = detect_faces(rgb_frame, face_detector)

    for (face, box) in zip(faces, boxes):
        x1, y1, x2, y2 = map(int, box)
        
        # Convert Face to PIL Image
        face_pil = Image.fromarray(face)

        # Apply ViT-SR Enhancement
        enhanced_face = enhance_face(face_pil, vit_sr, device)

        # Resize back to original shape and overlay
        enhanced_face_resized = cv2.resize(enhanced_face, (x2-x1, y2-y1))
        frame[y1:y2, x1:x2] = enhanced_face_resized

    # Display Output
    cv2.imshow("ViT-SR Face Enhancement", frame)

    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

cap.release()
cv2.destroyAllWindows()


ModuleNotFoundError: No module named 'scripts.face_detection'