In [1]:
import os
import warnings
import logging

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
logging.getLogger('tensorflow').setLevel(logging.ERROR)
warnings.filterwarnings('ignore')

In [2]:
# Imports
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from transformers import ViTModel
from sklearn.metrics import accuracy_score
from tqdm import tqdm
from PIL import Image

E0000 00:00:1752517212.021468   10838 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1752517212.028192   10838 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1752517212.050792   10838 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1752517212.050821   10838 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1752517212.050823   10838 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1752517212.050825   10838 computation_placer.cc:177] computation placer already registered. Please check linka

In [3]:
# Config
REGISTERED_FACE_DIR = "../registered_faces"
FER_DIR = "../emotions_data"
SAVE_PATH = "/saved_model/dual_head_vit.pth"
EPOCHS = 10
BATCH_SIZE = 8
LR = 1e-4

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
# Transformers
common_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

face_dataset = datasets.ImageFolder(REGISTERED_FACE_DIR, transform=common_transform)
emotion_dataset = datasets.ImageFolder(FER_DIR, transform=common_transform)

face_loader = DataLoader(face_dataset, batch_size=BATCH_SIZE, shuffle=True)
emotion_loader = DataLoader(emotion_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [5]:
# MODEL
class DualHeadViT(nn.Module):
    def __init__(self, vit, face_classes, emotion_classes):
        super().__init__()
        self.vit = vit
        self.face_head = nn.Linear(vit.config.hidden_size, face_classes)
        self.emotion_head = nn.Linear(vit.config.hidden_size, emotion_classes)

    def forward(self, x):
        features = self.vit(pixel_values=x).last_hidden_state[:, 0]
        face_out = self.face_head(features)
        emotion_out = self.emotion_head(features)
        return face_out, emotion_out

In [6]:
# Load base ViT
vit = ViTModel.from_pretrained("google/vit-base-patch16-224")
model = DualHeadViT(vit, len(face_dataset.classes), len(emotion_dataset.classes)).to(device)

# Loss + Optimizer
face_criterion = nn.CrossEntropyLoss()
emotion_criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
# TRAINING LOOP
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    face_correct, emotion_correct = 0, 0
    face_total, emotion_total = 0, 0

    face_iter = iter(face_loader)
    emotion_iter = iter(emotion_loader)
    steps = min(len(face_iter), len(emotion_iter))

    for _ in tqdm(range(steps), desc=f"Epoch {epoch+1}/{EPOCHS}"):
        try:
            x_face, y_face = next(face_iter)
            x_emotion, y_emotion = next(emotion_iter)
        except StopIteration:
            break

        x_face, y_face = x_face.to(device), y_face.to(device)
        x_emotion, y_emotion = x_emotion.to(device), y_emotion.to(device)

        # Combine batches
        x = torch.cat([x_face, x_emotion], dim=0)
        face_labels = torch.cat([y_face, torch.zeros_like(y_face)], dim=0)
        emotion_labels = torch.cat([torch.zeros_like(y_emotion), y_emotion], dim=0)

        optimizer.zero_grad()
        face_logits, emotion_logits = model(x)

        # Only compute loss on relevant parts
        face_loss = face_criterion(face_logits[:len(y_face)], y_face)
        emotion_loss = emotion_criterion(emotion_logits[len(y_face):], y_emotion)
        loss = face_loss + emotion_loss
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        face_preds = face_logits[:len(y_face)].argmax(1)
        face_correct += (face_preds == y_face).sum().item()
        face_total += len(y_face)

        emotion_preds = emotion_logits[len(y_face):].argmax(1)
        emotion_correct += (emotion_preds == y_emotion).sum().item()
        emotion_total += len(y_emotion)

    acc_face = 100 * face_correct / face_total
    acc_emotion = 100 * emotion_correct / emotion_total

    print(f"Loss: {total_loss/steps:.4f} | Face Acc: {acc_face:.2f}% | Emotion Acc: {acc_emotion:.2f}%")

Epoch 1/10: 100%|███████████████████████████████████████████████████████████████████████| 22/22 [00:32<00:00,  1.48s/it]


Loss: 2.1971 | Face Acc: 85.23% | Emotion Acc: 27.84%


Epoch 2/10: 100%|███████████████████████████████████████████████████████████████████████| 22/22 [00:32<00:00,  1.48s/it]


Loss: 1.6087 | Face Acc: 100.00% | Emotion Acc: 39.77%


Epoch 3/10: 100%|███████████████████████████████████████████████████████████████████████| 22/22 [00:32<00:00,  1.49s/it]


Loss: 1.5030 | Face Acc: 100.00% | Emotion Acc: 44.89%


Epoch 4/10: 100%|███████████████████████████████████████████████████████████████████████| 22/22 [00:32<00:00,  1.47s/it]


Loss: 1.4416 | Face Acc: 100.00% | Emotion Acc: 48.30%


Epoch 5/10: 100%|███████████████████████████████████████████████████████████████████████| 22/22 [00:33<00:00,  1.51s/it]


Loss: 1.3399 | Face Acc: 100.00% | Emotion Acc: 49.43%


Epoch 6/10: 100%|███████████████████████████████████████████████████████████████████████| 22/22 [00:32<00:00,  1.47s/it]


Loss: 1.3480 | Face Acc: 100.00% | Emotion Acc: 45.45%


Epoch 7/10: 100%|███████████████████████████████████████████████████████████████████████| 22/22 [00:32<00:00,  1.47s/it]


Loss: 1.2136 | Face Acc: 100.00% | Emotion Acc: 48.86%


Epoch 8/10: 100%|███████████████████████████████████████████████████████████████████████| 22/22 [00:32<00:00,  1.48s/it]


Loss: 1.2903 | Face Acc: 100.00% | Emotion Acc: 56.82%


Epoch 9/10: 100%|███████████████████████████████████████████████████████████████████████| 22/22 [00:32<00:00,  1.47s/it]


Loss: 1.3278 | Face Acc: 100.00% | Emotion Acc: 46.59%


Epoch 10/10: 100%|██████████████████████████████████████████████████████████████████████| 22/22 [00:32<00:00,  1.48s/it]

Loss: 1.1824 | Face Acc: 100.00% | Emotion Acc: 60.80%





In [9]:
# SAVE
SAVE_PATH = "./saved_model/dual_head_vit.pth"

torch.save({
    "model_state_dict": model.state_dict(),
    "face_classes": face_dataset.classes,
    "emotion_classes": emotion_dataset.classes
}, SAVE_PATH)

print(f"Model saved to {SAVE_PATH}")

Model saved to ./saved_model/dual_head_vit.pth
