In [1]:
# This script performs a one-time analysis of ferckjalfaga dataset. 
    # It loads the final V29 model, removes its final classification layer to access
    # the rich feature embeddings (the "faceprints"), and then calculates the average 
    # faceprint for each of the core emotion classes.

In [2]:
import torch
import torch.nn as nn
from transformers import AutoImageProcessor, AutoModelForImageClassification
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import os

In [3]:
# ==============================================================================
# 1. CONFIGURATION
# ==============================================================================

# --- IMPORTANT: Update this path to point to your FINAL, CURATED training dataset ---
# This should be the root folder containing the subfolders for each emotion class.
TRAINING_DATA_DIR = "/Users/natalyagrokh/AI/ml_expressions/img_datasets/ferckjalfaga_dataset" 

# Path to the final, production-ready model
MODEL_PATH = "/Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V29_20250710_082807"

# Output file path
CENTROIDS_SAVE_PATH = "emotion_centroids.pt"

BATCH_SIZE = 32

In [4]:
# ==============================================================================
# 2. SETUP
# ==============================================================================

# --- Load Model and Processor ---
print(f"--- Loading model from {MODEL_PATH} ---")
model = AutoModelForImageClassification.from_pretrained(MODEL_PATH)
processor = AutoImageProcessor.from_pretrained(MODEL_PATH)

# --- Set up device ---
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"\nüñ•Ô∏è Using device: {device}")

# --- Modify the model to output feature embeddings ---
embedding_size = model.config.hidden_size
model.classifier = nn.Identity(embedding_size)
model.to(device).eval()
print("‚úÖ Model modified to output feature embeddings.")

# --- Load and process the dataset ---
print(f"\n--- Loading training data from {TRAINING_DATA_DIR} ---")
if not os.path.exists(TRAINING_DATA_DIR):
    raise FileNotFoundError(f"CRITICAL: The specified training data directory does not exist. Please update the TRAINING_DATA_DIR path.")

dataset = load_dataset("imagefolder", data_dir=TRAINING_DATA_DIR, split="train")

def transform(examples):
    # Process images on-the-fly
    examples["pixel_values"] = processor([img.convert("RGB") for img in examples["image"]], return_tensors="pt")['pixel_values']
    return examples

dataset.set_transform(transform)

# --- THIS IS THE NEW/CORRECTED PART ---
# Define a custom collate function to handle batching
def custom_collate(batch):
    """
    Manually creates a batch by stacking 'pixel_values' and creating a 'labels' tensor,
    while ignoring the raw 'image' objects.
    """
    pixel_values = torch.stack([item['pixel_values'] for item in batch])
    labels = torch.tensor([item['label'] for item in batch])
    return {'pixel_values': pixel_values, 'label': labels}

# Update the DataLoader to use our custom collate function
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, collate_fn=custom_collate)
print("‚úÖ Data loaded and prepared with custom collator.")

--- Loading model from /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V29_20250710_082807 ---

üñ•Ô∏è Using device: mps
‚úÖ Model modified to output feature embeddings.

--- Loading training data from /Users/natalyagrokh/AI/ml_expressions/img_datasets/ferckjalfaga_dataset ---


Resolving data files:   0%|          | 0/17461 [00:00<?, ?it/s]

‚úÖ Data loaded and prepared with custom collator.


In [5]:
# ==============================================================================
# 3. CALCULATE EMBEDDINGS PER CLASS
# ==============================================================================

# Dictionary to hold all embeddings, separated by class label ID
class_embeddings = {i: [] for i in range(model.config.num_labels)}

print("\n--- Generating embeddings for all training images ---")
for batch in tqdm(dataloader, desc="Processing batches"):
    pixel_values = batch['pixel_values'].to(device)
    labels = batch['label']
    
    with torch.no_grad():
        # Get the feature embeddings from the modified model
        embeddings = model(pixel_values=pixel_values).logits
    
    # Move embeddings to CPU and store them
    cpu_embeddings = embeddings.cpu().numpy()
    for i in range(len(labels)):
        label_id = labels[i].item()
        class_embeddings[label_id].append(cpu_embeddings[i])

print("‚úÖ All embeddings generated.")


--- Generating embeddings for all training images ---


Processing batches: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 546/546 [05:46<00:00,  1.58it/s]

‚úÖ All embeddings generated.





In [6]:
# ==============================================================================
# 4. CALCULATE AND SAVE CENTROIDS
# ==============================================================================

emotion_centroids = {}
id2label = model.config.id2label

print("\n--- Calculating average embedding (centroid) for each class ---")
for label_id, embeddings_list in class_embeddings.items():
    if not embeddings_list:
        print(f"‚ö†Ô∏è Warning: No images found for class {id2label[label_id]}. Skipping centroid calculation.")
        continue
    
    # Calculate the mean of all embeddings for this class
    centroid = np.mean(np.array(embeddings_list), axis=0)
    emotion_centroids[label_id] = torch.from_numpy(centroid)
    print(f"  -> Centroid calculated for '{id2label[label_id]}'")

# Save the final dictionary of centroids to a file
torch.save(emotion_centroids, CENTROIDS_SAVE_PATH)
print(f"\n‚úÖ Centroids successfully calculated and saved to: {CENTROIDS_SAVE_PATH}")


--- Calculating average embedding (centroid) for each class ---
  -> Centroid calculated for 'anger'
  -> Centroid calculated for 'disgust'
  -> Centroid calculated for 'fear'
  -> Centroid calculated for 'happiness'
  -> Centroid calculated for 'neutral'
  -> Centroid calculated for 'questioning'
  -> Centroid calculated for 'sadness'
  -> Centroid calculated for 'surprise'
  -> Centroid calculated for 'contempt'
  -> Centroid calculated for 'unknown'

‚úÖ Centroids successfully calculated and saved to: emotion_centroids.pt
