In [1]:
# Imports
import os, torch
from torch import nn, optim
from torchvision import datasets, transforms
from transformers import ViTModel, ViTImageProcessor
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score
from tqdm import tqdm
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm
2025-07-15 06:38:42.175712: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1752561522.390397     424 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1752561522.454531     424 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1752561523.002295     424 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1752561523.002335     424 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1752561523.002338     424

In [2]:
# Config
MODEL_NAME = "google/vit-base-patch16-224"
DATA_DIR = "../spoof_datasets/spoof"
MODEL_SAVE_PATH = "./saved_model/vit_spoof_classifier.pth"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
# Preprocess
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

dataset = datasets.ImageFolder(DATA_DIR, transform=transform)
loader = DataLoader(dataset, batch_size=8, shuffle=True)

In [4]:
# Build Model
vit = ViTModel.from_pretrained(MODEL_NAME)

class SpoofClassifier(nn.Module):
    def __init__(self, vit):
        super().__init__()
        self.vit = vit
        self.head = nn.Linear(vit.config.hidden_size, 2)
        
    def forward(self, x):
        features = self.vit(pixel_values=x).last_hidden_state[:, 0]
        return self.head(features)

model = SpoofClassifier(vit).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
# Train
model.train()
for epoch in range(5):
    total_loss = 0
    all_preds, all_labels = [], []
    for imgs, labels in tqdm(loader, desc=f"Epoch {epoch+1}"):
        imgs, labels = imgs.to(device), labels.to(device)
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        optimizer.zero_grad(); loss.backward(); optimizer.step()
        total_loss += loss.item()

        all_preds.extend(outputs.argmax(1).cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
    
    acc = accuracy_score(all_labels, all_preds)
    print(f"Epoch {epoch+1}, Loss: {total_loss/len(loader):.4f}, Acc: {acc:.4f}")

Epoch 1: 100%|██████████████████████████████████████████████████████████████████████████| 10/10 [00:11<00:00,  1.19s/it]


Epoch 1, Loss: 0.2926, Acc: 0.8125


Epoch 2: 100%|██████████████████████████████████████████████████████████████████████████| 10/10 [00:08<00:00,  1.21it/s]


Epoch 2, Loss: 0.0026, Acc: 1.0000


Epoch 3: 100%|██████████████████████████████████████████████████████████████████████████| 10/10 [00:08<00:00,  1.22it/s]


Epoch 3, Loss: 0.0007, Acc: 1.0000


Epoch 4: 100%|██████████████████████████████████████████████████████████████████████████| 10/10 [00:08<00:00,  1.23it/s]


Epoch 4, Loss: 0.0004, Acc: 1.0000


Epoch 5: 100%|██████████████████████████████████████████████████████████████████████████| 10/10 [00:08<00:00,  1.13it/s]

Epoch 5, Loss: 0.0003, Acc: 1.0000





In [6]:
# Save
torch.save(model.state_dict(), MODEL_SAVE_PATH)

print("Saved spoof detection model to:", MODEL_SAVE_PATH)

Saved spoof detection model to: ./saved_model/vit_spoof_classifier.pth


In [9]:
model.eval()

SpoofClassifier(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
            (intermediate_act_f