In [1]:
# ==============================================================================
# This script trains a binary classification model, referred to as the "Gatekeeper,"
    # whose sole job is to distinguish between a genuine, classifiable emotional
    # expression and a non-emotional facial action (e.g., mid-speech movements).
# It uses the manually curated "CorrectionSet" as its training data. The resulting
# model will be used as the first-stage filter in our main video analysis pipeline.
# ==============================================================================

In [2]:
import torch
from transformers import AutoImageProcessor, AutoModelForImageClassification, TrainingArguments, Trainer
from datasets import load_dataset
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import os
from datetime import datetime

In [3]:
# ==============================================================================
# 1. CONFIGURATION
# ==============================================================================

# --- Path to the dataset you just created ---
DATASET_PATH = "/Users/natalyagrokh/AI/ml_expressions/img_expressions/data_flywheel/V6_20250716_112248/CorrectionSet"

# --- Define where to save the new Gatekeeper model ---
OUTPUT_DIR_ROOT = "/Users/natalyagrokh/AI/ml_expressions/img_expressions/data_flywheel/gatekeeper_models"
MODEL_NAME = f"gatekeeper_V1_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
FINAL_OUTPUT_DIR = os.path.join(OUTPUT_DIR_ROOT, MODEL_NAME)

os.makedirs(FINAL_OUTPUT_DIR, exist_ok=True)

In [4]:
# ==============================================================================
# 2. DATA LOADING & PREPARATION
# ==============================================================================

# --- Load the dataset from the folders ---
# The labels ('Emotion', 'Non-Emotional_Action') will be automatically inferred.
print(f"--- Loading dataset from: {DATASET_PATH} ---")
dataset = load_dataset("imagefolder", data_dir=DATASET_PATH)
train_test_split = dataset['train'].train_test_split(test_size=0.2, stratify_by_column='label')
train_dataset = train_test_split['train']
eval_dataset = train_test_split['test']

# --- Load a pre-trained model processor ---
# We can use the same processor as before for consistency.
processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")

# --- Define the transformation function ---
def transform(examples):
    # Process images on-the-fly
    images = [img.convert("RGB") for img in examples["image"]]
    examples["pixel_values"] = processor(images, return_tensors="pt")['pixel_values']
    return examples

train_dataset.set_transform(transform)
eval_dataset.set_transform(transform)

# --- Define the data collator ---
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['label'] for x in batch])
    }

--- Loading dataset from: /Users/natalyagrokh/AI/ml_expressions/img_expressions/data_flywheel/V6_20250716_112248/CorrectionSet ---


Resolving data files:   0%|          | 0/2974 [00:00<?, ?it/s]

Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.


In [5]:
# ==============================================================================
# 3. MODEL TRAINING
# ==============================================================================

# --- Load the pre-trained model ---
# We will fine-tune a new model for this simple binary task.
labels = dataset['train'].features['label'].names
model = AutoModelForImageClassification.from_pretrained(
    "google/vit-base-patch16-224-in21k",
    num_labels=len(labels),
    id2label={i: label for i, label in enumerate(labels)},
    label2id={label: i for i, label in enumerate(labels)}
)

training_args = TrainingArguments(
    output_dir=FINAL_OUTPUT_DIR,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=4,
    logging_steps=10,
    eval_strategy="epoch",    
    save_strategy="epoch",    
    load_best_model_at_end=True,
    remove_unused_columns=False 
)

# --- Define Metrics ---
def compute_metrics(p):
    preds = np.argmax(p.predictions, axis=1)
    precision, recall, f1, _ = precision_recall_fscore_support(p.label_ids, preds, average='binary')
    acc = accuracy_score(p.label_ids, preds)
    return {'accuracy': acc, 'f1': f1, 'precision': precision, 'recall': recall}

# --- Initialize and Run the Trainer ---
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    tokenizer=processor,
)

print(f"\n--- Starting training for Gatekeeper Model: {MODEL_NAME} ---")
trainer.train()

print(f"\n✅ Training complete. Best model saved to: {FINAL_OUTPUT_DIR}")

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(



--- Starting training for Gatekeeper Model: gatekeeper_V1_20250721_123628 ---




Epoch,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
1,0.2546,0.346971,0.852101,0.865443,0.792717,0.952862
2,0.2104,0.288236,0.890756,0.896332,0.851515,0.946128
3,0.0751,0.287729,0.905882,0.906977,0.895082,0.919192
4,0.0257,0.306111,0.904202,0.904523,0.9,0.909091





✅ Training complete. Best model saved to: /Users/natalyagrokh/AI/ml_expressions/img_expressions/data_flywheel/gatekeeper_models/gatekeeper_V1_20250721_123628
