In [11]:
# Imports
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from transformers import ViTModel
from tqdm import tqdm

In [12]:
# Config
FACE_DIR = "../registered_faces"
EMOTION_DIR = "../emotions_data"
SPOOF_DIR = "../spoof_datasets/spoof"
SAVE_PATH = "./saved_model/triple_head_vit.pth"
BATCH_SIZE = 8
EPOCHS = 10
LR = 1e-4
VAL_SPLIT = 0.2

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [13]:
# Transforms
base_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

emotion_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor()
])

In [14]:
# Datasets
face_dataset = datasets.ImageFolder(FACE_DIR, transform=base_transform)
emotion_dataset = datasets.ImageFolder(EMOTION_DIR, transform=emotion_transform)
spoof_dataset = datasets.ImageFolder(SPOOF_DIR, transform=base_transform)

# Split emotion & spoof into train/val
val_len = int(len(emotion_dataset) * VAL_SPLIT)
emotion_train, emotion_val = random_split(emotion_dataset, [len(emotion_dataset)-val_len, val_len])

val_len_spf = int(len(spoof_dataset) * VAL_SPLIT)
spoof_train, spoof_val = random_split(spoof_dataset, [len(spoof_dataset)-val_len_spf, val_len_spf])

In [15]:
# Loaders
face_loader = DataLoader(face_dataset, batch_size=BATCH_SIZE, shuffle=True)
emotion_loader = DataLoader(emotion_train, batch_size=BATCH_SIZE, shuffle=True)
spoof_loader = DataLoader(spoof_train, batch_size=BATCH_SIZE, shuffle=True)

In [16]:
# Validation
emotion_val_loader = DataLoader(emotion_val, batch_size=BATCH_SIZE)
spoof_val_loader = DataLoader(spoof_val, batch_size=BATCH_SIZE)

In [17]:
# Model
class TripleHeadViT(nn.Module):
    def __init__(self, vit, face_classes, emotion_classes):
        super().__init__()
        self.vit = vit
        self.dropout = nn.Dropout(0.3)
        self.face_head = nn.Linear(vit.config.hidden_size, face_classes)
        self.emotion_head = nn.Linear(vit.config.hidden_size, emotion_classes)
        self.spoof_head = nn.Linear(vit.config.hidden_size, 1)  # Binary class

    def forward(self, x):
        features = self.vit(pixel_values=x).last_hidden_state[:, 0]
        features = self.dropout(features)
        return self.face_head(features), self.emotion_head(features), self.spoof_head(features)

In [18]:
vit = ViTModel.from_pretrained("google/vit-base-patch16-224")
for name, param in vit.named_parameters():
    if "encoder.layer.11" not in name and "encoder.layer.10" not in name:
        param.requires_grad = False

model = TripleHeadViT(vit, len(face_dataset.classes), len(emotion_dataset.classes)).to(device)

# Losses
face_criterion = nn.CrossEntropyLoss()
emotion_criterion = nn.CrossEntropyLoss()
spoof_criterion = nn.BCEWithLogitsLoss()

optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LR)

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [19]:
# Training Loop
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0

    face_correct, emotion_correct, spoof_correct = 0, 0, 0
    face_total, emotion_total, spoof_total = 0, 0, 0

    face_iter = iter(face_loader)
    emotion_iter = iter(emotion_loader)
    spoof_iter = iter(spoof_loader)

    steps = min(len(face_iter), len(emotion_iter), len(spoof_iter))

    for _ in tqdm(range(steps), desc=f"Epoch {epoch+1}/{EPOCHS}"):
        x_face, y_face = next(face_iter)
        x_emotion, y_emotion = next(emotion_iter)
        x_spoof, y_spoof = next(spoof_iter)

        x_face, y_face = x_face.to(device), y_face.to(device)
        x_emotion, y_emotion = x_emotion.to(device), y_emotion.to(device)
        x_spoof, y_spoof = x_spoof.to(device), y_spoof.to(device).float().unsqueeze(1)

        x = torch.cat([x_face, x_emotion, x_spoof], dim=0)
        optimizer.zero_grad()
        face_logits, emotion_logits, spoof_logits = model(x)

        face_loss = face_criterion(face_logits[:len(y_face)], y_face)
        emotion_loss = emotion_criterion(emotion_logits[len(y_face):len(y_face)+len(y_emotion)], y_emotion)
        spoof_loss = spoof_criterion(spoof_logits[-len(y_spoof):], y_spoof)
        loss = face_loss + emotion_loss + spoof_loss

        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        # Accuracy
        face_preds = face_logits[:len(y_face)].argmax(1)
        face_correct += (face_preds == y_face).sum().item()
        face_total += len(y_face)

        emotion_preds = emotion_logits[len(y_face):len(y_face)+len(y_emotion)].argmax(1)
        emotion_correct += (emotion_preds == y_emotion).sum().item()
        emotion_total += len(y_emotion)

        spoof_preds = torch.sigmoid(spoof_logits[-len(y_spoof):]) > 0.5
        spoof_correct += (spoof_preds.squeeze().int() == y_spoof.squeeze().int()).sum().item()
        spoof_total += len(y_spoof)

    # Calculate train accuracies
    train_face_acc = 100 * face_correct / face_total
    train_emotion_acc = 100 * emotion_correct / emotion_total
    train_spoof_acc = 100 * spoof_correct / spoof_total

    # Validation
    model.eval()
    val_emotion_correct, val_emotion_total = 0, 0
    val_spoof_correct, val_spoof_total = 0, 0

    with torch.no_grad():
        for x_val, y_val in emotion_val_loader:
            x_val, y_val = x_val.to(device), y_val.to(device)
            _, emotion_logits, _ = model(x_val)
            preds = emotion_logits.argmax(1)
            val_emotion_correct += (preds == y_val).sum().item()
            val_emotion_total += len(y_val)

        for x_val, y_val in spoof_val_loader:
            x_val, y_val = x_val.to(device), y_val.to(device).float().unsqueeze(1)
            _, _, spoof_logits = model(x_val)
            preds = (torch.sigmoid(spoof_logits) > 0.5).int()
            val_spoof_correct += (preds == y_val.int()).sum().item()
            val_spoof_total += len(y_val)

    val_emotion_acc = 100 * val_emotion_correct / val_emotion_total
    val_spoof_acc = 100 * val_spoof_correct / val_spoof_total

    # Log all results
    print(f"[Epoch {epoch+1}] Loss: {total_loss/steps:.4f} | "
          f"Face Acc (Train): {train_face_acc:.2f}% | "
          f"Emotion Acc (Train): {train_emotion_acc:.2f}% | Val: {val_emotion_acc:.2f}% | "
          f"Spoof Acc (Train): {train_spoof_acc:.2f}% | Val: {val_spoof_acc:.2f}%")

Epoch 1/10: 100%|█████████████████████████████████████████████████████████████████████████| 8/8 [00:12<00:00,  1.55s/it]


[Epoch 1] Loss: 3.5566 | Face Acc (Train): 42.19% | Emotion Acc (Train): 28.12% | Val: 25.78% | Spoof Acc (Train): 93.75% | Val: 93.75%


Epoch 2/10: 100%|█████████████████████████████████████████████████████████████████████████| 8/8 [00:10<00:00,  1.33s/it]


[Epoch 2] Loss: 2.6630 | Face Acc (Train): 82.81% | Emotion Acc (Train): 20.31% | Val: 29.29% | Spoof Acc (Train): 96.88% | Val: 100.00%


Epoch 3/10: 100%|█████████████████████████████████████████████████████████████████████████| 8/8 [00:11<00:00,  1.48s/it]


[Epoch 3] Loss: 2.1749 | Face Acc (Train): 90.62% | Emotion Acc (Train): 34.38% | Val: 31.84% | Spoof Acc (Train): 100.00% | Val: 100.00%


Epoch 4/10: 100%|█████████████████████████████████████████████████████████████████████████| 8/8 [00:10<00:00,  1.37s/it]


[Epoch 4] Loss: 2.2029 | Face Acc (Train): 92.19% | Emotion Acc (Train): 23.44% | Val: 35.15% | Spoof Acc (Train): 100.00% | Val: 100.00%


Epoch 5/10: 100%|█████████████████████████████████████████████████████████████████████████| 8/8 [00:09<00:00,  1.24s/it]


[Epoch 5] Loss: 1.7408 | Face Acc (Train): 96.88% | Emotion Acc (Train): 45.31% | Val: 36.30% | Spoof Acc (Train): 100.00% | Val: 100.00%


Epoch 6/10: 100%|█████████████████████████████████████████████████████████████████████████| 8/8 [00:10<00:00,  1.30s/it]


[Epoch 6] Loss: 1.9947 | Face Acc (Train): 98.44% | Emotion Acc (Train): 25.00% | Val: 39.36% | Spoof Acc (Train): 100.00% | Val: 100.00%


Epoch 7/10: 100%|█████████████████████████████████████████████████████████████████████████| 8/8 [00:10<00:00,  1.34s/it]


[Epoch 7] Loss: 1.8334 | Face Acc (Train): 98.44% | Emotion Acc (Train): 34.38% | Val: 40.39% | Spoof Acc (Train): 100.00% | Val: 100.00%


Epoch 8/10: 100%|█████████████████████████████████████████████████████████████████████████| 8/8 [00:09<00:00,  1.20s/it]


[Epoch 8] Loss: 1.7481 | Face Acc (Train): 100.00% | Emotion Acc (Train): 35.94% | Val: 42.27% | Spoof Acc (Train): 100.00% | Val: 100.00%


Epoch 9/10: 100%|█████████████████████████████████████████████████████████████████████████| 8/8 [00:09<00:00,  1.24s/it]


[Epoch 9] Loss: 1.4817 | Face Acc (Train): 100.00% | Emotion Acc (Train): 45.31% | Val: 43.11% | Spoof Acc (Train): 100.00% | Val: 100.00%


Epoch 10/10: 100%|████████████████████████████████████████████████████████████████████████| 8/8 [00:09<00:00,  1.13s/it]


[Epoch 10] Loss: 1.4538 | Face Acc (Train): 100.00% | Emotion Acc (Train): 50.00% | Val: 44.42% | Spoof Acc (Train): 100.00% | Val: 100.00%


In [20]:
# SAVE
torch.save({
    "model_state_dict": model.state_dict(),
    "face_classes": face_dataset.classes,
    "emotion_classes": emotion_dataset.classes
}, SAVE_PATH)

print(f"Triple-head model saved at {SAVE_PATH}")

Triple-head model saved at ./saved_model/triple_head_vit.pth
