In [1]:
import os
import warnings
import logging

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
logging.getLogger('tensorflow').setLevel(logging.ERROR)
warnings.filterwarnings('ignore')

In [2]:
# Imports
import torch
from torch import nn, optim
from torch.nn import Dropout
from torch.utils.data import random_split, DataLoader
from torchvision import datasets, transforms
from transformers import ViTModel, ViTImageProcessor
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt

E0000 00:00:1752148855.844217    5541 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1752148855.852412    5541 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1752148855.877592    5541 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1752148855.877645    5541 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1752148855.877648    5541 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1752148855.877650    5541 computation_placer.cc:177] computation placer already registered. Please check linka

In [3]:
# Config
MODEL_NAME = "google/vit-base-patch16-224"
DATA_DIR = "../registered_faces"
EMBEDDINGS_DIR = "../notebooks/saved_embeddings"
MODEL_SAVE_PATH = "../notebooks/saved_model/vit_face_classifier.pth"
EPOCHS = 50
BATCH_SIZE = 4
LR = 1e-4
FREEZE_BACKBONE = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [4]:
# Load & preprocess dataset
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor()
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# Dataset & Dataloader split
full_dataset = datasets.ImageFolder(root=DATA_DIR, transform=transform)
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

# Apply different transform to val set
val_dataset.dataset.transform = val_transform

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

print(f"Classes: {full_dataset.classes}")

Classes: ['Harsh', 'Mummy', 'Papa']


In [5]:
# Dataset & Dataloader split
full_dataset = datasets.ImageFolder(root=DATA_DIR, transform=transform)
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

val_dataset.dataset.transform = val_transform

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

print(f"Classes: {full_dataset.classes}")

Classes: ['Harsh', 'Mummy', 'Papa']


In [6]:
# Build model
vit = ViTModel.from_pretrained(MODEL_NAME)

if FREEZE_BACKBONE:
    for param in vit.parameters():
        param.requires_grad = False

class FaceClassifier(nn.Module):
    def __init__(self, vit, num_classes, dropout=0.3):
        super().__init__()
        self.vit = vit
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(vit.config.hidden_size, num_classes)
        
    def forward(self, x):
        x = self.vit(pixel_values=x).last_hidden_state[:, 0]
        x = self.dropout(x)
        return self.classifier(x)

model = FaceClassifier(vit, len(full_dataset.classes), dropout=0.3).to(device)

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
# Training setup
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LR)
best_val_loss = float('inf')
patience = 3
counter = 0

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    for imgs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
        imgs, labels = imgs.to(device), labels.to(device)
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg_train_loss = total_loss / len(train_loader)

    # Validation loop
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
    avg_val_loss = val_loss / len(val_loader)

    print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

    # Early Stopping Check
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        counter = 0
        torch.save({
            'model_state_dict': model.state_dict(),
            'class_to_idx': full_dataset.class_to_idx
        }, MODEL_SAVE_PATH)
        print(f"Saved best model at epoch {epoch+1}")
    else:
        counter += 1
        print(f"No improvement. Patience counter: {counter}/{patience}")
        if counter >= patience:
            print("Early stopping triggered.")
            break

Epoch 1/50: 100%|███████████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.16it/s]


Epoch 1/50 | Train Loss: 1.0867 | Val Loss: 1.0284
Saved best model at epoch 1


Epoch 2/50: 100%|███████████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.55it/s]


Epoch 2/50 | Train Loss: 0.9683 | Val Loss: 0.9362
Saved best model at epoch 2


Epoch 3/50: 100%|███████████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.52it/s]


Epoch 3/50 | Train Loss: 0.8200 | Val Loss: 0.8533
Saved best model at epoch 3


Epoch 4/50: 100%|███████████████████████████████████████████████████████████████████████| 18/18 [00:02<00:00,  7.24it/s]


Epoch 4/50 | Train Loss: 0.7696 | Val Loss: 0.7879
Saved best model at epoch 4


Epoch 5/50: 100%|███████████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.47it/s]


Epoch 5/50 | Train Loss: 0.6375 | Val Loss: 0.7354
Saved best model at epoch 5


Epoch 6/50: 100%|███████████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.44it/s]


Epoch 6/50 | Train Loss: 0.6115 | Val Loss: 0.6830
Saved best model at epoch 6


Epoch 7/50: 100%|███████████████████████████████████████████████████████████████████████| 18/18 [00:02<00:00,  6.90it/s]


Epoch 7/50 | Train Loss: 0.5209 | Val Loss: 0.6423
Saved best model at epoch 7


Epoch 8/50: 100%|███████████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.31it/s]


Epoch 8/50 | Train Loss: 0.5099 | Val Loss: 0.6043
Saved best model at epoch 8


Epoch 9/50: 100%|███████████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.50it/s]


Epoch 9/50 | Train Loss: 0.4351 | Val Loss: 0.5757
Saved best model at epoch 9


Epoch 10/50: 100%|██████████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.53it/s]


Epoch 10/50 | Train Loss: 0.4017 | Val Loss: 0.5474
Saved best model at epoch 10


Epoch 11/50: 100%|██████████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.54it/s]


Epoch 11/50 | Train Loss: 0.3868 | Val Loss: 0.5237
Saved best model at epoch 11


Epoch 12/50: 100%|██████████████████████████████████████████████████████████████████████| 18/18 [00:04<00:00,  3.64it/s]


Epoch 12/50 | Train Loss: 0.3295 | Val Loss: 0.5011
Saved best model at epoch 12


Epoch 13/50: 100%|██████████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.51it/s]


Epoch 13/50 | Train Loss: 0.3157 | Val Loss: 0.4844
Saved best model at epoch 13


Epoch 14/50: 100%|██████████████████████████████████████████████████████████████████████| 18/18 [00:02<00:00,  7.70it/s]


Epoch 14/50 | Train Loss: 0.2826 | Val Loss: 0.4663
Saved best model at epoch 14


Epoch 15/50: 100%|██████████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.28it/s]


Epoch 15/50 | Train Loss: 0.2904 | Val Loss: 0.4507
Saved best model at epoch 15


Epoch 16/50: 100%|██████████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.33it/s]


Epoch 16/50 | Train Loss: 0.2251 | Val Loss: 0.4393
Saved best model at epoch 16


Epoch 17/50: 100%|██████████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.50it/s]


Epoch 17/50 | Train Loss: 0.2276 | Val Loss: 0.4289
Saved best model at epoch 17


Epoch 18/50: 100%|██████████████████████████████████████████████████████████████████████| 18/18 [00:02<00:00,  7.75it/s]


Epoch 18/50 | Train Loss: 0.2112 | Val Loss: 0.4169
Saved best model at epoch 18


Epoch 19/50: 100%|██████████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.01it/s]


Epoch 19/50 | Train Loss: 0.2192 | Val Loss: 0.4069
Saved best model at epoch 19


Epoch 20/50: 100%|██████████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.41it/s]


Epoch 20/50 | Train Loss: 0.2110 | Val Loss: 0.3988
Saved best model at epoch 20


Epoch 21/50: 100%|██████████████████████████████████████████████████████████████████████| 18/18 [00:02<00:00,  7.36it/s]


Epoch 21/50 | Train Loss: 0.1867 | Val Loss: 0.3915
Saved best model at epoch 21


Epoch 22/50: 100%|██████████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.48it/s]


Epoch 22/50 | Train Loss: 0.1701 | Val Loss: 0.3833
Saved best model at epoch 22


Epoch 23/50: 100%|██████████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.55it/s]


Epoch 23/50 | Train Loss: 0.1608 | Val Loss: 0.3778
Saved best model at epoch 23


Epoch 24/50: 100%|██████████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.53it/s]


Epoch 24/50 | Train Loss: 0.1405 | Val Loss: 0.3723
Saved best model at epoch 24


Epoch 25/50: 100%|██████████████████████████████████████████████████████████████████████| 18/18 [00:02<00:00,  7.90it/s]


Epoch 25/50 | Train Loss: 0.1368 | Val Loss: 0.3656
Saved best model at epoch 25


Epoch 26/50: 100%|██████████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.60it/s]


Epoch 26/50 | Train Loss: 0.1393 | Val Loss: 0.3590
Saved best model at epoch 26


Epoch 27/50: 100%|██████████████████████████████████████████████████████████████████████| 18/18 [00:04<00:00,  3.60it/s]


Epoch 27/50 | Train Loss: 0.1408 | Val Loss: 0.3538
Saved best model at epoch 27


Epoch 28/50: 100%|██████████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.28it/s]


Epoch 28/50 | Train Loss: 0.1257 | Val Loss: 0.3508
Saved best model at epoch 28


Epoch 29/50: 100%|██████████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.51it/s]


Epoch 29/50 | Train Loss: 0.1334 | Val Loss: 0.3451
Saved best model at epoch 29


Epoch 30/50: 100%|██████████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.54it/s]


Epoch 30/50 | Train Loss: 0.1119 | Val Loss: 0.3402
Saved best model at epoch 30


Epoch 31/50: 100%|██████████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.52it/s]


Epoch 31/50 | Train Loss: 0.1155 | Val Loss: 0.3358
Saved best model at epoch 31


Epoch 32/50: 100%|██████████████████████████████████████████████████████████████████████| 18/18 [00:02<00:00,  7.06it/s]


Epoch 32/50 | Train Loss: 0.1081 | Val Loss: 0.3322
Saved best model at epoch 32


Epoch 33/50: 100%|██████████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.51it/s]


Epoch 33/50 | Train Loss: 0.1132 | Val Loss: 0.3268
Saved best model at epoch 33


Epoch 34/50: 100%|██████████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.33it/s]


Epoch 34/50 | Train Loss: 0.1029 | Val Loss: 0.3237
Saved best model at epoch 34


Epoch 35/50: 100%|██████████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.43it/s]


Epoch 35/50 | Train Loss: 0.0990 | Val Loss: 0.3209
Saved best model at epoch 35


Epoch 36/50: 100%|██████████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.30it/s]


Epoch 36/50 | Train Loss: 0.0823 | Val Loss: 0.3179
Saved best model at epoch 36


Epoch 37/50: 100%|██████████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.20it/s]


Epoch 37/50 | Train Loss: 0.0752 | Val Loss: 0.3152
Saved best model at epoch 37


Epoch 38/50: 100%|██████████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.53it/s]


Epoch 38/50 | Train Loss: 0.0876 | Val Loss: 0.3139
Saved best model at epoch 38


Epoch 39/50: 100%|██████████████████████████████████████████████████████████████████████| 18/18 [00:02<00:00,  7.42it/s]


Epoch 39/50 | Train Loss: 0.0725 | Val Loss: 0.3114
Saved best model at epoch 39


Epoch 40/50: 100%|██████████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.50it/s]


Epoch 40/50 | Train Loss: 0.0667 | Val Loss: 0.3086
Saved best model at epoch 40


Epoch 41/50: 100%|██████████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.42it/s]


Epoch 41/50 | Train Loss: 0.0654 | Val Loss: 0.3077
Saved best model at epoch 41


Epoch 42/50: 100%|██████████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.37it/s]


Epoch 42/50 | Train Loss: 0.0678 | Val Loss: 0.3050
Saved best model at epoch 42


Epoch 43/50: 100%|██████████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.48it/s]


Epoch 43/50 | Train Loss: 0.0683 | Val Loss: 0.3028
Saved best model at epoch 43


Epoch 44/50: 100%|██████████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.40it/s]


Epoch 44/50 | Train Loss: 0.0685 | Val Loss: 0.3008
Saved best model at epoch 44


Epoch 45/50: 100%|██████████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.39it/s]


Epoch 45/50 | Train Loss: 0.0611 | Val Loss: 0.2998
Saved best model at epoch 45


Epoch 46/50: 100%|██████████████████████████████████████████████████████████████████████| 18/18 [00:02<00:00,  6.97it/s]


Epoch 46/50 | Train Loss: 0.0577 | Val Loss: 0.2983
Saved best model at epoch 46


Epoch 47/50: 100%|██████████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.48it/s]


Epoch 47/50 | Train Loss: 0.0569 | Val Loss: 0.2971
Saved best model at epoch 47


Epoch 48/50: 100%|██████████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.20it/s]


Epoch 48/50 | Train Loss: 0.0582 | Val Loss: 0.2953
Saved best model at epoch 48


Epoch 49/50: 100%|██████████████████████████████████████████████████████████████████████| 18/18 [00:02<00:00,  6.73it/s]


Epoch 49/50 | Train Loss: 0.0518 | Val Loss: 0.2942
Saved best model at epoch 49


Epoch 50/50: 100%|██████████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.37it/s]


Epoch 50/50 | Train Loss: 0.0540 | Val Loss: 0.2928
Saved best model at epoch 50


In [8]:
# Extract & save embeddings
model.eval()
embeddings, names = [], []

loader = DataLoader(full_dataset, batch_size=BATCH_SIZE, shuffle=False)

with torch.no_grad():
    for imgs, labels in loader:
        imgs = imgs.to(device)
        outputs = model.vit(pixel_values=imgs)
        pooled = outputs.last_hidden_state[:, 0].cpu().numpy()
        embeddings.append(pooled)
        names += [full_dataset.classes[label] for label in labels]

embeddings = np.vstack(embeddings)
os.makedirs(EMBEDDINGS_DIR, exist_ok=True)
np.save(os.path.join(EMBEDDINGS_DIR, "face_embeddings.npy"), embeddings)
np.save(os.path.join(EMBEDDINGS_DIR, "face_names.npy"), np.array(names))

print("Saved face embeddings and names.")

Saved face embeddings and names.


In [9]:
# Test on a single image
EMBEDDINGS_FILE = os.path.join(EMBEDDINGS_DIR, "face_embeddings.npy")
NAMES_FILE = os.path.join(EMBEDDINGS_DIR, "face_names.npy")

def recognize(image_path):
    img = Image.open(image_path).convert("RGB")
    img_tensor = transform(img).unsqueeze(0).to(device)
    with torch.no_grad():
        output = model.vit(pixel_values=img_tensor)
        emb = output.last_hidden_state[:,0].cpu().numpy()
    loaded_embs = np.load(EMBEDDINGS_FILE)
    loaded_names = np.load(NAMES_FILE)
    sims = cosine_similarity(emb, loaded_embs)[0]
    best_idx = np.argmax(sims)
    
    print(f"Predicted: {loaded_names[best_idx]} (score: {sims[best_idx]:.2f})")

In [10]:
# Test
root_folder = "../registered_faces"
EMBEDDINGS_FILE = "saved_embeddings/face_embeddings.npy"
NAMES_FILE = "saved_embeddings/face_names.npy"

for person in os.listdir(root_folder):
    person_folder = os.path.join(root_folder, person)
    if not os.path.isdir(person_folder):
        continue
        
    print(f"\n--- Testing images in folder: {person} ---")
    
    for img_file in os.listdir(person_folder):
        img_path = os.path.join(person_folder, img_file)
        recognize(img_path)


--- Testing images in folder: Harsh ---
Predicted: Harsh (score: 0.87)
Predicted: Harsh (score: 0.73)
Predicted: Harsh (score: 0.73)
Predicted: Harsh (score: 0.78)
Predicted: Harsh (score: 0.65)
Predicted: Harsh (score: 0.79)
Predicted: Harsh (score: 0.75)
Predicted: Harsh (score: 0.70)
Predicted: Harsh (score: 0.84)
Predicted: Harsh (score: 0.84)
Predicted: Harsh (score: 0.77)
Predicted: Harsh (score: 0.93)
Predicted: Harsh (score: 0.74)
Predicted: Harsh (score: 0.75)
Predicted: Harsh (score: 0.70)
Predicted: Harsh (score: 0.76)
Predicted: Harsh (score: 0.77)
Predicted: Harsh (score: 0.77)
Predicted: Harsh (score: 0.83)
Predicted: Harsh (score: 0.73)
Predicted: Harsh (score: 0.79)
Predicted: Harsh (score: 0.77)
Predicted: Harsh (score: 0.73)
Predicted: Harsh (score: 0.77)
Predicted: Harsh (score: 1.00)
Predicted: Harsh (score: 0.79)
Predicted: Harsh (score: 0.82)
Predicted: Harsh (score: 0.82)
Predicted: Harsh (score: 0.75)
Predicted: Harsh (score: 0.87)
Predicted: Harsh (score: 0.84

In [11]:
# Save model
torch.save({
    'model_state_dict': model.state_dict(),
    'class_to_idx': full_dataset.class_to_idx
}, MODEL_SAVE_PATH)

print(f"Model saved to {MODEL_SAVE_PATH}")

Model saved to ../notebooks/saved_model/vit_face_classifier.pth
