In [1]:
import warnings
import os

warnings.filterwarnings('ignore')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

In [2]:
# Imports
import os
import torch
from torch import nn, optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from transformers import ViTModel
from facenet_pytorch import MTCNN
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
from tqdm import tqdm
from PIL import Image
from collections import deque, Counter

E0000 00:00:1752061028.475797    2487 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1752061028.597678    2487 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1752061029.652998    2487 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1752061029.653064    2487 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1752061029.653070    2487 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1752061029.653074    2487 computation_placer.cc:177] computation placer already registered. Please check linka

In [3]:
# Configuration
MODEL_NAME = "google/vit-base-patch16-224"
DATA_DIR = "../registered_faces"
EMBEDDINGS_FILE = "face_embeddings.npy"
NAMES_FILE = "face_names.npy"
MODEL_SAVE_PATH = "saved_model/vit_face_classifier.pth"
EPOCHS = 5
BATCH_SIZE = 4
LR = 1e-4
SMOOTHING_WINDOW = 5

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [4]:
# Initialize MTCNN
mtcnn = MTCNN(image_size=224, margin=20)

# Data augmentation
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
    transforms.ToTensor()
])

# Custom Dataset with face detection
class FaceDataset(Dataset):
    def __init__(self, root_dir, transform):
        self.samples = []
        self.labels = []
        self.class_to_idx = {}
        self.transform = transform
        classes = sorted(os.listdir(root_dir))
        for idx, class_name in enumerate(classes):
            self.class_to_idx[class_name] = idx
            class_folder = os.path.join(root_dir, class_name)
            if not os.path.isdir(class_folder):
                continue
            for img_name in os.listdir(class_folder):
                img_path = os.path.join(class_folder, img_name)
                self.samples.append(img_path)
                self.labels.append(idx)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path = self.samples[idx]
        label = self.labels[idx]
        img = Image.open(img_path).convert("RGB")
        face = mtcnn(img)
        if face is None:
            # fallback: return blank tensor
            face = torch.zeros(3, 224, 224)
        else:
            face = transforms.ToPILImage()(face)
            face = self.transform(face)
        return face, label

In [5]:
# Load dataset
dataset = FaceDataset(DATA_DIR, transform)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
idx_to_class = {v:k for k,v in dataset.class_to_idx.items()}
print(f"Classes: {list(dataset.class_to_idx.keys())}")

Classes: ['Harsh', 'Mummy', 'Papa']


In [6]:
# Build model
vit = ViTModel.from_pretrained(MODEL_NAME)
class FaceClassifier(nn.Module):
    def __init__(self, vit, num_classes):
        super().__init__()
        self.vit = vit
        self.classifier = nn.Linear(vit.config.hidden_size, num_classes)
    def forward(self, x):
        outputs = self.vit(pixel_values=x)
        pooled = outputs.last_hidden_state[:,0]
        return self.classifier(pooled)

model = FaceClassifier(vit, num_classes=len(dataset.class_to_idx)).to(device)

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
# Training
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)
model.train()
for epoch in range(EPOCHS):
    epoch_loss = 0
    for imgs, labels in tqdm(loader):
        imgs, labels = imgs.to(device), labels.to(device)
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    print(f"Epoch [{epoch+1}/{EPOCHS}] Loss: {epoch_loss/len(loader):.4f}")

100%|█████████████████████████████████████████████████████████████████████████████| 8/8 [00:51<00:00,  6.45s/it]


Epoch [1/5] Loss: 1.3365


100%|█████████████████████████████████████████████████████████████████████████████| 8/8 [00:50<00:00,  6.30s/it]


Epoch [2/5] Loss: 0.4654


100%|█████████████████████████████████████████████████████████████████████████████| 8/8 [00:50<00:00,  6.31s/it]


Epoch [3/5] Loss: 0.2338


100%|█████████████████████████████████████████████████████████████████████████████| 8/8 [00:47<00:00,  5.99s/it]


Epoch [4/5] Loss: 0.0956


100%|█████████████████████████████████████████████████████████████████████████████| 8/8 [00:53<00:00,  6.71s/it]

Epoch [5/5] Loss: 0.0439





In [8]:
# Extract embeddings from trained model
model.eval()
embeddings, names = [], []
with torch.no_grad():
    for imgs, labels in loader:
        imgs = imgs.to(device)
        pooled = model.vit(pixel_values=imgs).last_hidden_state[:,0]
        embeddings.append(pooled.cpu().numpy())
        for label in labels:
            names.append(idx_to_class[label.item()])
embeddings = np.vstack(embeddings)

np.save("saved_embeddings/" + EMBEDDINGS_FILE, embeddings)
np.save("saved_embeddings/" + NAMES_FILE, np.array(names))
print("Saved embeddings & names.")

Saved embeddings & names.


In [9]:
# Temporal smoothing with weighted scores
recent_predictions = deque(maxlen=SMOOTHING_WINDOW)
recent_scores = deque(maxlen=SMOOTHING_WINDOW)

def recognize_with_smoothing(image_path):
    img = Image.open(image_path).convert("RGB")
    face = mtcnn(img)
    if face is None:
        print("No face detected.")
        return
    face = transforms.ToPILImage()(face)
    img_tensor = transform(face).unsqueeze(0).to(device)
    with torch.no_grad():
        pooled = model.vit(pixel_values=img_tensor).last_hidden_state[:,0].cpu().numpy()
    loaded_embs = np.load("saved_embeddings/" + EMBEDDINGS_FILE)
    loaded_names = np.load("saved_embeddings/" + NAMES_FILE)
    sims = cosine_similarity(pooled, loaded_embs)[0]
    best_idx = np.argmax(sims)
    predicted = loaded_names[best_idx]
    recent_predictions.append(predicted)
    recent_scores.append(sims[best_idx])
    # weighted smoothing
    weighted = {}
    for name, score in zip(recent_predictions, recent_scores):
        weighted[name] = weighted.get(name, 0) + score
    most_common = max(weighted, key=weighted.get)
    print(f"Predicted: {predicted} (score: {sims[best_idx]:.2f}), Smoothed: {most_common}")

In [13]:
# Test all images
root_folder = "../registered_faces"

for person in os.listdir(root_folder):
    person_folder = os.path.join(root_folder, person)
    if not os.path.isdir(person_folder):
        continue
    print(f"\n--- Testing images in folder: {person} ---")
    recent_predictions.clear()
    recent_scores.clear()
    for img_file in os.listdir(person_folder):
        img_path = os.path.join(person_folder, img_file)
        recognize_with_smoothing(img_path)


--- Testing images in folder: Harsh ---
Predicted: Harsh (score: 0.83), Smoothed: Harsh
Predicted: Harsh (score: 0.84), Smoothed: Harsh
Predicted: Harsh (score: 0.93), Smoothed: Harsh
Predicted: Harsh (score: 0.89), Smoothed: Harsh
Predicted: Harsh (score: 0.90), Smoothed: Harsh
Predicted: Harsh (score: 0.97), Smoothed: Harsh
Predicted: Harsh (score: 0.84), Smoothed: Harsh
Predicted: Harsh (score: 0.79), Smoothed: Harsh
Predicted: Harsh (score: 0.88), Smoothed: Harsh
Predicted: Harsh (score: 0.87), Smoothed: Harsh

--- Testing images in folder: Mummy ---
Predicted: Mummy (score: 0.91), Smoothed: Mummy
Predicted: Mummy (score: 0.91), Smoothed: Mummy
Predicted: Mummy (score: 0.95), Smoothed: Mummy
Predicted: Mummy (score: 0.92), Smoothed: Mummy
Predicted: Mummy (score: 0.88), Smoothed: Mummy
Predicted: Mummy (score: 0.84), Smoothed: Mummy
Predicted: Mummy (score: 0.81), Smoothed: Mummy
Predicted: Mummy (score: 0.95), Smoothed: Mummy
Predicted: Mummy (score: 0.81), Smoothed: Mummy
Predic

In [14]:
# Save model
torch.save({
    'model_state_dict': model.state_dict(),
    'class_to_idx': dataset.class_to_idx
}, MODEL_SAVE_PATH)

print(f"Model saved to {MODEL_SAVE_PATH}")

Model saved to saved_model/vit_face_classifier.pth
