In [1]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from transformers import ViTModel
import torch.nn as nn
import json
import os
from PIL import Image

# CONFIG
MODEL_PATH = "./saved_model/triple_head_vit.pth"
FEEDBACK_FACE_DIR = "../feedback_frames/self_training/face"
FEEDBACK_EMOTION_DIR = "../feedback_frames/self_training/emotion"
BATCH_SIZE = 8
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm
2025-08-04 06:08:08.669502: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1754287688.925176    2399 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1754287689.001388    2399 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1754287689.661944    2399 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1754287689.661997    2399 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1754287689.661999    2399

In [2]:
# Load original class lists
with open('./saved_model/face_classes.json', 'r') as f:
    face_classes = json.load(f)
with open('./saved_model/emotion_classes.json', 'r') as f:
    emotion_classes = json.load(f)

# Transforms
feedback_face_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])
feedback_emotion_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

In [3]:
# Feedback Dataset
class FeedbackDataset(torch.utils.data.Dataset):
    def __init__(self, base_dir, class_names, transform):
        self.transform = transform
        self.samples = []
        for cls in class_names:
            cls_path = os.path.join(base_dir, cls)
            if os.path.isdir(cls_path):
                for img_name in os.listdir(cls_path):
                    if img_name.lower().endswith((".jpg", ".jpeg", ".png")):
                        self.samples.append((os.path.join(cls_path, img_name), class_names.index(cls)))
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        path, label = self.samples[idx]
        img = Image.open(path).convert("RGB")
        return self.transform(img), label

In [4]:
# Prepare DataLoaders
face_feedback = FeedbackDataset(FEEDBACK_FACE_DIR, face_classes, feedback_face_transform)
emotion_feedback = FeedbackDataset(FEEDBACK_EMOTION_DIR, emotion_classes, feedback_emotion_transform)
face_loader = DataLoader(face_feedback, batch_size=BATCH_SIZE)
emotion_loader = DataLoader(emotion_feedback, batch_size=BATCH_SIZE)

In [5]:
# Model definition
class TripleHeadViT(nn.Module):
    def __init__(self, vit, face_classes, emotion_classes):
        super().__init__()
        self.vit = vit
        self.dropout = nn.Dropout(0.3)
        self.face_head = nn.Linear(vit.config.hidden_size, face_classes)
        self.emotion_head = nn.Linear(vit.config.hidden_size, emotion_classes)
        self.spoof_head = nn.Linear(vit.config.hidden_size, 1)
    def forward(self, x):
        features = self.vit(pixel_values=x).last_hidden_state[:, 0]
        features = self.dropout(features)
        return self.face_head(features), self.emotion_head(features), self.spoof_head(features)

# Load model
ckpt = torch.load(MODEL_PATH, map_location=DEVICE)
vit = ViTModel.from_pretrained("google/vit-base-patch16-224")
model = TripleHeadViT(vit, len(face_classes), len(emotion_classes)).to(DEVICE)
model.load_state_dict(ckpt["model_state_dict"], strict=True)
model.eval()

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


TripleHeadViT(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
            (intermediate_act_fn)

In [6]:
# Evaluation
def evaluate(loader, head):
    correct = 0
    total = 0
    with torch.no_grad():
        for imgs, labels in loader:
            imgs = imgs.to(DEVICE)
            labels = labels.to(DEVICE)
            face_logits, emotion_logits, _ = model(imgs)
            if head == 'face':
                preds = face_logits.argmax(dim=1)
            elif head == 'emotion':
                preds = emotion_logits.argmax(dim=1)
            else:
                raise ValueError("Head must be 'face' or 'emotion'")
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return 100.0 * correct / total if total > 0 else 0.0

face_acc = evaluate(face_loader, 'face')
emotion_acc = evaluate(emotion_loader, 'emotion')

print("Evaluation on feedback dataset:")
print(f"  Face recognition accuracy:   {face_acc:.2f}%")
print(f"  Emotion classification accuracy: {emotion_acc:.2f}%")

Evaluation on feedback dataset:
  Face recognition accuracy:   100.00%
  Emotion classification accuracy: 10.00%
