In [1]:
import os
import warnings
import logging

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
logging.getLogger('tensorflow').setLevel(logging.ERROR)
warnings.filterwarnings('ignore')

In [2]:
# Imports
import torch
from torch import nn, optim
from torchvision import datasets, transforms
from transformers import ViTModel, ViTImageProcessor
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt

E0000 00:00:1752142944.183949    3227 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1752142944.202482    3227 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1752142944.309577    3227 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1752142944.309635    3227 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1752142944.309638    3227 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1752142944.309640    3227 computation_placer.cc:177] computation placer already registered. Please check linka

In [3]:
# Config
MODEL_NAME = "google/vit-base-patch16-224"
DATA_DIR = "../registered_faces"
EMBEDDINGS_DIR = "../notebooks/saved_embeddings"
MODEL_SAVE_PATH = "../notebooks/saved_model/vit_face_classifier.pth"
EPOCHS = 5
BATCH_SIZE = 4
LR = 1e-4

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [4]:
# Load & preprocess dataset
processor = ViTImageProcessor.from_pretrained(MODEL_NAME)

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

dataset = datasets.ImageFolder(root=DATA_DIR, transform=transform)
loader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

print(f"Classes: {dataset.classes}")

Classes: ['Harsh', 'Mummy', 'Papa']


In [5]:
# Build model (ViT + classifier head)
vit = ViTModel.from_pretrained(MODEL_NAME)

class FaceClassifier(nn.Module):
    def __init__(self, vit, num_classes):
        super().__init__()
        self.vit = vit
        self.classifier = nn.Linear(vit.config.hidden_size, num_classes)

    def forward(self, x):
        outputs = self.vit(pixel_values=x)
        pooled = outputs.last_hidden_state[:, 0]
        return self.classifier(pooled)

model = FaceClassifier(vit, len(dataset.classes)).to(device)

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
# Training loop
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)

model.train()
for epoch in range(EPOCHS):
    total_loss = 0
    for imgs, labels in tqdm(loader):
        imgs, labels = imgs.to(device), labels.to(device)
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {total_loss/len(loader):.4f}")

100%|█████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:07<00:00,  1.10it/s]


Epoch 1/5, Loss: 0.8335


100%|█████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:06<00:00,  1.29it/s]


Epoch 2/5, Loss: 0.0662


100%|█████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:06<00:00,  1.31it/s]


Epoch 3/5, Loss: 0.0169


100%|█████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:06<00:00,  1.32it/s]


Epoch 4/5, Loss: 0.0076


100%|█████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:03<00:00,  2.25it/s]

Epoch 5/5, Loss: 0.0044





In [7]:
# Extract & save embeddings
model.eval()
embeddings, names = [], []

with torch.no_grad():
    for imgs, labels in loader:
        inputs = imgs.to(device)
        outputs = model.vit(pixel_values=inputs)
        pooled = outputs.last_hidden_state[:, 0].cpu().numpy()
        embeddings.append(pooled)
        names += [dataset.classes[label] for label in labels]

embeddings = np.vstack(embeddings)
os.makedirs(EMBEDDINGS_DIR, exist_ok=True)
np.save(f"{EMBEDDINGS_DIR}/face_embeddings.npy", embeddings)
np.save(f"{EMBEDDINGS_DIR}/face_names.npy", np.array(names))

print("Saved model and embeddings.")

Saved model and embeddings.


In [11]:
# Test on a single image
def recognize(image_path):
    img = Image.open(image_path).convert("RGB")
    img_tensor = transform(img).unsqueeze(0).to(device)
    with torch.no_grad():
        output = model.vit(pixel_values=img_tensor)
        emb = output.last_hidden_state[:,0].cpu().numpy()
    loaded_embs = np.load(EMBEDDINGS_FILE)
    loaded_names = np.load(NAMES_FILE)
    sims = cosine_similarity(emb, loaded_embs)[0]
    best_idx = np.argmax(sims)
    
    print(f"Predicted: {loaded_names[best_idx]} (score: {sims[best_idx]:.2f})")

In [14]:
# Test
root_folder = "../registered_faces"
EMBEDDINGS_FILE = "saved_embeddings/face_embeddings.npy"
NAMES_FILE = "saved_embeddings/face_names.npy"

for person in os.listdir(root_folder):
    person_folder = os.path.join(root_folder, person)
    if not os.path.isdir(person_folder):
        continue
        
    print(f"\n=== Testing images in folder: {person} ===")
    
    for img_file in os.listdir(person_folder):
        img_path = os.path.join(person_folder, img_file)
        recognize(img_path)


=== Testing images in folder: Harsh ===
Predicted: Harsh (score: 1.00)
Predicted: Harsh (score: 1.00)
Predicted: Harsh (score: 1.00)
Predicted: Harsh (score: 1.00)
Predicted: Harsh (score: 1.00)
Predicted: Harsh (score: 1.00)
Predicted: Harsh (score: 1.00)
Predicted: Harsh (score: 1.00)
Predicted: Harsh (score: 1.00)
Predicted: Harsh (score: 1.00)

=== Testing images in folder: Mummy ===
Predicted: Mummy (score: 1.00)
Predicted: Mummy (score: 1.00)
Predicted: Mummy (score: 1.00)
Predicted: Mummy (score: 1.00)
Predicted: Mummy (score: 1.00)
Predicted: Mummy (score: 1.00)
Predicted: Mummy (score: 1.00)
Predicted: Mummy (score: 1.00)
Predicted: Mummy (score: 1.00)
Predicted: Mummy (score: 1.00)

=== Testing images in folder: Papa ===
Predicted: Papa (score: 1.00)
Predicted: Papa (score: 1.00)
Predicted: Papa (score: 1.00)
Predicted: Papa (score: 1.00)
Predicted: Papa (score: 1.00)
Predicted: Papa (score: 1.00)
Predicted: Papa (score: 1.00)
Predicted: Papa (score: 1.00)
Predicted: Papa (s

In [15]:
# Save model
torch.save({
    'model_state_dict': model.state_dict(),
    'class_to_idx': dataset.class_to_idx
}, MODEL_SAVE_PATH)

print(f"Model saved to {MODEL_SAVE_PATH}")

Model saved to ../notebooks/saved_model/vit_face_classifier.pth
