In [1]:
import os
import warnings
import logging

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
logging.getLogger('tensorflow').setLevel(logging.ERROR)
warnings.filterwarnings('ignore')

In [2]:
# Imports
import torch
from torch import nn, optim
from torch.utils.data import random_split, DataLoader
from torchvision import datasets, transforms
from transformers import ViTModel, ViTImageProcessor
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt

E0000 00:00:1752146344.498452    4411 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1752146344.506680    4411 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1752146344.537697    4411 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1752146344.537731    4411 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1752146344.537733    4411 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1752146344.537735    4411 computation_placer.cc:177] computation placer already registered. Please check linka

In [3]:
# Config
MODEL_NAME = "google/vit-base-patch16-224"
DATA_DIR = "../registered_faces"
EMBEDDINGS_DIR = "../notebooks/saved_embeddings"
MODEL_SAVE_PATH = "../notebooks/saved_model/vit_face_classifier.pth"
EPOCHS = 10
BATCH_SIZE = 4
LR = 1e-4
FREEZE_BACKBONE = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [4]:
# Load & preprocess dataset
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor()
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# Dataset & Dataloader split
full_dataset = datasets.ImageFolder(root=DATA_DIR, transform=transform)
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

# Apply different transform to val set
val_dataset.dataset.transform = val_transform

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

print(f"Classes: {full_dataset.classes}")

Classes: ['Harsh', 'Mummy', 'Papa']


In [5]:
# Dataset & Dataloader split
full_dataset = datasets.ImageFolder(root=DATA_DIR, transform=transform)
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

val_dataset.dataset.transform = val_transform

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

print(f"Classes: {full_dataset.classes}")

Classes: ['Harsh', 'Mummy', 'Papa']


In [6]:
# # Build model (ViT + classifier head)
# vit = ViTModel.from_pretrained(MODEL_NAME)

# class FaceClassifier(nn.Module):
#     def __init__(self, vit, num_classes):
#         super().__init__()
#         self.vit = vit
#         self.classifier = nn.Linear(vit.config.hidden_size, num_classes)

#     def forward(self, x):
#         outputs = self.vit(pixel_values=x)
#         pooled = outputs.last_hidden_state[:, 0]
#         return self.classifier(pooled)

# model = FaceClassifier(vit, len(dataset.classes)).to(device)

# Build model
vit = ViTModel.from_pretrained(MODEL_NAME)

if FREEZE_BACKBONE:
    for param in vit.parameters():
        param.requires_grad = False

class FaceClassifier(nn.Module):
    def __init__(self, vit, num_classes):
        super().__init__()
        self.vit = vit
        self.classifier = nn.Linear(vit.config.hidden_size, num_classes)
    def forward(self, x):
        x = self.vit(pixel_values=x).last_hidden_state[:, 0]
        return self.classifier(x)

model = FaceClassifier(vit, len(full_dataset.classes)).to(device)

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
# # Training loop
# criterion = nn.CrossEntropyLoss()
# optimizer = optim.Adam(model.parameters(), lr=LR)

# model.train()
# for epoch in range(EPOCHS):
#     total_loss = 0
#     for imgs, labels in tqdm(loader):
#         imgs, labels = imgs.to(device), labels.to(device)
#         outputs = model(imgs)
#         loss = criterion(outputs, labels)
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()
#         total_loss += loss.item()
#     print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {total_loss/len(loader):.4f}")

# Training setup
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LR)

# Training loop
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    for imgs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
        imgs, labels = imgs.to(device), labels.to(device)
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg_train_loss = total_loss / len(train_loader)

    # Validation loop
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
    avg_val_loss = val_loss / len(val_loader)

    print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

Epoch 1/10: 100%|███████████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.23it/s]


Epoch 1/10 | Train Loss: 1.1378 | Val Loss: 1.1030


Epoch 2/10: 100%|███████████████████████████████████████████████████████████████████████| 18/18 [00:02<00:00,  7.15it/s]


Epoch 2/10 | Train Loss: 0.9206 | Val Loss: 1.0423


Epoch 3/10: 100%|███████████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.54it/s]


Epoch 3/10 | Train Loss: 0.7947 | Val Loss: 0.9679


Epoch 4/10: 100%|███████████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.50it/s]


Epoch 4/10 | Train Loss: 0.6960 | Val Loss: 0.8666


Epoch 5/10: 100%|███████████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.51it/s]


Epoch 5/10 | Train Loss: 0.6147 | Val Loss: 0.7975


Epoch 6/10: 100%|███████████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.57it/s]


Epoch 6/10 | Train Loss: 0.5421 | Val Loss: 0.7291


Epoch 7/10: 100%|███████████████████████████████████████████████████████████████████████| 18/18 [00:02<00:00,  6.86it/s]


Epoch 7/10 | Train Loss: 0.4823 | Val Loss: 0.6727


Epoch 8/10: 100%|███████████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.56it/s]


Epoch 8/10 | Train Loss: 0.4315 | Val Loss: 0.6278


Epoch 9/10: 100%|███████████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.51it/s]


Epoch 9/10 | Train Loss: 0.3871 | Val Loss: 0.5819


Epoch 10/10: 100%|██████████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.55it/s]


Epoch 10/10 | Train Loss: 0.3476 | Val Loss: 0.5494


In [8]:
# Extract & save embeddings
model.eval()
embeddings, names = [], []

loader = DataLoader(full_dataset, batch_size=BATCH_SIZE, shuffle=False)

with torch.no_grad():
    for imgs, labels in loader:
        imgs = imgs.to(device)
        outputs = model.vit(pixel_values=imgs)
        pooled = outputs.last_hidden_state[:, 0].cpu().numpy()
        embeddings.append(pooled)
        names += [full_dataset.classes[label] for label in labels]

embeddings = np.vstack(embeddings)
os.makedirs(EMBEDDINGS_DIR, exist_ok=True)
np.save(os.path.join(EMBEDDINGS_DIR, "face_embeddings.npy"), embeddings)
np.save(os.path.join(EMBEDDINGS_DIR, "face_names.npy"), np.array(names))

print("Saved face embeddings and names.")

Saved face embeddings and names.


In [9]:
# Test on a single image
EMBEDDINGS_FILE = os.path.join(EMBEDDINGS_DIR, "face_embeddings.npy")
NAMES_FILE = os.path.join(EMBEDDINGS_DIR, "face_names.npy")

def recognize(image_path):
    img = Image.open(image_path).convert("RGB")
    img_tensor = transform(img).unsqueeze(0).to(device)
    with torch.no_grad():
        output = model.vit(pixel_values=img_tensor)
        emb = output.last_hidden_state[:,0].cpu().numpy()
    loaded_embs = np.load(EMBEDDINGS_FILE)
    loaded_names = np.load(NAMES_FILE)
    sims = cosine_similarity(emb, loaded_embs)[0]
    best_idx = np.argmax(sims)
    
    print(f"Predicted: {loaded_names[best_idx]} (score: {sims[best_idx]:.2f})")

In [10]:
# Test
root_folder = "../registered_faces"
EMBEDDINGS_FILE = "saved_embeddings/face_embeddings.npy"
NAMES_FILE = "saved_embeddings/face_names.npy"

for person in os.listdir(root_folder):
    person_folder = os.path.join(root_folder, person)
    if not os.path.isdir(person_folder):
        continue
        
    print(f"\n--- Testing images in folder: {person} ---")
    
    for img_file in os.listdir(person_folder):
        img_path = os.path.join(person_folder, img_file)
        recognize(img_path)


--- Testing images in folder: Harsh ---
Predicted: Harsh (score: 0.71)
Predicted: Harsh (score: 0.76)
Predicted: Harsh (score: 0.77)
Predicted: Harsh (score: 0.65)
Predicted: Harsh (score: 0.89)
Predicted: Harsh (score: 0.73)
Predicted: Harsh (score: 0.72)
Predicted: Harsh (score: 0.65)
Predicted: Harsh (score: 0.82)
Predicted: Harsh (score: 0.67)
Predicted: Harsh (score: 0.80)
Predicted: Harsh (score: 0.85)
Predicted: Harsh (score: 0.65)
Predicted: Harsh (score: 0.77)
Predicted: Harsh (score: 0.86)
Predicted: Harsh (score: 0.76)
Predicted: Harsh (score: 0.75)
Predicted: Harsh (score: 0.83)
Predicted: Harsh (score: 0.80)
Predicted: Harsh (score: 0.64)
Predicted: Harsh (score: 0.72)
Predicted: Harsh (score: 0.83)
Predicted: Harsh (score: 0.70)
Predicted: Harsh (score: 0.85)
Predicted: Harsh (score: 0.88)
Predicted: Harsh (score: 0.73)
Predicted: Harsh (score: 0.68)
Predicted: Harsh (score: 0.90)
Predicted: Harsh (score: 0.78)
Predicted: Harsh (score: 0.91)
Predicted: Harsh (score: 0.86

In [11]:
# Save model
torch.save({
    'model_state_dict': model.state_dict(),
    'class_to_idx': full_dataset.class_to_idx
}, MODEL_SAVE_PATH)

print(f"Model saved to {MODEL_SAVE_PATH}")

Model saved to ../notebooks/saved_model/vit_face_classifier.pth
