In [1]:
# Imports
import os
import torch
from torch import nn, optim
from torchvision import datasets, transforms
from transformers import ViTModel, ViTImageProcessor
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm
2025-07-09 08:16:18.738600: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1752048979.085706    2131 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1752048979.181875    2131 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1752048980.035431    2131 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1752048980.035562    2131 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1752048980.035567    2131

In [4]:
# Configuration & device
MODEL_NAME = "google/vit-base-patch16-224"
DATA_DIR = "../registered_faces"
EMBEDDINGS_FILE = "face_embeddings.npy"
NAMES_FILE = "face_names.npy"
EPOCHS = 5
BATCH_SIZE = 4
LR = 1e-4

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [5]:
# Load & preprocess dataset
processor = ViTImageProcessor.from_pretrained(MODEL_NAME)

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

dataset = datasets.ImageFolder(root=DATA_DIR, transform=transform)
loader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

print(f"Classes: {dataset.classes}")

Classes: ['Harsh', 'Mummy', 'Papa']


In [6]:
# Build model (ViT + classifier head)
vit = ViTModel.from_pretrained(MODEL_NAME)
class FaceClassifier(nn.Module):
    def __init__(self, vit, num_classes):
        super().__init__()
        self.vit = vit
        self.classifier = nn.Linear(vit.config.hidden_size, num_classes)
    def forward(self, x):
        outputs = self.vit(pixel_values=x)
        pooled = outputs.last_hidden_state[:, 0]
        return self.classifier(pooled)

model = FaceClassifier(vit, num_classes=len(dataset.classes)).to(device)

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
# Training loop
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)

model.train()
for epoch in range(EPOCHS):
    epoch_loss = 0
    for imgs, labels in tqdm(loader):
        inputs = imgs.to(device)
        labels = labels.to(device)
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    print(f"Epoch [{epoch+1}/{EPOCHS}] Loss: {epoch_loss/len(loader):.4f}")

100%|█████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:10<00:00,  1.34s/it]


Epoch [1/5] Loss: 0.7869


100%|█████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:07<00:00,  1.05it/s]


Epoch [2/5] Loss: 0.0535


100%|█████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:07<00:00,  1.05it/s]


Epoch [3/5] Loss: 0.0152


100%|█████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:08<00:00,  1.02s/it]


Epoch [4/5] Loss: 0.0070


100%|█████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:05<00:00,  1.45it/s]

Epoch [5/5] Loss: 0.0039





In [15]:
# Extract and save embeddings
model.eval()
embeddings = []
names = []

with torch.no_grad():
    for imgs, labels in loader:
        inputs = imgs.to(device)
        outputs = model.vit(pixel_values=inputs)
        pooled = outputs.last_hidden_state[:, 0]
        emb = pooled.cpu().numpy()
        embeddings.append(emb)
        for label in labels:
            names.append(dataset.classes[label])

embeddings = np.vstack(embeddings)

np.save("saved_embeddings/" + EMBEDDINGS_FILE, embeddings)
np.save("saved_embeddings/" + NAMES_FILE, np.array(names))

print("Saved embeddings & names to 'saved_embeddings' folder.")

Saved embeddings & names to 'saved_embeddings' folder.


In [17]:
# Test on a single image
def recognize(image_path):
    img = Image.open(image_path).convert("RGB")
    img_tensor = transform(img).unsqueeze(0).to(device)
    with torch.no_grad():
        output = model.vit(pixel_values=img_tensor)
        emb = output.last_hidden_state[:,0].cpu().numpy()
    loaded_embs = np.load("saved_embeddings/" + EMBEDDINGS_FILE)
    loaded_names = np.load("saved_embeddings/" + NAMES_FILE)
    sims = cosine_similarity(emb, loaded_embs)[0]
    best_idx = np.argmax(sims)
    print(f"Predicted: {loaded_names[best_idx]} (score: {sims[best_idx]:.2f})")

# Test
root_folder = "../registered_faces"
for person in os.listdir(root_folder):
    person_folder = os.path.join(root_folder, person)
    if not os.path.isdir(person_folder):
        continue
    print(f"\n=== Testing images in folder: {person} ===")
    for img_file in os.listdir(person_folder):
        img_path = os.path.join(person_folder, img_file)
        recognize(img_path)


=== Testing images in folder: Harsh ===
Predicted: Harsh (score: 1.00)
Predicted: Harsh (score: 1.00)
Predicted: Harsh (score: 1.00)
Predicted: Harsh (score: 1.00)
Predicted: Harsh (score: 1.00)
Predicted: Harsh (score: 1.00)
Predicted: Harsh (score: 1.00)
Predicted: Harsh (score: 1.00)
Predicted: Harsh (score: 1.00)
Predicted: Harsh (score: 1.00)

=== Testing images in folder: Mummy ===
Predicted: Mummy (score: 1.00)
Predicted: Mummy (score: 1.00)
Predicted: Mummy (score: 1.00)
Predicted: Mummy (score: 1.00)
Predicted: Mummy (score: 1.00)
Predicted: Mummy (score: 1.00)
Predicted: Mummy (score: 1.00)
Predicted: Mummy (score: 1.00)
Predicted: Mummy (score: 1.00)
Predicted: Mummy (score: 1.00)

=== Testing images in folder: Papa ===
Predicted: Papa (score: 1.00)
Predicted: Papa (score: 1.00)
Predicted: Papa (score: 1.00)
Predicted: Papa (score: 1.00)
Predicted: Papa (score: 1.00)
Predicted: Papa (score: 1.00)
Predicted: Papa (score: 1.00)
Predicted: Papa (score: 1.00)
Predicted: Papa (s

In [18]:
# Save trained model
MODEL_SAVE_PATH = "saved_model/vit_face_classifier.pth"

torch.save({
    'model_state_dict': model.state_dict(),
    'class_to_idx': dataset.class_to_idx
}, MODEL_SAVE_PATH)

print(f"Model saved to {MODEL_SAVE_PATH}")

Model saved to saved_model/vit_face_classifier.pth
