In [1]:
!pip install datasets transformers



In [1]:
from datasets import load_dataset

ds = load_dataset("nguyenkhoa/celeba-spoof-for-face-antispoofing-test")
train_ds = ds["test"].train_test_split(train_size=0.2, seed=42)["train"]
print(train_ds)


Dataset({
    features: ['cropped_image', 'labels', 'labelNames'],
    num_rows: 13434
})


In [2]:
import os
from PIL import Image

train_dir = "celeba_spoof/train"

os.makedirs(os.path.join(train_dir, "real"), exist_ok=True)
os.makedirs(os.path.join(train_dir, "spoof"), exist_ok=True)

for i, example in enumerate(train_ds):
    img = example["cropped_image"]

    if img is None:
        print(f"Skipping image {i} because it's None")
        continue

    if not isinstance(img, Image.Image):
        img = Image.open(img)

    label = "real" if example["labels"] == 0 else "spoof"
    
    img_path = os.path.join(train_dir, label, f"{i}.jpg")
    img.save(img_path)

print("20% training data saved successfully!")


Skipping image 443 because it's None
Skipping image 468 because it's None
Skipping image 855 because it's None
Skipping image 1134 because it's None
Skipping image 1317 because it's None
Skipping image 1521 because it's None
Skipping image 1740 because it's None
Skipping image 1994 because it's None
Skipping image 2051 because it's None
Skipping image 2156 because it's None
Skipping image 2313 because it's None
Skipping image 2550 because it's None
Skipping image 2662 because it's None
Skipping image 2804 because it's None
Skipping image 3009 because it's None
Skipping image 3112 because it's None
Skipping image 3249 because it's None
Skipping image 3264 because it's None
Skipping image 3463 because it's None
Skipping image 3526 because it's None
Skipping image 3591 because it's None
Skipping image 3956 because it's None
Skipping image 4052 because it's None
Skipping image 4059 because it's None
Skipping image 4095 because it's None
Skipping image 4437 because it's None
Skipping image 

In [3]:
from transformers import ViTImageProcessor

model_name = 'google/vit-base-patch16-224-in21k'
processor = ViTImageProcessor.from_pretrained(model_name)

In [4]:
print(processor)

ViTImageProcessor {
  "do_convert_rgb": null,
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.5,
    0.5,
    0.5
  ],
  "image_processor_type": "ViTImageProcessor",
  "image_std": [
    0.5,
    0.5,
    0.5
  ],
  "resample": 2,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "height": 224,
    "width": 224
  }
}



In [5]:
def process_example(example):
    if example["cropped_image"] is None:
        return None

    inputs = processor(example["cropped_image"], return_tensors="pt")
    inputs["labels"] = torch.tensor(example["labels"])
    return inputs


In [6]:
import torch

processed_ds = train_ds.map(process_example, remove_columns=["cropped_image", "labelNames"])

In [7]:
print(processed_ds)

Dataset({
    features: ['labels', 'pixel_values'],
    num_rows: 13342
})


In [8]:
import torch

def collate_fn(batch):
    pixel_values = torch.stack([torch.tensor(x['pixel_values']) for x in batch])
    labels = torch.tensor([x['labels'] for x in batch])
    
    pixel_values = pixel_values.squeeze(1)
    return {'pixel_values': pixel_values, 'labels': labels}


In [9]:
import numpy as np
import evaluate

metric = evaluate.load("accuracy")
def compute_metrics(p):
    predictions = np.argmax(p.predictions, axis=1)  # Convert logits to class labels
    return metric.compute(predictions=predictions, references=p.label_ids)


In [10]:
labels = ["real", "spoof"]

In [11]:
from transformers import ViTForImageClassification

model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224-in21k",
    num_labels=len(labels),
    id2label={str(i): label for i, label in enumerate(labels)},
    label2id={label: str(i) for i, label in enumerate(labels)}
)

print(model.config)


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ViTConfig {
  "_attn_implementation_autoset": true,
  "_name_or_path": "google/vit-base-patch16-224-in21k",
  "architectures": [
    "ViTModel"
  ],
  "attention_probs_dropout_prob": 0.0,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 768,
  "id2label": {
    "0": "real",
    "1": "spoof"
  },
  "image_size": 224,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "label2id": {
    "real": "0",
    "spoof": "1"
  },
  "layer_norm_eps": 1e-12,
  "model_type": "vit",
  "num_attention_heads": 12,
  "num_channels": 3,
  "num_hidden_layers": 12,
  "patch_size": 16,
  "qkv_bias": true,
  "torch_dtype": "float32",
  "transformers_version": "4.49.0"
}



In [12]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./vit-celeba-spoof",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    num_train_epochs=4,
    fp16=True,
    logging_steps=10,
    learning_rate=5e-5,
    save_total_limit=2,
    remove_unused_columns=False,
    push_to_hub=False,
    report_to="tensorboard",
    load_best_model_at_end=True,
    metric_for_best_model="loss",
    greater_is_better=False,
)




In [13]:
from datasets import DatasetDict

splits = processed_ds.train_test_split(test_size=0.2, seed=42)
prepared_ds = DatasetDict({
    "train": splits["train"],
    "validation": splits["test"]
})


In [14]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds["train"],
    eval_dataset=prepared_ds["validation"],
    tokenizer=processor,
)


  trainer = Trainer(


In [15]:
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

Epoch,Training Loss,Validation Loss,Accuracy
1,0.0195,0.0187,0.996628
2,0.0012,0.010695,0.997377
3,0.0007,0.018983,0.996628
4,0.0006,0.004042,0.998876


***** train metrics *****
  epoch                    =          4.0
  total_flos               = 3081083316GF
  train_loss               =       0.0218
  train_runtime            =   0:59:35.38
  train_samples_per_second =       11.941
  train_steps_per_second   =        0.747


In [22]:
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

predictions = trainer.predict(prepared_ds["validation"])

preds = np.argmax(predictions.predictions, axis=-1)  
labels = predictions.label_ids  

accuracy = accuracy_score(labels, preds)
precision = precision_score(labels, preds, average='weighted')
recall = recall_score(labels, preds, average='weighted')
f1 = f1_score(labels, preds, average='weighted')

print(f"Validation Accuracy: {accuracy:.4f}")
print(f"Validation Precision: {precision:.4f}")
print(f"Validation Recall: {recall:.4f}")
print(f"Validation F1-score: {f1:.4f}")


Validation Accuracy: 0.9989
Validation Precision: 0.9989
Validation Recall: 0.9989
Validation F1-score: 0.9989
