In [8]:
# utils.py
import matplotlib.pyplot as plt
import mediapipe as mp
from PIL import Image
import numpy as np
import torch
from transformers import ResNetForImageClassification
from torchvision import transforms

class ImageWithFaces:
    def __init__(self, image_path):
        self.image_path = image_path
        self.faces = []

def init_resnet():
    model_path = 'final_model_resnet_50.pt'
    state_dict = torch.load(model_path, map_location=torch.device('cpu'))
    model = ResNetForImageClassification.from_pretrained("microsoft/resnet-50")
    model.classifier[1] = torch.nn.Linear(in_features=2048, out_features=2, bias=True)
    model.load_state_dict(state_dict)
    model.eval()
    preprocess = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    return model, preprocess


def classify_faces(face, model, preprocess):
    face_image = Image.fromarray(face)
    input_tensor = preprocess(face_image)
    input_batch = input_tensor.unsqueeze(0)
    with torch.no_grad():
        output = model(input_batch)
    _, predicted_class = torch.max(output.logits, 1)
    return predicted_class


def predict_faces_from_image(image_path):
    model, preprocess = init_resnet()
    images = plt.imread(image_path)
    return classify_faces(images, model, preprocess)
    

# Example usage:
predictions = predict_faces_from_image('face_4.jpg')
print(predictions)


tensor([0])
